Zero and Identity - Useful for writing derivatives.

These should work with sparse matrices and numpy arrays.

```python
	z = Zero()
	z*A == 0
	o = Identity()
	o*A == A
```
This commit is contained in:
Rowan Cockett
2015-10-30 12:21:30 -07:00
parent ed463c736f
commit 0885b72577
2 changed files with 185 additions and 0 deletions
+129
View File
@@ -0,0 +1,129 @@
import unittest
from SimPEG.Utils import Zero, Identity, sdiag
from SimPEG import np, sp
class Tests(unittest.TestCase):
def test_zero(self):
z = Zero()
assert z == 0
assert not (z < 0)
assert z <= 0
assert not (z > 0)
assert z >= 0
assert +z == z
assert -z == z
assert z + 1 == 1
assert z + 3 +z == 3
assert z - 3 == -3
assert z - 3 -z == -3
assert 3*z == 0
assert z*3 == 0
assert z/3 == 0
self.assertRaises(ZeroDivisionError, lambda:3/z)
def test_mat_zero(self):
z = Zero()
S = sdiag(np.r_[2,3])
assert S*z == 0
def test_one(self):
o = Identity()
assert o == 1
assert not (o < 1)
assert o <= 1
assert not (o > 1)
assert o >= 1
o = -o
assert o == -1
assert not (o < -1)
assert o <= -1
assert not (o > -1)
assert o >= -1
assert -(-o)*o == -o
o = Identity()
assert +o == o
assert -o == -o
assert o*3 == 3
assert -o*3 == -3
assert -o*o == -1
assert -o*o*-o == 1
assert -o + 3 == 2
assert 3 + -o == 2
assert -o - 3 == -4
assert o - 3 == -2
assert 3 - -o == 4
assert 3 - o == 2
assert o/2 == 0
assert o/2. == 0.5
assert -o/2 == -1
assert -o/2. == -0.5
assert 2/o == 2
assert 2/-o == -2
def test_mat_one(self):
o = Identity()
S = sdiag(np.r_[2,3])
def check(exp,ans):
assert np.all((exp).todense() == ans)
check(S * o, [[2,0],[0,3]])
check(o * S, [[2,0],[0,3]])
check(S * -o, [[-2,0],[0,-3]])
check(-o * S, [[-2,0],[0,-3]])
check(S/o, [[2,0],[0,3]])
check(S/-o, [[-2,0],[0,-3]])
self.assertRaises(NotImplementedError, lambda:o/S)
check(S + o, [[3,0],[0,4]])
check(o + S, [[3,0],[0,4]])
check(S + - o, [[1,0],[0,2]])
check(S - o, [[1,0],[0,2]])
check(- o + S, [[1,0],[0,2]])
def test_mat_shape(self):
o = Identity()
S = sdiag(np.r_[2,3])[:1,:]
self.assertRaises(ValueError, lambda:S + o)
def check(exp,ans):
assert np.all((exp).todense() == ans)
check(S * o, [[2,0]])
check(S * -o, [[-2,0]])
def test_numpy_one(self):
o = Identity()
n = np.r_[2.,3]
assert np.all(n+1 == n+o)
assert np.all(1+n == o+n)
assert np.all(n-1 == n-o)
assert np.all(1-n == o-n)
assert np.all(n/1 == n/o)
assert np.all(n/-1 == n/-o)
assert np.all(1/n == o/n)
assert np.all(-1/n == -o/n)
assert np.all(n*1 == n*o)
assert np.all(n*-1 == n*-o)
assert np.all(1*n == o*n)
assert np.all(-1*n == -o*n)
def test_both(self):
z = Zero()
o = Identity()
assert o*z == 0
assert o*z + o == 1
assert o-z == 1
if __name__ == '__main__':
unittest.main()
+56
View File
@@ -396,4 +396,60 @@ def diagEst(matFun, n, k=None, approach='Probing'):
return d
class Zero(object):
def __add__(self, v):return v
def __radd__(self, v):return v
def __sub__(self, v):return -v
def __rsub__(self, v):return v
def __mul__(self, v):return self
def __rmul__(self, v):return self
def __div__(self, v): return self
def __truediv__(self, v): return self
def __rdiv__(self, v): raise ZeroDivisionError('Cannot divide by zero.')
def __pos__(self):return self
def __neg__(self):return self
def __lt__(self, v):return 0 < v
def __le__(self, v):return 0 <= v
def __eq__(self, v):return v == 0
def __ne__(self, v):return not (0 == v)
def __ge__(self, v):return 0 >= v
def __gt__(self, v):return 0 > v
class Identity(object):
_positive = True
def __init__(self, positive=True):
self._positive = positive is True
def __pos__(self):return self
def __neg__(self):return Identity(not self._positive)
def __add__(self, v):
if sp.issparse(v):
return v + speye(v.shape[0]) if self._positive else v - speye(v.shape[0])
return v + 1 if self._positive else v - 1
def __radd__(self, v):
return self.__add__(v)
def __sub__(self, v): return self+-v
def __rsub__(self, v):return -self+v
def __mul__(self, v): return v if self._positive else -v
def __rmul__(self, v):return v if self._positive else -v
def __div__(self, v):
if sp.issparse(v): raise NotImplementedError('Sparse arrays not divisibile.')
return 1/v if self._positive else -1/v
def __truediv__(self, v):
if sp.issparse(v): raise NotImplementedError('Sparse arrays not divisibile.')
return 1.0/v if self._positive else -1.0/v
def __rdiv__(self, v):
return v if self._positive else -v
def __lt__(self, v):return 1 < v if self._positive else -1 < v
def __le__(self, v):return 1 <= v if self._positive else -1 <= v
def __eq__(self, v):return v == 1 if self._positive else v == -1
def __ne__(self, v):return (not (1 == v))if self._positive else (not (-1 == v))
def __ge__(self, v):return 1 >= v if self._positive else -1 >= v
def __gt__(self, v):return 1 > v if self._positive else -1 > v