# ------------------------------ Load Libraties -------------------------------
import sys
from netgen.csg import *
from ngsolve import *
from ngsolve.internal import *
from xfem import *
from xfem.lsetcurv import *
from math import pi
import numpy as np

# -------------------------------- Parameters ---------------------------------
# Mesh Diameter
maxh = 0.2

# Polynomial Order of FE Spaces
spaceorder = 2

# Geometric Order
meshorder = 2

# Parameters in the Method
betaE = 1
betaF = 1
gamma = 1

# Geometry and Meshes
cube = OrthoBrick(Pnt(- 1.5, - 1.5, - 1.5), Pnt(1.5, 1.5, 1.5))
geo = CSGeometry()
geo.Add(cube)
mesh = Mesh(geo.GenerateMesh(maxh = maxh, quad_dominated = False))
# levelset = sqrt(x**2 + y**2 + z**2) - 1
levelset = x**2 + y**2 + z**2 - 1

# Model Problem: - Laplace-Beltrami(uexa) + uexa = fexa
uexa = sin(pi * z)
fexa = (sin(pi * z) * (1 * pi * pi * (1 - z * z) + 1)
          + 1 * cos(pi * z) * 2 * pi * z)

# ----------------------------------- Main ------------------------------------
lsetmeshadap = LevelSetMeshAdaptation(mesh, order = meshorder, threshold = 0.2, discontinuous_qn = True)
deformation = lsetmeshadap.CalcDeformation(levelset)
lset_approx = lsetmeshadap.lset_p1
mesh.deformation = deformation

# Various FE Spaces
ci = CutInfo(mesh, lset_approx)
ba_IF = ci.GetElementsOfType(IF)
ba_fd_facets = GetFacetsWithNeighborTypes(mesh, a = ba_IF, b = ba_IF)
V1 = H1(mesh, order = spaceorder, dirichlet = [], dgjumps = True)
V1R = Restrict(V1, ba_IF)

# Trial and Test Functions
u = V1R.TrialFunction()
v = V1R.TestFunction()

# Various Vectors
nF = specialcf.normal(3)
nh = Normalize(grad(lset_approx))
tg = Normalize(Cross(nF, nh))
nE = Normalize(Cross(nh, tg))

# Mesh Size
h = specialcf.mesh_size

# Projection
Ph = Id(3) - OuterProduct(nh, nh)

# Average E
aveE_u = 0.5*(nE*grad(u) + nE.Other()*grad(u.Other()))
aveE_v = 0.5*(nE*grad(v) + nE.Other()*grad(v.Other()))

# Jump E
jumpE_u = u - u.Other()
jumpE_v = v - v.Other()

# Jump F1
jumpF1_u = u - u.Other()
jumpF1_v = v - v.Other()

# Jump F2
jumpF2_u = grad(u) - grad(u.Other())
jumpF2_v = grad(v) - grad(v.Other())

# Measure on the Surface
ds = dCut(levelset = lset_approx, domain_type = IF, definedonelements = ba_IF, deformation = deformation)

# Measure on the Surface Edge
de = dCut(levelset = lset_approx, domain_type = IF, definedonelements = ba_fd_facets, skeleton = True, deformation = deformation)

# Measure on the Bulk around the Surface
db = dx(definedonelements = ba_IF, deformation = deformation)

# Measure on the Interior Face around the Surface
df = dx(definedonelements = ba_fd_facets, skeleton = True, deformation = deformation)

# Bilinear Forms
a = BilinearForm(V1R, symmetric = True)
a += (InnerProduct(Ph*grad(u), Ph*grad(v)) + u*v)*ds
a += - InnerProduct(aveE_u, jumpE_v)*de
a += - InnerProduct(aveE_v, jumpE_u)*de
a += betaE/h*InnerProduct(jumpE_u, jumpE_v)*de
a += betaF/h**2*InnerProduct(jumpF1_u, jumpF1_v)*df
a += gamma*InnerProduct(nF*jumpF2_u, nF*jumpF2_v)*df

a.Assemble()

# Linear Form
f = LinearForm(V1R)
f += fexa*v*ds

f.Assemble()

s = GridFunction(V1R)
s.vec.data = a.mat.Inverse(V1R.FreeDofs()) * f.vec

err = sqrt(Integrate((s - uexa)**2*ds, mesh = mesh))
print(err)

