diff --git a/SimPEG/Survey.py b/SimPEG/Survey.py index dc5aeadb..2d2a5a00 100644 --- a/SimPEG/Survey.py +++ b/SimPEG/Survey.py @@ -190,6 +190,106 @@ class Data(object): indBot += rx.nD +class Fields(object): + """Fancy Field Storage + + u[:,'phi'] = phi + print u[tx0,'phi'] + + """ + + knownFields = None + txPair = BaseTx + + def __init__(self, mesh, survey, **kwargs): + self.survey = survey + self.mesh = mesh + Utils.setKwargs(self, **kwargs) + self._fields = {} + + def _initStore(self, name): + if name in self._fields: + return self._fields[name] + + assert name in self.knownFields, 'field name is not known.' + + loc = self.knownFields[name] + + nP = {'CC': self.mesh.nC, + 'F': self.mesh.nF, + 'E': self.mesh.nE}[loc] + + nTx = self.survey.nTx + field = np.empty((nP, nTx)) + + self._fields[name] = field + + return field + + def _indexAndNameFromKey(self, key): + if type(key) is not tuple: + key = (key,) + if len(key) == 1: + key += (None,) + + assert len(key) == 2, 'must be [Tx, fieldName]' + + txTestList, name = key + + if name is not None and name not in self.knownFields: + raise KeyError('Invalid field name') + + if type(txTestList) is slice: + ind = txTestList + else: + if type(txTestList) is not list: + txTestList = [txTestList] + for txTest in txTestList: + if not isinstance(txTest, self.txPair): + raise KeyError('First index must be a Transmitter') + if txTest not in self.survey.txList: + raise KeyError('Invalid Transmitter, not in survey list.') + + ind = np.in1d(self.survey.txList, txTestList) + + return ind, name + + def __setitem__(self, key, value): + ind, name = self._indexAndNameFromKey(key) + if name is None: + freq = key + assert type(value) is dict, 'New fields must be a dictionary, if field is not specified.' + newFields = value + elif name in self.knownFields: + assert type(value) is np.ndarray, 'Must be set to a numpy array' + newFields = {name: value} + else: + raise Exception('Unknown setter') + + for name in newFields: + field = self._initStore(name) + NEWF = newFields[name] + if field.shape[1] == 1 or NEWF.ndim == 1: + NEWF = Utils.mkvc(NEWF,2) + field[:,ind] = NEWF + + def __getitem__(self, key): + ind, name = self._indexAndNameFromKey(key) + if name is None: + out = {} + for name in self._fields: + out[name] = self._fields[name][:,ind] + if out[name].shape[1] == 1: + out[name] = Utils.mkvc(out[name]) + return out + + out = self._fields[name][:,ind] + if out.shape[1] == 1: + out = Utils.mkvc(out) + return out + + + class BaseSurvey(object): """Survey holds the observed data, and the standard deviations.""" diff --git a/SimPEG/Tests/test_Survey.py b/SimPEG/Tests/test_Survey.py index f9866bb5..3a40cddf 100644 --- a/SimPEG/Tests/test_Survey.py +++ b/SimPEG/Tests/test_Survey.py @@ -1,7 +1,7 @@ import unittest from SimPEG import * -class DataTest(unittest.TestCase): +class DataAndFieldsTest(unittest.TestCase): def setUp(self): mesh = Mesh.TensorMesh([np.ones(n)*5 for n in [10,11,12]],[0,0,-30]) @@ -20,6 +20,7 @@ class DataTest(unittest.TestCase): txList = [Tx0,Tx1,Tx2,Tx3,Tx4] survey = Survey.BaseSurvey(txList=txList) self.D = Survey.Data(survey) + self.F = Survey.Fields(mesh, survey, knownFields={'phi':'CC','e':'E','b':'F'}) self.Tx0 = Tx0 self.Tx1 = Tx1 self.mesh = mesh @@ -45,7 +46,54 @@ class DataTest(unittest.TestCase): txs += [txs[0]] self.assertRaises(AssertionError, Survey.BaseSurvey, txList=txs) + def test_SetGet(self): + F = self.F + nTx = F.survey.nTx + e = np.random.rand(F.mesh.nE, nTx) + F[:, 'e'] = e + b = np.random.rand(F.mesh.nF, nTx) + F[:, 'b'] = b + self.assertTrue(np.all(F[:, 'e'] == e)) + self.assertTrue(np.all(F[:, 'b'] == b)) + F[:] = {'b':b,'e':e} + self.assertTrue(np.all(F[:, 'e'] == e)) + self.assertTrue(np.all(F[:, 'b'] == b)) + + b = np.random.rand(F.mesh.nF,1) + F[self.Tx0, 'b'] = b + self.assertTrue(np.all(F[self.Tx0, 'b'] == Utils.mkvc(b))) + + b = np.random.rand(F.mesh.nF) + F[self.Tx0, 'b'] = b + self.assertTrue(np.all(F[self.Tx0, 'b'] == b)) + + phi = np.random.rand(F.mesh.nC,2) + F[[self.Tx0,self.Tx1], 'phi'] = phi + self.assertTrue(np.all(F[[self.Tx0,self.Tx1], 'phi'] == phi)) + + fdict = F[:] + self.assertTrue(type(fdict) is dict) + self.assertTrue(sorted([k for k in fdict]) == ['b','e','phi']) + + b = np.random.rand(F.mesh.nF, 2) + F[[self.Tx0, self.Tx1],'b'] = b + self.assertTrue(F[self.Tx0]['b'].shape == (F.mesh.nF,)) + self.assertTrue(F[self.Tx0,'b'].shape == (F.mesh.nF,)) + self.assertTrue(np.all(F[self.Tx0,'b'] == b[:,0])) + self.assertTrue(np.all(F[self.Tx1,'b'] == b[:,1])) + + def test_assertions(self): + freq = [self.Tx0, self.Tx1] + bWrongSize = np.random.rand(self.F.mesh.nE, self.F.survey.nTx) + def fun(): self.F[freq, 'b'] = bWrongSize + self.assertRaises(ValueError, fun) + def fun(): self.F[-999.] + self.assertRaises(KeyError, fun) + def fun(): self.F['notRight'] + self.assertRaises(KeyError, fun) + def fun(): self.F[freq,'notThere'] + self.assertRaises(KeyError, fun) if __name__ == '__main__': unittest.main()