From df3ba5f8eeb17d15f98ea6f342fa819385ca45e3 Mon Sep 17 00:00:00 2001 From: rowanc1 Date: Sat, 26 Apr 2014 17:52:11 -0700 Subject: [PATCH] Fancy timeFields. --- SimPEG/Survey.py | 64 +++++++++++++++++++++++++++--- SimPEG/Tests/test_Survey.py | 78 +++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 6 deletions(-) diff --git a/SimPEG/Survey.py b/SimPEG/Survey.py index 1ed649db..a6acd1a5 100644 --- a/SimPEG/Survey.py +++ b/SimPEG/Survey.py @@ -198,9 +198,9 @@ class Fields(object): """ - knownFields = None + knownFields = None #: the known fields, a dict with locations, e.g. {"phi": "CC", "b": "F"} txPair = BaseTx - dtype = float + dtype = float #: dtype is the type of the storage matrix. This can be a dictionary. def __init__(self, mesh, survey, **kwargs): self.survey = survey @@ -221,6 +221,7 @@ class Fields(object): loc = self.knownFields[name] nP = {'CC': self.mesh.nC, + 'N': self.mesh.nN, 'F': self.mesh.nF, 'E': self.mesh.nE}[loc] @@ -282,23 +283,74 @@ class Fields(object): NEWF = newFields[name] if field.shape[1] == 1 or NEWF.ndim == 1: NEWF = Utils.mkvc(NEWF,2) - field[:,ind] = NEWF + self._setField(field, NEWF, ind) def __getitem__(self, key): ind, name = self._indexAndNameFromKey(key) if name is None: out = {} for name in self._fields: - out[name] = self._fields[name][:,ind] - if out[name].shape[1] == 1: - out[name] = Utils.mkvc(out[name]) + out[name] = self._getField(name, ind) return out + return self._getField(name, ind) + def _setField(self, field, val, ind): + field[:,ind] = val + + def _getField(self, name, ind): out = self._fields[name][:,ind] if out.shape[1] == 1: out = Utils.mkvc(out) return out +class TimeFields(Fields): + """Fancy Field Storage for time domain problems + + u[:,'phi', timeInd] = phi + print u[tx0,'phi'] + + """ + + def _storageShape(self, nP): + nTx = self.survey.nTx + nT = self.survey.prob.nT + return (nP, nTx, nT) + + def _indexAndNameFromKey(self, key): + if type(key) is not tuple: + key = (key,) + if len(key) == 1: + key += (None,) + if len(key) == 2: + key += (slice(None,None,None),) + + assert len(key) == 3, 'must be [Tx, fieldName, times]' + + txTestList, name, timeInd = key + + if name is not None and name not in self.knownFields: + raise KeyError('Invalid field name') + + txInd = self._txIndex(txTestList) + + return (txInd, timeInd), name + + def _setField(self, field, val, ind): + txInd, timeInd = ind + if val.ndim == 2: + val = val[:, np.newaxis, :] + field[:,txInd,timeInd] = val + + def _getField(self, name, ind): + txInd, timeInd = ind + out = self._fields[name][:,txInd,timeInd] + if out.shape[1] == 1: + if out.ndim == 2: + out = out[:,0] + else: + out = out[:,0,:] + return out + class BaseSurvey(object): """Survey holds the observed data, and the standard deviations.""" diff --git a/SimPEG/Tests/test_Survey.py b/SimPEG/Tests/test_Survey.py index 3a40cddf..4c0b1a8a 100644 --- a/SimPEG/Tests/test_Survey.py +++ b/SimPEG/Tests/test_Survey.py @@ -95,5 +95,83 @@ class DataAndFieldsTest(unittest.TestCase): def fun(): self.F[freq,'notThere'] self.assertRaises(KeyError, fun) + + +class FieldsTest_Time(unittest.TestCase): + + def setUp(self): + mesh = Mesh.TensorMesh([np.ones(n)*5 for n in [10,11,12]],[0,0,-30]) + x = np.linspace(5,10,3) + XYZ = Utils.ndgrid(x,x,np.r_[0.]) + txLoc = np.r_[0,0,0.] + rxList0 = Survey.BaseRx(XYZ, 'exi') + Tx0 = Survey.BaseTx(txLoc, 'VMD', [rxList0]) + rxList1 = Survey.BaseRx(XYZ, 'bxi') + Tx1 = Survey.BaseTx(txLoc, 'VMD', [rxList1]) + rxList2 = Survey.BaseRx(XYZ, 'bxi') + Tx2 = Survey.BaseTx(txLoc, 'VMD', [rxList2]) + rxList3 = Survey.BaseRx(XYZ, 'bxi') + Tx3 = Survey.BaseTx(txLoc, 'VMD', [rxList3]) + Tx4 = Survey.BaseTx(txLoc, 'VMD', [rxList0, rxList1, rxList2, rxList3]) + txList = [Tx0,Tx1,Tx2,Tx3,Tx4] + survey = Survey.BaseSurvey(txList=txList) + prob = Problem.BaseTimeProblem(mesh, timeSteps=[(10.,3), (20.,2)]) + survey.pair(prob) + self.F = Survey.TimeFields(mesh, survey, knownFields={'phi':'CC','e':'E','b':'F'}) + self.Tx0 = Tx0 + self.Tx1 = Tx1 + self.mesh = mesh + self.XYZ = XYZ + + def test_SetGet(self): + F = self.F + nTx = F.survey.nTx + nT = F.survey.prob.nT + e = np.random.rand(F.mesh.nE, nTx, nT) + F[:, 'e'] = e + b = np.random.rand(F.mesh.nF, nTx, nT) + F[:, 'b'] = b + + self.assertTrue(np.all(F[:, 'e'] == e)) + self.assertTrue(np.all(F[:, 'b'] == b)) + F[:] = {'b':b,'e':e} + self.assertTrue(np.all(F[:, 'e'] == e)) + self.assertTrue(np.all(F[:, 'b'] == b)) + + b = np.random.rand(F.mesh.nF,nT) + F[self.Tx0, 'b'] = b + self.assertTrue(np.all(F[self.Tx0, 'b'] == b)) + + phi = np.random.rand(F.mesh.nC,2,nT) + F[[self.Tx0,self.Tx1], 'phi'] = phi + self.assertTrue(np.all(F[[self.Tx0,self.Tx1], 'phi'] == phi)) + + fdict = F[:] + self.assertTrue(type(fdict) is dict) + self.assertTrue(sorted([k for k in fdict]) == ['b','e','phi']) + + b = np.random.rand(F.mesh.nF, 2, nT) + F[[self.Tx0, self.Tx1],'b'] = b + self.assertTrue(F[self.Tx0]['b'].shape == (F.mesh.nF,nT)) + self.assertTrue(F[self.Tx0,'b'].shape == (F.mesh.nF,nT)) + self.assertTrue(np.all(F[self.Tx0,'b'] == b[:,0,:])) + self.assertTrue(np.all(F[self.Tx1,'b'] == b[:,1,:])) + self.assertTrue(np.all(F[self.Tx0,'b',1] == b[:,0,1])) + self.assertTrue(np.all(F[self.Tx1,'b',1] == b[:,1,1])) + self.assertTrue(np.all(F[self.Tx0,'b',4] == b[:,0,4])) + self.assertTrue(np.all(F[self.Tx1,'b',4] == b[:,1,4])) + + def test_assertions(self): + freq = [self.Tx0, self.Tx1] + bWrongSize = np.random.rand(self.F.mesh.nE, self.F.survey.nTx) + def fun(): self.F[freq, 'b'] = bWrongSize + self.assertRaises(ValueError, fun) + def fun(): self.F[-999.] + self.assertRaises(KeyError, fun) + def fun(): self.F['notRight'] + self.assertRaises(KeyError, fun) + def fun(): self.F[freq,'notThere'] + self.assertRaises(KeyError, fun) + if __name__ == '__main__': unittest.main()