import numpy as np
from netgen.occ import *
from netgen.csg import *
from netgen.geom2d import SplineGeometry
from ngsolve import *
import netgen.gui
from matplotlib import pyplot as plt
from ngsolve.solvers import *


rho_nmc = 1.8e3  # kg/m3
m_nmc = 0.57
epsilon_nmc = 0.57
epsilon_void = 0.14
SC_nmc = 200  # Ah/kg
V_nmc = 100e-6 * 0.19e-4
R = 8.3145
F = 96485.3329
T = 298.15
c_Li_max = 18946 #* 1e-18
alpha_a = 0.5
alpha_c = 1.0 - alpha_a
FoverRT = F / R / T
i_0 = 0.25 * 1e-12

I_tot = rho_nmc * m_nmc * SC_nmc / 10 * 1e-18 * 100

sto = 0.5
sigma_nmc = 10 ** (2 * (1.05 - sto) / 0.8 - 4) * 1e2 * 1e-6
D_nmc = 1e-14 * 1e12
D_e = 3.53e-12 * 1e12
sigma_e = 0.116 * 1e-6

tau_nmc = 1.5
tau_e = 4.5
sigma_nmc_eff = sigma_nmc * epsilon_nmc ** (1 + tau_nmc)
sigma_e_eff = sigma_e * (1 - epsilon_nmc) ** (1 + tau_e)
D_e_eff = D_e * (1 - epsilon_nmc) ** (1 + tau_e)
D_nmc_eff = D_nmc * epsilon_nmc ** (1 + tau_nmc)


h = 20
n_par = 10
r = h / 4
scale = 0.95
pc_y = [r * scale]	
pc_x = [i*r * scale for i in np.arange(1, 2*n_par+1, 2)]
w = pc_x[-1] + r * scale

geo = []

class Particles:
    def __init__(self, pc_x, pc_y, r, outside_rect=None):
        self.outside_rect = outside_rect
        lower_particles = []
        upper_particles = []
        for x in pc_x:
            par_new_lower = MoveTo(x, pc_y[0]).Circle(r).Face()
            lower_particles.append(par_new_lower)
        self.par_new = self.create_particles(lower_particles)
         
    def create_particles(self, particles):
        first = True
        for par in particles:
            if first:
                par_new = par
                first = False
            par_new += par
        if self.outside_rect is not None:
            par_new -= rect2
        par_new.faces.name = "solid"
        par_new.edges.name = "internal_solid"
        #par_new.edges.maxh = 0.2
        #par_new.faces.maxh = 0.1
        par_new.edges.Min(X).name = "left_solid"
        par_new.edges.Max(X).name = "right_solid"
        return par_new
    

rct = MoveTo(-r, -h/2).Rectangle(w + r, 1.5*h).Face()
rct = MoveTo(0, -h/2).Rectangle(w, 1.5*h).Face()
rct.edges.Min(X).name = "left_ele"
rct.edges.Max(X).name = "right_ele"
rct.edges.Min(Y).name = "bottom_ele"
rct.edges.Max(Y).name = "top_ele"

rect2 = MoveTo(-w, -h).Rectangle(3*w, 3*h).Face()
rect2 = rect2 - rct

par_new = Particles(pc_x, pc_y, r, outside_rect=rect2).par_new
par_new_pom = Particles(pc_x, pc_y, r, outside_rect=rect2).par_new

ele = rct - par_new_pom
for j in range(len(ele.edges.edges)):
    if ele.edges[j].name not in ["left_ele", "right_ele", "top_ele", "bottom_ele"]:
        ele.edges[j].name = ele.edges[j].name.replace("solid", "ele")

ele.faces.name = "ele"

geo = [ele, par_new_pom]
geometry = Compound(geo)
mesh = Mesh(
    OCCGeometry(geometry, dim=2).GenerateMesh(maxh=0.5)
)

fes_all = H1(mesh, order=2, dirichlet="left_ele|right_ele|top_ele|bottom_ele")
fes_ele_v = H1(mesh, order=2, definedon="ele", dirichlet="left_ele|right_ele|top_ele|bottom_ele")
fes_ele_c = H1(mesh, order=2, definedon="ele")
fes_solid_v = H1(mesh, order=2, definedon="solid", dirichlet="left_solid")
fes_solid_c = H1(mesh, order=2, definedon="solid")

fes_solid = fes_solid_v * fes_solid_c
fes_ele = fes_ele_v * fes_ele_c

(u_phis, u_cs), (v_phis, v_cs) = fes_solid.TnT()
(u_phie, u_ce), (v_phie, v_ce) = fes_ele.TnT()

