from ngsolve import *
from netgen.geom2d import unit_square
from ngsolve.krylovspace import CGSolver
from ngsPETSc import KrylovSolver

static_condensation = True
bddcngsolve = True
order = 2

uex = CoefficientFunction(sin(3.7*pi*x)*sin(2.5*pi*y))
source = CoefficientFunction(-uex.Diff(x).Diff(x) - uex.Diff(y).Diff(y))

mesh = Mesh(unit_square.GenerateMesh(maxh=0.05))
n = specialcf.normal(mesh.dim)
h = specialcf.mesh_size

V    = L2(mesh, order=order)
Vbar = FacetFESpace(mesh, order=order, dirichlet='bottom|right|top|left')
X    = FESpace([V, Vbar])
(u, ubar), (v, vbar) = X.TnT()

alpha = 6.0*order**2
dS = dx(element_boundary = True)
a = BilinearForm(X, eliminate_internal=static_condensation)
a += grad(u)*grad(v)*dx
a += -(grad(u)*n)*(v-vbar)*dS
a += -(grad(v)*n)*(u-ubar)*dS
a += (alpha/h)*(v-vbar)*(u-ubar)*dS

f = LinearForm(X)   
f += source*v*dx

gfu = GridFunction(X)    
gfu.components[1].Set(uex, definedon=mesh.Boundaries('bottom|right|top|left'))
gfu2 = GridFunction(X)

cP = Preconditioner(a, type="bddc", inverse="sparsecholesky")

with TaskManager():

    a.Assemble()
    f.Assemble()    
    res = f.vec.CreateVector()
    res.data = f.vec - a.mat * gfu.vec
    if static_condensation:        
        res.data += a.harmonic_extension_trans * res
        if bddcngsolve == True:
            solver = CGSolver(mat=a.mat, pre=cP.mat, maxiter=500, tol=1e-10, printrates="\r")
            solver.Solve(sol=gfu2.vec, rhs=res, initialize=False)
        else:
            solver = KrylovSolver(a, X.FreeDofs(True), solverParameters={"ksp_type": "cg",
                                                                         "ksp_monitor": "",
                                                                         "pc_type": "bddc",
                                                                         "ngs_mat_type": "is"})
            solver.solve(res, gfu2.vec)
        gfu.vec.data += gfu2.vec        
        gfu.vec.data += a.harmonic_extension * gfu.vec
        gfu.vec.data += a.inner_solve * res        
    else:
        solver = KrylovSolver(a, X.FreeDofs(), solverParameters={"ksp_type": "preonly",
                                                                 "pc_type": "lu",
                                                                 "pc_factor_mat_solver_type": "mumps"})
        solver.solve(res, gfu2.vec)
        gfu.vec.data += gfu2.vec

erru = sqrt(Integrate((gfu.components[0]-uex)**2, mesh))
print("err=%2.2e" % erru)
