#%%
import netgen.gui
from ngsolve import *
import netgen.gui
import matplotlib.pyplot as plt
from netgen.geom2d import unit_square
from ngsolve.meshes import MakeStructured2DMesh
import netgen.occ as ngocc
import numpy as np

# Create list type data as read form file 

def generate_combinations(x_min, x_max, y_min, y_max, step):

    x_vals = np.arange(x_min, x_max + step, step)
    y_vals = np.arange(y_min, y_max + step, step)

    # Base grid
    grid_combinations = [[float(a), float(b)] for a in x_vals for b in y_vals]

    return np.array(grid_combinations)

# Generate x,y combinations for f(x,y)
pxy_data = generate_combinations(-1e-2, 1e-2, -1e-2, 1e-2, 5e-3)

n_pairs = pxy_data.shape[0]

# Here goes your real data
fpxy_data = np.linspace(1,3,n_pairs) 

# Create data pairs
data = np.column_stack((pxy_data, fpxy_data))

# Extract nods
x_vals = np.unique(data[:, 0])
y_vals = np.unique(data[:, 1])
nx, ny = len(x_vals), len(y_vals)

# Sort for reshape
sorted_data = data[np.lexsort((data[:,0], data[:,1]))]  # first y, than x

# Extract correstponding value fi=f(xi,yi)
f_grid = sorted_data[:, 2].reshape((ny, nx))  # ACHTUNG: (ny, nx) -> row: y, col: x

# Boundary for VoxelCoefficient
start = (x_vals[0], y_vals[0])
end   = (x_vals[-1], y_vals[-1])


# Generate test example 
geo = ngocc.Box(ngocc.Pnt(0,2,-2),ngocc.Pnt(10,-2,2))

geo.faces.Min(ngocc.X).name="left"
geo.faces.Max(ngocc.X).name="right"

geo.faces.Min(ngocc.Y).name="bot"
geo.faces.Max(ngocc.Y).name="top"

geo.faces.Min(ngocc.Z).name="back"
geo.faces.Max(ngocc.Z).name="front"

geo = ngocc.OCCGeometry(geo)
mesh = Mesh(geo.GenerateMesh(maxh=0.5))


# ---FE Space Definition--- #
fes_u = VectorH1(mesh, order=1, dirichlet='left')
fes_lag = NumberSpace(mesh, definedon=mesh.Boundaries("right"))

fes = fes_u * fes_lag

print(f'Dofs: {fes.ndof}')

q = GridFunction(fes)
u,lag = q.components

# for load steps
q0 = GridFunction(fes)
u0,lag0 = q0.components

U, LAG = fes.TrialFunction()
varU, varLAG = fes.TestFunction()

def eps_(u): return (1/2)*(grad(u)+grad(u).trans)

def S_Hooke(mu,lam,eps): return 2*mu*eps+ lam*Trace(eps)*Id(3)

def Psi_Hooke(mu,lam,E):
    '''
    Strain energy density: St. Venant-Kirchhoff material.\n
    Mat. parameter: mu, lambda.\n
    Parameter: E(u) = Green Lagrange Strain Tensor E(u)
    Type: scalar (plane strain)
    '''
    psi = (1/2)*lam*Trace(E)**2 + mu*InnerProduct(E,E)
    return CF(psi)

def Lame_params(EModul,PoissonRatio): 
        lam = (PoissonRatio*EModul)/((1-2*PoissonRatio)*(1+PoissonRatio))
        mu = EModul/(2*(1+PoissonRatio))
        return (lam,mu)

lam0, mu0 = Lame_params(70e3, 0.3)

# Strain component in 11 direction
eps0_11 = eps_(u0)[0]

paramCF_vis = VoxelCoefficient(
    start=start,
    end=end,
    values=f_grid,
    linear=True,  # bilinear Interpolation
    trafocf=(eps0_11, eps0_11)  # x,y as 2D CoefficientFunction
).Compile()

# strain dependent material parameters
lam_fkt = lam0*paramCF_vis
mu_fkt = mu0*paramCF_vis

Area = Integrate(1, mesh.Boundaries("right")) # for surface integral

steps = np.linspace(0, 1, 21)
load_param = Parameter(0)
u_max = 0.05
n = CF((1,0,0))

a1 = BilinearForm(fes)
a1 += SymbolicEnergy(Psi_Hooke(lam=lam_fkt,mu=mu_fkt,E=eps_(U)).Compile())
a1 += SymbolicEnergy(InnerProduct(-LAG,U[0]-u_max*load_param).Compile(), definedon=mesh.Boundaries("right"))

q.vec[:]=0
q0.vec[:]=0

u_list, f_list = [], []

for step in steps:
    print(f"Step: {step}")

    load_param.Set(step)
    q0.vec.data = q.vec

    # # Lösen
    with TaskManager():
        res2 = q.vec.CreateVector()
        a1.AssembleLinearization(q.vec)
        a1.Apply(q.vec, res2)
        q.vec.data -= a1.mat.Inverse(fes.FreeDofs(),inverse="pardiso")*res2

    u_list.append((1/Area)*Integrate(u[0],mesh.Boundaries("right")))
    f_list.append(lag.vec.data[0])

plt.plot(u_list, f_list, marker='o')
plt.xlabel('u_x / mm')
plt.ylabel('S_xx / MPa')

Draw(u,mesh,'u')
Draw(BoundaryFromVolumeCF(eps_(u)),mesh,'eps')
Draw(BoundaryFromVolumeCF(S_Hooke(mu_fkt,lam_fkt,eps_(u0))),mesh,'S')

#%%

# Visulize the interpolation on a seperate mesh

min_val, max_val = -1e-2,1e-2

mapping = lambda x, y: (min_val + (max_val - min_val)*x, min_val + (max_val - min_val)*y)

param_mesh = MakeStructured2DMesh(
    quads=True,
    nx=20,
    ny=20,
    mapping=mapping
)

paramCF_vis = VoxelCoefficient(
    start=start,
    end=end,
    values=f_grid,
    linear=True,  # bilinear Interpolation
    # trafocf=(x, y)  # x,y als 2D CoefficientFunction
).Compile()

fes = H1(param_mesh,order=1)

paramGF = GridFunction(fes)
paramGF.Set(paramCF_vis)

Draw(paramGF,param_mesh,'paramGF')

#%%