gfs = GridFunction(fes_solid)
gf_phis, gf_cs = gfs.components
gfe = GridFunction(fes_ele)
gf_phie, gf_ce = gfe.components

eta = GridFunction(fes_all)

gfflux = GridFunction(HDiv(mesh, order=2))
c_e_flux = GridFunction(HDiv(mesh, order=2, definedon="ele"))

gf_ce.Set(1000 * 1e-18)
gf_cs.Set(c_Li_max * 1e-18)

c_e_flux.Set(grad(gf_ce) / gf_ce)

dt = 10

a_s = BilinearForm(fes_solid, symmetric=False)
a_s += D_nmc_eff * grad(u_cs) * grad(v_cs) * dx(definedon=mesh.Materials("solid"))
a_s += sigma_nmc_eff * grad(u_phis) * grad(v_phis) * dx(definedon=mesh.Materials("solid"))
a_s.Assemble()

a_e = BilinearForm(fes_ele, symmetric=False)
a_e += D_e_eff * grad(u_ce) * grad(v_ce) * dx(definedon=mesh.Materials("ele"))
a_e += sigma_e_eff * grad(u_phie) * grad(v_phie) * dx(definedon=mesh.Materials("ele"))
a_e.Assemble()

m_s = BilinearForm(fes_solid, symmetric=False)
m_s += u_cs * v_cs * dx(definedon=mesh.Materials("solid"))
m_s += 0.0 * v_phis * dx(definedon=mesh.Materials("solid"))
m_s.Assemble()
m_e = BilinearForm(fes_ele, symmetric=False)
m_e += u_ce * v_ce * dx(definedon=mesh.Materials("ele"))
m_e += 0.0 * v_phie * dx(definedon=mesh.Materials("ele"))
m_e.Assemble()

m_sstar = m_s.mat.CreateMatrix()
m_sstar.AsVector().data = m_s.mat.AsVector() + dt * a_s.mat.AsVector()
invm_sstar = m_sstar.Inverse(freedofs=fes_solid.FreeDofs())
m_estar = m_e.mat.CreateMatrix()
m_estar.AsVector().data = m_e.mat.AsVector() + dt * a_e.mat.AsVector()
invm_estar = m_estar.Inverse(freedofs=fes_ele.FreeDofs())

j_int = 2 * i_0 * sinh(0.5 * FoverRT * eta)
j_sink = 2 * 100e-12 * sinh(-0.5 * FoverRT * gf_phis)

f_s = LinearForm(fes_solid)
f_s += j_int / F * v_cs * ds("internal_ele")
f_s += -j_int * v_phis * ds("internal_ele")
f_s += 0.0 * v_cs * ds("right_solid|left_solid")
f_s += I_tot * v_phis * ds("right_solid")
f_s += 0.0 * v_phis * ds("left_solid")

f_e = LinearForm(fes_ele)
f_e += 0.0 * v_ce * ds("right_ele|top_ele|bottom_ele")
f_e += 0.0 * v_phie * ds("right_ele|top_ele|bottom_ele")
f_e += -j_int / F * v_ce * ds("internal_ele")
f_e += j_int * v_phie * ds("internal_ele")
f_e += j_sink / F * v_ce * ds("left_ele")
f_e += -j_sink * v_phie * ds("left_ele")


br = 0
while True:
    eta.Set(gf_phis - gf_phie, definedon=mesh.Boundaries("internal_ele"))
    f_s.Assemble()
    f_e.Assemble()

    res_s = dt * f_s.vec - dt * a_s.mat * gfs.vec
    gfs.vec.data += invm_sstar * res_s
    res_e = dt * f_e.vec - dt * a_e.mat * gfe.vec
    gfe.vec.data += invm_estar * res_e

    if br > 200:
        f_s = LinearForm(fes_solid)
        f_s += j_int / F * v_cs * ds("internal_ele")
        f_s += -j_int * v_phis * ds("internal_ele")
        f_s += 0.0 * v_cs * ds("right_solid|left_solid")
        f_s += -I_tot * v_phis * ds("right_solid")
        f_s += 0.0 * v_phis * ds("left_solid")


    br += 1
    if br >= 400:
        break


Draw(gf_phie, mesh, "phie")
Draw(gf_ce, mesh, "dce")
Draw(gf_cs, mesh, "dcs")
Draw(-D_nmc_eff * grad(gf_cs), mesh, "i_nmc")
Draw(-D_e_eff * grad(gf_ce), mesh, "i_e")
Draw(j_int, mesh, "j_ints")
Draw(gf_phis, mesh, "phis")


'''
bnds = mesh.GetBoundaries()
for i, bnd in enumerate(bnds):
    print(fr"{i + 1}: {bnd}")
'''