import ngsolve
import netgen.meshing as ngm
import netgen.gui
from netgen.geom2d import SplineGeometry
from ngsolve import Draw, Redraw
import numpy as np
import time


def generate_circle_mesh(n=11, radius=2.0):
    print("--- Generating Circle Mesh ---")

    ngmesh = ngm.Mesh(dim=2)

    # ---------------------------------------------------------
    # 1. Define Regions (Assigns Indices 1, 2, 3, 4)
    # ---------------------------------------------------------

    idx_dom = ngmesh.AddRegion("dom", dim=2)  # Face Index
    idx_bottom = ngmesh.AddRegion("bottom", dim=1)  # Index 1
    idx_right = ngmesh.AddRegion("right", dim=1)  # Index 2
    idx_top = ngmesh.AddRegion("top", dim=1)  # Index 3
    idx_left = ngmesh.AddRegion("left", dim=1)  # Index 4

    print(f"Indices: Bottom={idx_bottom}, Right={idx_right}, Top={idx_top}, Left={idx_left}")

    # ---------------------------------------------------------
    # 2. Define Geometry (Strict Creation Order)
    # ---------------------------------------------------------

    geo = SplineGeometry()

    # Coordinates
    val = radius / np.sqrt(2)
    p_sw = geo.AppendPoint(-val, -val)  # SW
    p_se = geo.AppendPoint(val, -val)  # SE
    p_ne = geo.AppendPoint(val, val)  # NE
    p_nw = geo.AppendPoint(-val, val)  # NW

    # Control Points
    ctrl = radius * np.sqrt(2)
    c_s = geo.AppendPoint(0, -ctrl)  # South
    c_e = geo.AppendPoint(ctrl, 0)  # East
    c_n = geo.AppendPoint(0, ctrl)  # North
    c_w = geo.AppendPoint(-ctrl, 0)  # West

    # --- EDGE 1: Bottom ---
    # Matches idx_bottom (1)
    # Orientation: SW -> SE
    geo.Append(["spline3", p_sw, c_s, p_se], bc=idx_bottom, leftdomain=idx_dom)

    # --- EDGE 2: Right ---
    # Matches idx_right (2)
    # Orientation: SE -> NE
    geo.Append(["spline3", p_se, c_e, p_ne], bc=idx_right, leftdomain=idx_dom)

    # --- EDGE 3: Top ---
    # Matches idx_top (3)
    # Orientation: NE -> NW
    geo.Append(["spline3", p_ne, c_n, p_nw], bc=idx_top, leftdomain=idx_dom)

    # --- EDGE 4: Left ---
    # Matches idx_left (4)
    # Orientation: NW -> SW
    geo.Append(["spline3", p_nw, c_w, p_sw], bc=idx_left, leftdomain=idx_dom)

    # Bind Geometry
    ngmesh.SetGeometry(geo)

    # ---------------------------------------------------------
    # 3. Generate Mesh Points
    # ---------------------------------------------------------
    print("Generating Points...")
    pids = []
    for j in range(n):
        for i in range(n):
            # Logical Square [-R, R]
            x0 = -radius + 2 * radius * (i / (n - 1))
            y0 = -radius + 2 * radius * (j / (n - 1))

            # Elliptical Mapping (Square -> Circle)
            x = x0 * np.sqrt(1 - (y0 ** 2 / (2 * radius ** 2)))
            y = y0 * np.sqrt(1 - (x0 ** 2 / (2 * radius ** 2)))

            pids.append(ngmesh.Add(ngm.MeshPoint(ngm.Pnt(x, y, 0))))

    # ---------------------------------------------------------
    # 4. Connect Elements
    # ---------------------------------------------------------
    print("Connecting 2D Elements...")
    for j in range(n - 1):
        for i in range(n - 1):
            base = j * n + i
            # Quad connectivity
            p_idx = [base, base + 1, base + n + 1, base + n]
            ngmesh.Add(ngm.Element2D(idx_dom, [pids[p] for p in p_idx]))

    # ---------------------------------------------------------
    # 5. Connect Boundaries (Strict Index Matching)
    # ---------------------------------------------------------

    print("Connecting Boundaries...")

    # Bottom (j=0) -> Index 1
    for i in range(n - 1):
        p1 = 0 * n + i
        p2 = 0 * n + (i + 1)
        ngmesh.Add(ngm.Element1D([pids[p1], pids[p2]], index=idx_bottom))

    # Right (i=n-1) -> Index 2
    for j in range(n - 1):
        p1 = j * n + (n - 1)
        p2 = (j + 1) * n + (n - 1)
        ngmesh.Add(ngm.Element1D([pids[p1], pids[p2]], index=idx_right))

    # Top (j=n-1) -> Index 3
    # Note: Mesh points run Right-to-Left here, consistent with Geometry NE->NW
    for i in range(n - 1, 0, -1):
        p1 = (n - 1) * n + i
        p2 = (n - 1) * n + (i - 1)
        ngmesh.Add(ngm.Element1D([pids[p1], pids[p2]], index=idx_top))

    # Left (i=0) -> Index 4
    # Note: Mesh points run Top-to-Bottom here, consistent with Geometry NW->SW
    for j in range(n - 1, 0, -1):
        p1 = j * n + 0
        p2 = (j - 1) * n + 0
        ngmesh.Add(ngm.Element1D([pids[p1], pids[p2]], index=idx_left))

    # ---------------------------------------------------------
    # 6. Curve
    # ---------------------------------------------------------
    print("Curving...")
    mesh = ngsolve.Mesh(ngmesh)

    mesh.Curve(3)

    print("Done.")
    Draw(geo)
    Draw(mesh)

    while True:
        Redraw()
        time.sleep(0.1)


if __name__ == "__main__":
    generate_circle_mesh(n=8, radius=2.0)