from math import pi
# ngsolve stuff
from ngsolve.meshes import *
from ngsolve import *
# basic xfem functionality
from xfem import *
from xfem.lsetcurv import *
import sys

from netgen.csg import *

# Setting Parameters which are so important that we potentially want to modify them in the python call
order = 2
n_ref = 3

# Setting Parameters which will stay fixed
lam_ip = 4 * (order+1)**2
lam_gp_0 = 10
lam_gp_1 = 0.1

ngsglobals.msg_level = 0
SetNumThreads(3)
SetHeapSize(9999880)

length = (1.5,1.5,1.5)
init_nxyz = (4,4,4)
levelset = sqrt(x**2 + y**2 + z**2) - 1
exact = sin(pi * z)
coef_f = (sin(pi * z) * ( pi * pi * (1 - z * z) + 1.)
          + cos(pi * z) * 2 * pi * z)
exact_grad = CoefficientFunction((exact.Diff(x),  exact.Diff(y), exact.Diff(z) ))

l2errors = []
h1errors = []

for i in range(n_ref+1):
   cube = CSGeometry()
   cube.Add (OrthoBrick(Pnt(-length[0],-length[1],-length[2]), Pnt(length[0],length[1],length[2])))
   mesh = Mesh (cube.GenerateMesh(maxh=(2.*length[0]/(init_nxyz[0]))*0.5**(i), quad_dominated=False))
   print("Mesh is ready")
   
   lsetp1 = GridFunction(H1(mesh,order=1))
   
   with TaskManager():
        InterpolateToP1(levelset,lsetp1)
        
        # class to compute the mesh transformation needed for higher order accuracy
        #  * order: order of the mesh deformation function
        #  * threshold: barrier for maximum deformation (to ensure shape regularity)
        lsetmeshadap = LevelSetMeshAdaptation(mesh, order=order, threshold=10, discontinuous_qn=True, heapsize=24999200)
        deformation = lsetmeshadap.CalcDeformation(levelset)
        
        # TraceFESpace 
        VhG = L2(mesh, order=order, dirichlet=[], dgjumps=True, low_order_space=False)
   
   # overwrite freedofs of VhG to mark only dofs that are involved in the cut problem
   ci = CutInfo(mesh, lsetp1)
   ba_IF = ci.GetElementsOfType(IF)
   ba_IF_facets = GetFacetsWithNeighborTypes(mesh,a=ba_IF,b=ba_IF,use_and=True)
   
   VhG = Compress(VhG, GetDofsOfElements(VhG,ba_IF))
   
   freedofs = VhG.FreeDofs()
   freedofs &= GetDofsOfElements(VhG,ba_IF)
   
   gfu = GridFunction(VhG)
   
   #tangential projection to given normal
   def P(u,n_phi):
      return u - (u*n_phi)*n_phi
   
   #normalization (pointwise) of a vector
   def Normalized(u):
      return 1.0 / Norm(u) * u
   
   n_phi1 = Normalized(grad(lsetp1))
   n_phi2 = Normalized(grad(lsetp1).Other())
   
   h = specialcf.mesh_size
   n_F = specialcf.normal(mesh.dim)
   
   conormal1 = Normalized(P(n_F,n_phi1))
   conormal2 = Normalized(P(-n_F,n_phi2))
   
   def avg_flux(u):
     return 0.5*InnerProduct(grad(u),conormal1) - 0.5*InnerProduct(grad(u).Other(),conormal2)
   
   def jump(u):
     return u - u.Other()
   
   # expressions of test and trial functions:
   u = VhG.TrialFunction()
   v = VhG.TestFunction()
   
   # integration domains (and integration parameter "subdivlvl" and "force_intorder")
   lset_if  = { "levelset" : lsetp1, "domain_type" : IF , "subdivlvl" : 0}
   
   def EdgeIntegral(form):
       return SymbolicBFI(levelset_domain = lset_if, form = form.Compile(False,wait=True),
                          skeleton=True, definedonelements=ba_IF_facets)

   def FacetIntegral(form):
       return SymbolicBFI(form = form.Compile(False,wait=True),
                          skeleton=True, definedonelements=ba_IF_facets)
   
   def SurfaceIntegral(form):
       return SymbolicBFI(levelset_domain = lset_if, form = form.Compile(False,wait=True),
                          definedonelements=ba_IF)

   def VolumeIntegral(form):
       return SymbolicBFI(form = form.Compile(False,wait=True),
                          definedonelements=ba_IF)
   
   lam_nd = 1/h+h
   
   # bilinear forms:
   a = RestrictedBilinearForm(VhG,"a",ba_IF,ba_IF_facets,check_unused=False)
   a += SurfaceIntegral(P(grad(u),n_phi1) * P(grad(v),n_phi1) + u * v)
   a += VolumeIntegral((lam_nd * grad(u)*n_phi1) * (grad(v)*n_phi1))
   
   a += EdgeIntegral( - avg_flux(u)*jump(v))
   a += EdgeIntegral( - avg_flux(v)*jump(u))
   a += EdgeIntegral( + lam_ip/h * jump(u)*jump(v))
   
   a += FacetIntegral(lam_gp_1 * InnerProduct(grad(u) - grad(u.Other()), n_F) * InnerProduct(grad(v) - grad(v.Other()), n_F))
   a += FacetIntegral(lam_gp_0/(h*h) * (u-u.Other()) * (v-v.Other()))
   
   f = LinearForm(VhG)
   f += SymbolicLFI(levelset_domain = lset_if, form = coef_f * v, definedonelements=ba_IF)

   mesh.SetDeformation(deformation)
   with TaskManager():
       a.Assemble()
       f.Assemble();
   
       gfu.vec[:] = 0.0
       gfu.vec.data = a.mat.Inverse(freedofs,inverse="umfpack") * f.vec
   
   err_sqr_coefs = (gfu-exact)**2
   l2error = sqrt( Integrate( levelset_domain=lset_if, cf=err_sqr_coefs[0], mesh=mesh, order=2*order+1) )
   
   print ("l2error : ", l2error)
   l2errors.append(l2error)
   print ("l2errors : ", l2errors)
   l2eocs = [log(l2errors[j-1]/l2errors[j])/log(2) for j in range(1,len(l2errors))]
   print ("l2eocs : ", l2eocs)
   
   H1_err_sqr_coefs = Norm( P (gfu.Deriv() - exact_grad, n_phi1) )**2
   h1error = sqrt( l2error**2 + Integrate( levelset_domain=lset_if, cf=H1_err_sqr_coefs[0], mesh=mesh, order=2*order+1) )
   print ("h1error: ", h1error)
   h1errors.append(h1error)
   print ("h1errors: ", h1errors)
   h1eocs = [log(h1errors[j-1]/h1errors[j])/log(2) for j in range(1,len(h1errors))]
   print ("h1eocs : ", h1eocs)
   
   mesh.UnsetDeformation()
