
from ngsolve import BaseMatrix, BitArray

class BaseDirectSolver(BaseMatrix):
    def __init__(self, a: BaseMatrix, freedofs: BitArray = None):
        import scipy.sparse as sp
        super().__init__()
        self.mat = sp.csr_matrix(a.CSR())
        if freedofs is not None:
            self.freedofs = list(freedofs)
            self.mat = self.mat[self.freedofs,:][:,self.freedofs]
        self._factorized = False


class Pardiso(BaseDirectSolver):
    def __init__(self, a : BaseMatrix, freedofs : BitArray = None):
        super().__init__(a, freedofs)
        try:
            import pypardiso
        except ImportError:
            raise ImportError("Pardiso not available, install pypardiso using `pip install pypardiso`")
        self.inv = pypardiso.PyPardisoSolver()
        self._factorized = False

    def Factorize(self):
        self._factorized = True
        self.inv.factorize(self.mat)

    def Mult(self, x, y):
        if not self._factorized:
            self.Factorize()
        if self.freedofs is not None:
            y.FV().NumPy()[self.freedofs] = self.inv.solve(self.mat, x.FV().NumPy()[self.freedofs])
        else:
            y.FV().NumPy()[:] = self.inv.solve(self.mat, x.FV().NumPy())

class SuperLU(BaseDirectSolver):
    def __init__(self, a: BaseMatrix, freedofs: BitArray = None):
        super().__init__(a, freedofs)
        self._factorized = False

    def Factorize(self):
        import scipy.sparse as sp
        import scipy.sparse.linalg as spla
        self.lu = spla.factorized(sp.csc_matrix(self.mat))
        self._factorized = True

    def Mult(self, x, y):
        if not self._factorized:
            self.Factorize()
        if self.freedofs is not None:
            y.FV().NumPy()[self.freedofs] = self.lu(x.FV().NumPy()[self.freedofs])
        else:
            y.FV().NumPy()[:] = self.lu(x.FV().NumPy())

class Umfpack(BaseDirectSolver):
    def __init__(self, a: BaseMatrix, freedofs: BitArray = None):
        super().__init__(a, freedofs)
        self._factorized = False
        try:
            import scikits.umfpack
        except ImportError as e:
            print(e)
            raise ImportError("Umfpack not available, install scikits.umfpack using `pip install scikit-umfpack`")
        
    def Factorize(self):
        from scikits.umfpack import splu
        self.lu = splu(self.mat)
        self._factorized = True

    def Mult(self, x, y):
        if not self._factorized:
            self.Factorize()
        if self.freedofs is not None:
            y.FV().NumPy()[self.freedofs] = self.lu(x.FV().NumPy()[self.freedofs])
        else:
            y.FV().NumPy()[:] = self.lu(x.FV().NumPy())
