diff --git a/SimPEG/Utils/matutils.py b/SimPEG/Utils/matutils.py index b38bb4a1..ba900c72 100644 --- a/SimPEG/Utils/matutils.py +++ b/SimPEG/Utils/matutils.py @@ -26,6 +26,9 @@ def mkvc(x, numDims=1): if hasattr(x, 'tovec'): x = x.tovec() + if isinstance(x, Zero): + return x + assert isinstance(x, np.ndarray), "Vector must be a numpy array" if numDims == 1: @@ -37,6 +40,9 @@ def mkvc(x, numDims=1): def sdiag(h): """Sparse diagonal matrix""" + if isinstance(h, Zero): + return h + return sp.spdiags(mkvc(h), 0, h.size, h.size, format="csr") def sdInv(M): @@ -417,6 +423,12 @@ class Zero(object): def __ge__(self, v):return 0 >= v def __gt__(self, v):return 0 > v + @property + def transpose(self): return Zero() + + @property + def T(self): return Zero() + class Identity(object): _positive = True def __init__(self, positive=True): diff --git a/tests/utils/test_Zero.py b/tests/utils/test_Zero.py index 594de6a6..7b3c6e5d 100644 --- a/tests/utils/test_Zero.py +++ b/tests/utils/test_Zero.py @@ -1,5 +1,5 @@ import unittest -from SimPEG.Utils import Zero, Identity, sdiag +from SimPEG.Utils import Zero, Identity, sdiag, mkvc from SimPEG import np, sp class Tests(unittest.TestCase): @@ -29,6 +29,11 @@ class Tests(unittest.TestCase): assert a == 1 self.assertRaises(ZeroDivisionError, lambda:3/z) + assert mkvc(z) == 0 + assert sdiag(z)*a == 0 + assert z.T == 0 + assert z.transpose == 0 + def test_mat_zero(self): z = Zero() S = sdiag(np.r_[2,3])