from math import pi
# ngsolve stuff
from ngsolve import *
# visualization stuff
from ngsolve.internal import *
# basic xfem functionality
from xfem import *
from xfem.lsetcurv import *

from netgen.geom2d import SplineGeometry

# from make_uniform2D_grid import MakeUniform2DGrid
# from manufactured_solution_plain2d import get_solution_str

SetNumThreads(3)
geometry = 'circle' # circle / starfish

if geometry == 'circle':
    R = 1
    levelset = sqrt(x*x+y*y) - R
    levelset_str = 'sqrt(x*x+y*y) - R'
elif geometry == 'starfish':
    r0 = 1
    omega = 5
    levelset =  sqrt(x*x+y*y)-(r0+0.2*sin(omega*atan2(x,y)))
    levelset_str = 'sqrt(x*x+y*y)-('+str(r0)+'+0.2*sin('+str(omega)+'*atan2(x,y)))'

exact = sin(y)
exact_string = 'sin(y)'

print('Calculating exact f and grad(u)')

#sol_str, exact_grad_arr = get_solution_str(exact_string, levelset_str)
#print("Solution string:", sol_str)
#print("Exact grad str array:", exact_grad_arr)
exact_grad = CF((exact.Diff(x),exact.Diff(y)))#CoefficientFunction((eval(exact_grad_arr[0]),  eval(exact_grad_arr[1])))

l2errors = []
h1errors = []

structured_mesh = False

for i in [2,3,4,5,6]:
    
   if structured_mesh:
      mesh = MakeUniform2DGrid(quads = False, N=2**(i+1), P1=(-1.5,-1.5),P2=(1.5,1.5))
   else:
      square = SplineGeometry()
      square.AddRectangle([-1.5,-1.5],[1.5,1.5],bc=1)
      mesh = Mesh (square.GenerateMesh(maxh=0.5**(i), quad_dominated=False))
   
   lsetp1 = GridFunction(H1(mesh,order=1))
   InterpolateToP1(levelset,lsetp1)
   
   Draw(levelset, mesh, "lset")
   Draw(lsetp1, mesh, "lsetp1")
   # TraceFESpace 
   VhG = L2(mesh, order=1, dirichlet=[], dgjumps=True)
   #fes2 = FacetFESpace(mesh, order=1, dgjumps =True)
   
   # overwrite freedofs of VhG to mark only dofs that are involved in the cut problem
   ci = CutInfo(mesh, lsetp1)
   reg_Th = ci.GetElementsOfType(IF)
   reg_Fh = GetFacetsWithNeighborTypes(mesh,a=reg_Th,b=reg_Th,use_and=True)
      
   freedofs = VhG.FreeDofs()
   freedofs &= GetDofsOfElements(VhG,reg_Th)
   Draw(BitArrayCF(freedofs), mesh, "freedofs")
   
   gfu = GridFunction(VhG)
   
   #tangential projection to given normal
   def P(u,n_phi):
      return u - (u*n_phi)*n_phi
   
   #normalization (pointwise) of a vector
   def Normalized(u):
      return 1.0 / Norm(u) * u
   
   n_phi1 = Normalized(grad(lsetp1))
   n_phi2 = Normalized(grad(lsetp1).Other())
   
   h = specialcf.mesh_size
   n_F = specialcf.normal(2)
   
   conormal1 = Normalized(P(n_F,n_phi1))
   conormal2 = Normalized(P(n_F,n_phi2))
   
   def avg_flux(u):
     return 0.5*InnerProduct(grad(u),conormal1) + 0.5*InnerProduct(grad(u.Other()),conormal2)
   
   def jump(u):
     return u - u.Other()
   
   beta_E = 50.
   beta_F = 50.
   gamma = 0.1
   lam_nd = 1
   
   u = VhG.TrialFunction()
   v = VhG.TestFunction()
   lset_if  = { "levelset" : lsetp1, "domain_type" : IF , "subdivlvl" : 0}
   
   a = RestrictedBilinearForm(VhG,"a",reg_Th,reg_Fh,check_unused=False)
   a += SymbolicBFI(levelset_domain = lset_if, form = P(grad(u),n_phi1) * P(grad(v),n_phi1) + u * v, definedonelements=reg_Th)
   a += SymbolicBFI(form = (lam_nd * grad(u)*n_phi1) * (grad(v)*n_phi1), definedonelements=reg_Th)
   
   a += SymbolicBFI(levelset_domain = lset_if, form = ( - avg_flux(u)*jump(v)
                                                        - avg_flux(v)*jump(u)
                                                        + beta_E/h * jump(u)*jump(v)),
                    skeleton=True, definedonelements=reg_Fh)
   
   a += SymbolicBFI(form = gamma * InnerProduct(grad(u) - grad(u.Other()), n_F) * InnerProduct(grad(v) - grad(v.Other()), n_F) , skeleton=True, definedonelements=reg_Fh)
   a += SymbolicBFI(form = beta_F/(h*h) * (u-u.Other()) * (v-v.Other()) , skeleton=True, definedonelements=reg_Fh)
   
   f_coeff = ((1+x*x)*sin(y)+cos(y)*y)
   #f_coeff = eval(sol_str)
   
   f = LinearForm(VhG)
   f += SymbolicLFI(levelset_domain = lset_if, form = f_coeff * v, definedonelements=reg_Th)
   a.Assemble()
   f.Assemble();
   
   gfu.vec[:] = 0.0
   gfu.vec.data = a.mat.Inverse(freedofs, "umfpack") * f.vec
   
   Draw(gfu, mesh, "sol")
   Draw(gfu-exact, mesh, "res")
   
   err_sqr_coefs = (gfu-exact)**2
   l2error = sqrt( Integrate( levelset_domain=lset_if, cf=err_sqr_coefs[0], mesh=mesh, order=2) )
   print ("l2error : ", l2error)
   l2errors.append(l2error)
   
   H1_err_sqr_coefs = Norm( P (gfu.Deriv() - exact_grad, n_phi1) )**2
   h1error = sqrt( l2error**2 + Integrate( levelset_domain=lset_if, cf=H1_err_sqr_coefs[0], mesh=mesh, order=2) )
   print("h1error : ", h1error)
   h1errors.append(h1error)

if len(l2errors) > 1:
    eocs = [log(l2errors[i-1]/l2errors[i])/log(2) for i in range(1,len(l2errors))]
    print ("l2 eocs : ", eocs)
    
    eocs = [log(h1errors[i-1]/h1errors[i])/log(2) for i in range(1,len(h1errors))]
    print ("h1 eocs : ", eocs)

f = open("conv_lobetal2d"+geometry+".dat","w")
for i in range(len(l2errors)):
    f.write(str(i)+"\t"+str(l2errors[i])+"\t"+str(h1errors[i])+"\n")
