BaseTimeDomain Problem in SimPEG

This commit is contained in:
rowanc1
2014-04-15 16:42:40 -07:00
parent cbd11d0ae4
commit 3a0942e616
2 changed files with 97 additions and 0 deletions
+72
View File
@@ -141,3 +141,75 @@ class BaseProblem(object):
class BaseTimeProblem(BaseProblem):
"""Sets up that basic needs of a time domain problem."""
@property
def timeSteps(self):
"""Sets/gets the timeSteps for the time domain problem.
You can set as an array of dt's or as a list of tuples/floats.
Tuples must be length two with [..., (dt, repeat), ...]
For example, the following setters are the same::
prob.timeSteps = [(1e-6, 3), 1e-5, (1e-4, 2)]
prob.timeSteps = np.r_[1e-6,1e-6,1e-6,1e-5,1e-4,1e-4]
"""
assert hasattr(self, '_timeSteps'), 'The timeSteps have not yet been set.'
return self._timeSteps
@timeSteps.setter
def timeSteps(self, value):
if type(value) is np.ndarray:
self._timeSteps = value
del self.timeMesh
return
if type(value) is not list:
raise Exception('timeSteps must be a np.ndarray or a list of scalars and tuples.')
proposed = []
for v in value:
if Utils.isScalar(v):
proposed += [float(v)]
elif type(v) is tuple and len(v) == 2:
proposed += [float(v[0])]*int(v[1])
else:
raise Exception('timeSteps list must contain only scalars and len(2) tuples.')
self._timeSteps = np.array(proposed)
del self.timeMesh
@property
def nT(self):
"Number of time steps."
return self.timeMesh.nC
@property
def t0(self):
return getattr(self, '_t0', 0.0)
@t0.setter
def t0(self, value):
assert Utils.isScalar(value), 't0 must be a scalar'
del self.timeMesh
self._t0 = float(value)
@property
def times(self):
"Modeling times"
return self.timeMesh.vectorNx
@property
def timeMesh(self):
if getattr(self, '_timeMesh', None) is None:
self._timeMesh = Mesh.TensorMesh([self.timeSteps], x0=[self.t0])
return self._timeMesh
@timeMesh.deleter
def timeMesh(self):
if hasattr(self, '_timeMesh'):
del self._timeMesh
+25
View File
@@ -0,0 +1,25 @@
import unittest
from SimPEG import *
class TestTimeProblem(unittest.TestCase):
def setUp(self):
mesh = Mesh.TensorMesh([10,10])
self.prob = Problem.BaseTimeProblem(mesh)
def test_timeProblem_setTimeSteps(self):
self.prob.timeSteps = [(1e-6, 3), 1e-5, (1e-4, 2)]
trueTS = np.r_[1e-6,1e-6,1e-6,1e-5,1e-4,1e-4]
self.assertTrue(np.all(trueTS == self.prob.timeSteps))
self.prob.timeSteps = trueTS
self.assertTrue(np.all(trueTS == self.prob.timeSteps))
self.assertTrue(self.prob.nT == 6)
self.assertTrue(np.all(self.prob.times == np.r_[0,trueTS].cumsum()))
if __name__ == '__main__':
unittest.main()