diff --git a/SimPEG/Tests/test_Zero.py b/SimPEG/Tests/test_Zero.py new file mode 100644 index 00000000..be7c9bbe --- /dev/null +++ b/SimPEG/Tests/test_Zero.py @@ -0,0 +1,129 @@ +import unittest +from SimPEG.Utils import Zero, Identity, sdiag +from SimPEG import np, sp + +class Tests(unittest.TestCase): + + def test_zero(self): + z = Zero() + assert z == 0 + assert not (z < 0) + assert z <= 0 + assert not (z > 0) + assert z >= 0 + assert +z == z + assert -z == z + assert z + 1 == 1 + assert z + 3 +z == 3 + assert z - 3 == -3 + assert z - 3 -z == -3 + assert 3*z == 0 + assert z*3 == 0 + assert z/3 == 0 + self.assertRaises(ZeroDivisionError, lambda:3/z) + + def test_mat_zero(self): + z = Zero() + S = sdiag(np.r_[2,3]) + assert S*z == 0 + + def test_one(self): + o = Identity() + assert o == 1 + assert not (o < 1) + assert o <= 1 + assert not (o > 1) + assert o >= 1 + o = -o + assert o == -1 + assert not (o < -1) + assert o <= -1 + assert not (o > -1) + assert o >= -1 + assert -(-o)*o == -o + o = Identity() + assert +o == o + assert -o == -o + assert o*3 == 3 + assert -o*3 == -3 + assert -o*o == -1 + assert -o*o*-o == 1 + assert -o + 3 == 2 + assert 3 + -o == 2 + + assert -o - 3 == -4 + assert o - 3 == -2 + assert 3 - -o == 4 + assert 3 - o == 2 + + assert o/2 == 0 + assert o/2. == 0.5 + assert -o/2 == -1 + assert -o/2. == -0.5 + assert 2/o == 2 + assert 2/-o == -2 + + + def test_mat_one(self): + + o = Identity() + S = sdiag(np.r_[2,3]) + def check(exp,ans): + assert np.all((exp).todense() == ans) + + check(S * o, [[2,0],[0,3]]) + check(o * S, [[2,0],[0,3]]) + check(S * -o, [[-2,0],[0,-3]]) + check(-o * S, [[-2,0],[0,-3]]) + check(S/o, [[2,0],[0,3]]) + check(S/-o, [[-2,0],[0,-3]]) + self.assertRaises(NotImplementedError, lambda:o/S) + + check(S + o, [[3,0],[0,4]]) + check(o + S, [[3,0],[0,4]]) + + check(S + - o, [[1,0],[0,2]]) + check(S - o, [[1,0],[0,2]]) + check(- o + S, [[1,0],[0,2]]) + + def test_mat_shape(self): + o = Identity() + S = sdiag(np.r_[2,3])[:1,:] + self.assertRaises(ValueError, lambda:S + o) + def check(exp,ans): + assert np.all((exp).todense() == ans) + check(S * o, [[2,0]]) + check(S * -o, [[-2,0]]) + + def test_numpy_one(self): + o = Identity() + n = np.r_[2.,3] + assert np.all(n+1 == n+o) + assert np.all(1+n == o+n) + assert np.all(n-1 == n-o) + assert np.all(1-n == o-n) + assert np.all(n/1 == n/o) + assert np.all(n/-1 == n/-o) + assert np.all(1/n == o/n) + assert np.all(-1/n == -o/n) + assert np.all(n*1 == n*o) + assert np.all(n*-1 == n*-o) + assert np.all(1*n == o*n) + assert np.all(-1*n == -o*n) + + def test_both(self): + z = Zero() + o = Identity() + assert o*z == 0 + assert o*z + o == 1 + assert o-z == 1 + + + +if __name__ == '__main__': + unittest.main() + + + + + diff --git a/SimPEG/Utils/matutils.py b/SimPEG/Utils/matutils.py index 7ba100a0..ee58bf86 100644 --- a/SimPEG/Utils/matutils.py +++ b/SimPEG/Utils/matutils.py @@ -396,4 +396,60 @@ def diagEst(matFun, n, k=None, approach='Probing'): return d +class Zero(object): + def __add__(self, v):return v + def __radd__(self, v):return v + def __sub__(self, v):return -v + def __rsub__(self, v):return v + def __mul__(self, v):return self + def __rmul__(self, v):return self + def __div__(self, v): return self + def __truediv__(self, v): return self + def __rdiv__(self, v): raise ZeroDivisionError('Cannot divide by zero.') + def __pos__(self):return self + def __neg__(self):return self + def __lt__(self, v):return 0 < v + def __le__(self, v):return 0 <= v + def __eq__(self, v):return v == 0 + def __ne__(self, v):return not (0 == v) + def __ge__(self, v):return 0 >= v + def __gt__(self, v):return 0 > v + +class Identity(object): + _positive = True + def __init__(self, positive=True): + self._positive = positive is True + + def __pos__(self):return self + def __neg__(self):return Identity(not self._positive) + + def __add__(self, v): + if sp.issparse(v): + return v + speye(v.shape[0]) if self._positive else v - speye(v.shape[0]) + return v + 1 if self._positive else v - 1 + def __radd__(self, v): + return self.__add__(v) + + def __sub__(self, v): return self+-v + def __rsub__(self, v):return -self+v + + def __mul__(self, v): return v if self._positive else -v + def __rmul__(self, v):return v if self._positive else -v + + def __div__(self, v): + if sp.issparse(v): raise NotImplementedError('Sparse arrays not divisibile.') + return 1/v if self._positive else -1/v + def __truediv__(self, v): + if sp.issparse(v): raise NotImplementedError('Sparse arrays not divisibile.') + return 1.0/v if self._positive else -1.0/v + def __rdiv__(self, v): + return v if self._positive else -v + + def __lt__(self, v):return 1 < v if self._positive else -1 < v + def __le__(self, v):return 1 <= v if self._positive else -1 <= v + def __eq__(self, v):return v == 1 if self._positive else v == -1 + def __ne__(self, v):return (not (1 == v))if self._positive else (not (-1 == v)) + def __ge__(self, v):return 1 >= v if self._positive else -1 >= v + def __gt__(self, v):return 1 > v if self._positive else -1 > v +