from ngsolve import *
from mpi4py import MPI
from netgen.occ import *
from netgen.occ import unit_cube
import sys

comm = MPI.COMM_WORLD

if len(sys.argv) > 1 and sys.argv[1] == "step":
    ngmesh = OCCGeometry("Cube.step").GenerateMesh(maxh=0.1, comm=comm)
else:
    ngmesh = unit_cube.GenerateMesh(maxh=0.1, comm=comm)

mesh = Mesh(ngmesh)

fes = H1(mesh, order=3, dirichlet=".*")
u,v = fes.TnT()

a = BilinearForm(grad(u)*grad(v)*dx)
pre = preconditioners.Local(a)
a.Assemble()

f = LinearForm(1*v*dx).Assemble()
gfu = GridFunction(fes)

inv = CGSolver(a.mat, pre.mat)
gfu.vec.data = inv*f.vec

ip = InnerProduct(gfu.vec, f.vec)
printonce("(u,f) =", ip)

import pickle
netgen.meshing.SetParallelPickling(True)
pickle.dump (gfu, open("solution.pickle"+str(comm.rank), "wb"))
