import time
import tracemalloc
import os
import psutil
from ngsolve import Compress, Mesh, H1, unit_square, FESpace, CompressCompound
import netgen.occ as occ
import os

with open("pid.txt", "w") as f:
    f.write(str(os.getpid()))

# Create a mesh
box_1 = occ.Box((0,0,0), (4,4,4))
box_1.mat("box_1")
box_2 = occ.Box((0.5,0.5,0.5), (5,5,5))
box_2.mat("box_2")
geo = occ.Glue([box_1, box_2])
occgeo = occ.OCCGeometry(geo)
mesh = Mesh(occgeo.GenerateMesh(maxh=0.2))

print(mesh.GetMaterials())
# Track time and memory
n_spaces = 3000
spaces = []

print(f"Creating {n_spaces} H1 spaces with definedon...")

# Time and memory tracking
start_time = time.time()
tracemalloc.start()

process = psutil.Process(os.getpid())

for i in range(n_spaces):
    fes = H1(mesh, order=2, definedon=mesh.Materials('box_2'))
    #fes = Compress(H1(mesh, order=2, definedon=mesh.Materials('box_2')))
    if i == 0:
        print(f"DOFs {fes.ndof}")
    spaces.append(fes)

    if i % 500 == 0:
        #fes_comp = FESpace(spaces)
        #fes_comp = CompressCompound(fes_comp)
        rss = process.memory_info().rss / (1024 * 1024)
        print(f"[{i}] RSS memory: {rss:.2f} MB")


current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
elapsed = time.time() - start_time

rss_final = process.memory_info().rss / (1024 * 1024)

# Summary
print(f"\nTime taken: {elapsed:.2f} seconds")
print(f"Peak Python memory (tracemalloc): {peak / (1024*1024):.2f} MB")
print(f"Current Python memory: {current / (1024*1024):.2f} MB")
print(f"Final RSS memory (total process): {rss_final:.2f} MB")

