mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-27 18:25:42 +08:00
Zero and Identity - Useful for writing derivatives.
These should work with sparse matrices and numpy arrays. ```python z = Zero() z*A == 0 o = Identity() o*A == A ```
This commit is contained in:
@@ -0,0 +1,129 @@
|
||||
import unittest
|
||||
from SimPEG.Utils import Zero, Identity, sdiag
|
||||
from SimPEG import np, sp
|
||||
|
||||
class Tests(unittest.TestCase):
|
||||
|
||||
def test_zero(self):
|
||||
z = Zero()
|
||||
assert z == 0
|
||||
assert not (z < 0)
|
||||
assert z <= 0
|
||||
assert not (z > 0)
|
||||
assert z >= 0
|
||||
assert +z == z
|
||||
assert -z == z
|
||||
assert z + 1 == 1
|
||||
assert z + 3 +z == 3
|
||||
assert z - 3 == -3
|
||||
assert z - 3 -z == -3
|
||||
assert 3*z == 0
|
||||
assert z*3 == 0
|
||||
assert z/3 == 0
|
||||
self.assertRaises(ZeroDivisionError, lambda:3/z)
|
||||
|
||||
def test_mat_zero(self):
|
||||
z = Zero()
|
||||
S = sdiag(np.r_[2,3])
|
||||
assert S*z == 0
|
||||
|
||||
def test_one(self):
|
||||
o = Identity()
|
||||
assert o == 1
|
||||
assert not (o < 1)
|
||||
assert o <= 1
|
||||
assert not (o > 1)
|
||||
assert o >= 1
|
||||
o = -o
|
||||
assert o == -1
|
||||
assert not (o < -1)
|
||||
assert o <= -1
|
||||
assert not (o > -1)
|
||||
assert o >= -1
|
||||
assert -(-o)*o == -o
|
||||
o = Identity()
|
||||
assert +o == o
|
||||
assert -o == -o
|
||||
assert o*3 == 3
|
||||
assert -o*3 == -3
|
||||
assert -o*o == -1
|
||||
assert -o*o*-o == 1
|
||||
assert -o + 3 == 2
|
||||
assert 3 + -o == 2
|
||||
|
||||
assert -o - 3 == -4
|
||||
assert o - 3 == -2
|
||||
assert 3 - -o == 4
|
||||
assert 3 - o == 2
|
||||
|
||||
assert o/2 == 0
|
||||
assert o/2. == 0.5
|
||||
assert -o/2 == -1
|
||||
assert -o/2. == -0.5
|
||||
assert 2/o == 2
|
||||
assert 2/-o == -2
|
||||
|
||||
|
||||
def test_mat_one(self):
|
||||
|
||||
o = Identity()
|
||||
S = sdiag(np.r_[2,3])
|
||||
def check(exp,ans):
|
||||
assert np.all((exp).todense() == ans)
|
||||
|
||||
check(S * o, [[2,0],[0,3]])
|
||||
check(o * S, [[2,0],[0,3]])
|
||||
check(S * -o, [[-2,0],[0,-3]])
|
||||
check(-o * S, [[-2,0],[0,-3]])
|
||||
check(S/o, [[2,0],[0,3]])
|
||||
check(S/-o, [[-2,0],[0,-3]])
|
||||
self.assertRaises(NotImplementedError, lambda:o/S)
|
||||
|
||||
check(S + o, [[3,0],[0,4]])
|
||||
check(o + S, [[3,0],[0,4]])
|
||||
|
||||
check(S + - o, [[1,0],[0,2]])
|
||||
check(S - o, [[1,0],[0,2]])
|
||||
check(- o + S, [[1,0],[0,2]])
|
||||
|
||||
def test_mat_shape(self):
|
||||
o = Identity()
|
||||
S = sdiag(np.r_[2,3])[:1,:]
|
||||
self.assertRaises(ValueError, lambda:S + o)
|
||||
def check(exp,ans):
|
||||
assert np.all((exp).todense() == ans)
|
||||
check(S * o, [[2,0]])
|
||||
check(S * -o, [[-2,0]])
|
||||
|
||||
def test_numpy_one(self):
|
||||
o = Identity()
|
||||
n = np.r_[2.,3]
|
||||
assert np.all(n+1 == n+o)
|
||||
assert np.all(1+n == o+n)
|
||||
assert np.all(n-1 == n-o)
|
||||
assert np.all(1-n == o-n)
|
||||
assert np.all(n/1 == n/o)
|
||||
assert np.all(n/-1 == n/-o)
|
||||
assert np.all(1/n == o/n)
|
||||
assert np.all(-1/n == -o/n)
|
||||
assert np.all(n*1 == n*o)
|
||||
assert np.all(n*-1 == n*-o)
|
||||
assert np.all(1*n == o*n)
|
||||
assert np.all(-1*n == -o*n)
|
||||
|
||||
def test_both(self):
|
||||
z = Zero()
|
||||
o = Identity()
|
||||
assert o*z == 0
|
||||
assert o*z + o == 1
|
||||
assert o-z == 1
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -396,4 +396,60 @@ def diagEst(matFun, n, k=None, approach='Probing'):
|
||||
|
||||
return d
|
||||
|
||||
class Zero(object):
|
||||
def __add__(self, v):return v
|
||||
def __radd__(self, v):return v
|
||||
def __sub__(self, v):return -v
|
||||
def __rsub__(self, v):return v
|
||||
def __mul__(self, v):return self
|
||||
def __rmul__(self, v):return self
|
||||
def __div__(self, v): return self
|
||||
def __truediv__(self, v): return self
|
||||
def __rdiv__(self, v): raise ZeroDivisionError('Cannot divide by zero.')
|
||||
def __pos__(self):return self
|
||||
def __neg__(self):return self
|
||||
def __lt__(self, v):return 0 < v
|
||||
def __le__(self, v):return 0 <= v
|
||||
def __eq__(self, v):return v == 0
|
||||
def __ne__(self, v):return not (0 == v)
|
||||
def __ge__(self, v):return 0 >= v
|
||||
def __gt__(self, v):return 0 > v
|
||||
|
||||
class Identity(object):
|
||||
_positive = True
|
||||
def __init__(self, positive=True):
|
||||
self._positive = positive is True
|
||||
|
||||
def __pos__(self):return self
|
||||
def __neg__(self):return Identity(not self._positive)
|
||||
|
||||
def __add__(self, v):
|
||||
if sp.issparse(v):
|
||||
return v + speye(v.shape[0]) if self._positive else v - speye(v.shape[0])
|
||||
return v + 1 if self._positive else v - 1
|
||||
def __radd__(self, v):
|
||||
return self.__add__(v)
|
||||
|
||||
def __sub__(self, v): return self+-v
|
||||
def __rsub__(self, v):return -self+v
|
||||
|
||||
def __mul__(self, v): return v if self._positive else -v
|
||||
def __rmul__(self, v):return v if self._positive else -v
|
||||
|
||||
def __div__(self, v):
|
||||
if sp.issparse(v): raise NotImplementedError('Sparse arrays not divisibile.')
|
||||
return 1/v if self._positive else -1/v
|
||||
def __truediv__(self, v):
|
||||
if sp.issparse(v): raise NotImplementedError('Sparse arrays not divisibile.')
|
||||
return 1.0/v if self._positive else -1.0/v
|
||||
def __rdiv__(self, v):
|
||||
return v if self._positive else -v
|
||||
|
||||
def __lt__(self, v):return 1 < v if self._positive else -1 < v
|
||||
def __le__(self, v):return 1 <= v if self._positive else -1 <= v
|
||||
def __eq__(self, v):return v == 1 if self._positive else v == -1
|
||||
def __ne__(self, v):return (not (1 == v))if self._positive else (not (-1 == v))
|
||||
def __ge__(self, v):return 1 >= v if self._positive else -1 >= v
|
||||
def __gt__(self, v):return 1 > v if self._positive else -1 > v
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user