# %%

import netgen.occ as occ
import numpy as np
from ngsolve import *
import time
from ngsolve.webgui import Draw
from netgen.meshing import MeshingStep

# Function computing mean curvature WITHOUT gradient jump stabilization
def ComputeMC(mesh):

    ns = specialcf.normal(mesh.dim)
    Ps = Id(mesh.dim) - OuterProduct(ns, ns)

    fes = VectorH1(mesh, order = 1, definedon = mesh.Boundaries('.*'))
    gfu = GridFunction(fes)
    kappa0, eta0 = fes.TnT()
    A0 = BilinearForm(fes)
    F0 = LinearForm(fes)
    A0 += kappa0*eta0*ds
    A0.Assemble()
    F0 += -InnerProduct(Ps, grad(eta0).Trace())*ds

    F0.Assemble()
    gfu.vec.data = A0.mat.Inverse(fes.FreeDofs())*F0.vec

    return gfu

# Function computing mean curvature WITH gradient jump stabilization
def ComputeStabMC(mesh):

    V = VectorH1(mesh, order = 1, definedon = mesh.Boundaries('.*'))
    gfu = GridFunction(V)
    dV = NormalFacetSurface(mesh, order=0, definedon = mesh.Boundaries('.*'))
        
    ns = specialcf.normal(mesh.dim)
    tE = specialcf.tangential(mesh.dim)
    nE = Cross(ns, tE)
    Ps = Id(mesh.dim) - OuterProduct(ns, ns)

    fes = V*dV
    gfu_c = GridFunction(fes)
    (kappa0, dkappa), (eta0, deta) = fes.TnT()
    jump_dkappadn = (grad(kappa0).Trace()*nE-dkappa.Trace())
    jump_detadn = (grad(eta0).Trace()*nE-deta.Trace())
    A0 = BilinearForm(fes)
    F0 = LinearForm(fes)
    A0 += kappa0*eta0*ds

    J = specialcf.JacobianMatrix(mesh.dim, mesh.dim-1)
    area = sqrt(Det(J.trans*J))/2
    F = specialcf.JacobianMatrix(mesh.dim)
    tau = specialcf.tangential(mesh.dim)
    myh = Norm(F*tau)
    h_f = area*2/myh
    stab = 1e-3

    A0 += stab*h_f*InnerProduct(jump_dkappadn,jump_detadn)\
        *ds(element_boundary=True)
    A0.Assemble()
    F0 += -InnerProduct(Ps, grad(eta0).Trace())*ds

    F0.Assemble()
    gfu_c.vec.data = A0.mat.Inverse(fes.FreeDofs())*F0.vec
    gfu.vec.data = gfu_c.components[0].vec.data

    return gfu

# Function generating the torus mesh
def TorusMesh(maxh):

    R = sqrt(2)
    r=1

    pnt1 = occ.Pnt(R-r, 0, 0 )
    pnt2 = occ.Pnt(R, 0, r )
    pnt3 = occ.Pnt(R+r , 0, 0 )
    pnt4 = occ.Pnt(R , 0 , -r)

    arc1 = occ.ArcOfCircle(pnt1, pnt2, pnt3)
    arc2 = occ.ArcOfCircle(pnt3, pnt4, pnt1)

    w = occ.Wire([arc1, arc2])
    body = w.Revolve(occ.Axis((0,0,0),occ.Z), 360).Rotate(occ.Axis((0,0,0),occ.X), 90)

    geo = occ.OCCGeometry(body)
    mesh = Mesh(geo.GenerateMesh(maxh=maxh, optsteps2d=3, perfstepsend=MeshingStep.MESHSURFACE))

    return mesh

# Comparing execution times between the two
h_list = [0.4, 0.2, 0.1, 0.05]
times = np.zeros((len(h_list), 2))
for i, h in enumerate(h_list):
    mesh = TorusMesh(h)
    t1 = time.time()
    mc = ComputeMC(mesh)
    Draw(mc)
    t2 = time.time()
    mc = ComputeStabMC(mesh)
    Draw(mc)
    t3 = time.time()
    times[i, 0] = t2 -t1
    times[i, 1] = t3 -t2

import pandas as pd
data = {
  "maxh": h_list,
  "MC": times[:,0],
  "MC_stab": times[:, 1],
  "ratio": times[:,1]/times[:,0]
}
#load data into a DataFrame object:
df = pd.DataFrame(data)
print(df)