diff --git a/SimPEG/Survey.py b/SimPEG/Survey.py index 450f7464..c32a0440 100644 --- a/SimPEG/Survey.py +++ b/SimPEG/Survey.py @@ -219,7 +219,8 @@ class Fields(object): """ - knownFields = None #: the known fields, a dict with locations, e.g. {"phi": "CC", "b": "F"} + knownFields = None #: Known fields, a dict with locations, e.g. {"phi": "CC", "b": "F"} + aliasFields = None #: Aliased fields, a dict with [alias, location, and function or float], e.g. {"b":["e",lambda(F,e)]} txPair = BaseTx dtype = float #: dtype is the type of the storage matrix. This can be a dictionary. @@ -229,6 +230,11 @@ class Fields(object): Utils.setKwargs(self, **kwargs) self._fields = {} + if self.knownFields is None: + raise Exception('knownFields cannot be set to None') + if self.aliasFields is None: + self.aliasFields = {} + def _storageShape(self, nP): nTx = self.survey.nTx return (nP, nTx) @@ -271,17 +277,25 @@ class Fields(object): ind = np.in1d(self.survey.txList, txTestList) return ind - def _nameIndex(self, name): + def _nameIndex(self, name, accessType): if type(name) is slice: assert name == slice(None,None,None), 'Fancy field name slicing is not supported... yet.' name = None - if name is not None and name not in self.knownFields: - raise KeyError('Invalid field name') + if name is None: + return + if accessType=='set' and name not in self.knownFields: + if name in self.aliasFields: + raise KeyError("Invalid field name (%s) for setter, you can't set an aliased property"%name) + else: + raise KeyError('Invalid field name (%s) for setter'%name) + + elif accessType=='get' and (name not in self.knownFields and name not in self.aliasFields): + raise KeyError('Invalid field name (%s) for getter'%name) return name - def _indexAndNameFromKey(self, key): + def _indexAndNameFromKey(self, key, accessType): if type(key) is not tuple: key = (key,) if len(key) == 1: @@ -290,12 +304,12 @@ class Fields(object): assert len(key) == 2, 'must be [Tx, fieldName]' txTestList, name = key - name = self._nameIndex(name) + name = self._nameIndex(name, accessType) ind = self._txIndex(txTestList) return ind, name def __setitem__(self, key, value): - ind, name = self._indexAndNameFromKey(key) + ind, name = self._indexAndNameFromKey(key, 'set') if name is None: freq = key assert type(value) is dict, 'New fields must be a dictionary, if field is not specified.' @@ -310,7 +324,7 @@ class Fields(object): self._setField(field, newFields[name], ind) def __getitem__(self, key): - ind, name = self._indexAndNameFromKey(key) + ind, name = self._indexAndNameFromKey(key, 'get') if name is None: out = {} for name in self._fields: @@ -324,13 +338,23 @@ class Fields(object): field[:,ind] = val def _getField(self, name, ind): - out = self._fields[name][:,ind] + if name in self._fields: + out = self._fields[name][:,ind] + else: + out = self._getAliasField(name, ind) + if out.shape[1] == 1: out = Utils.mkvc(out) return out + def _getAliasField(self, name, ind): + alias, func = self.aliasFields[name] + return func(self, self._fields[alias][:,ind]) + def __contains__(self, other): - return self._fields.__contains__(other) + if other in self.aliasFields: + other = self.aliasFields[other][0] + return self._fields.__contains__(other) class TimeFields(Fields): """Fancy Field Storage for time domain problems @@ -345,7 +369,7 @@ class TimeFields(Fields): nT = self.survey.prob.nT return (nP, nTx, nT + 1) - def _indexAndNameFromKey(self, key): + def _indexAndNameFromKey(self, key, accessType): if type(key) is not tuple: key = (key,) if len(key) == 1: @@ -357,7 +381,7 @@ class TimeFields(Fields): txTestList, name, timeInd = key - name = self._nameIndex(name) + name = self._nameIndex(name, accessType) txInd = self._txIndex(txTestList) return (txInd, timeInd), name diff --git a/SimPEG/Tests/test_Survey.py b/SimPEG/Tests/test_Survey.py index 163ce1de..0c3e70f8 100644 --- a/SimPEG/Tests/test_Survey.py +++ b/SimPEG/Tests/test_Survey.py @@ -40,6 +40,15 @@ class DataAndFieldsTest(unittest.TestCase): D2 = Survey.Data(self.D.survey, V) self.assertTrue(np.all(Utils.mkvc(D2) == Utils.mkvc(self.D))) + def test_contains(self): + F = self.F + nTx = F.survey.nTx + self.assertTrue('b' not in F) + self.assertTrue('b' not in F) + e = np.random.rand(F.mesh.nE, nTx) + F[:, 'e'] = e + self.assertTrue('b' not in F) + self.assertTrue('e' in F) def test_uniqueTxs(self): txs = self.D.survey.txList @@ -96,6 +105,56 @@ class DataAndFieldsTest(unittest.TestCase): self.assertRaises(KeyError, fun) +class FieldsTest_Alias(unittest.TestCase): + + def setUp(self): + mesh = Mesh.TensorMesh([np.ones(n)*5 for n in [10,11,12]],[0,0,-30]) + x = np.linspace(5,10,3) + XYZ = Utils.ndgrid(x,x,np.r_[0.]) + txLoc = np.r_[0,0,0.] + rxList0 = Survey.BaseRx(XYZ, 'exi') + Tx0 = Survey.BaseTx(txLoc, 'VMD', [rxList0]) + rxList1 = Survey.BaseRx(XYZ, 'bxi') + Tx1 = Survey.BaseTx(txLoc, 'VMD', [rxList1]) + rxList2 = Survey.BaseRx(XYZ, 'bxi') + Tx2 = Survey.BaseTx(txLoc, 'VMD', [rxList2]) + rxList3 = Survey.BaseRx(XYZ, 'bxi') + Tx3 = Survey.BaseTx(txLoc, 'VMD', [rxList3]) + Tx4 = Survey.BaseTx(txLoc, 'VMD', [rxList0, rxList1, rxList2, rxList3]) + txList = [Tx0,Tx1,Tx2,Tx3,Tx4] + survey = Survey.BaseSurvey(txList=txList) + self.D = Survey.Data(survey) + self.F = Survey.Fields(mesh, survey, knownFields={'e':'E'}, aliasFields={'b':['e',(lambda F, e: F.mesh.edgeCurl * e)]}) + self.Tx0 = Tx0 + self.Tx1 = Tx1 + self.mesh = mesh + self.XYZ = XYZ + + def test_contains(self): + F = self.F + nTx = F.survey.nTx + self.assertTrue('b' not in F) + self.assertTrue('b' not in F) + e = np.random.rand(F.mesh.nE, nTx) + F[:, 'e'] = e + self.assertTrue('b' in F) + self.assertTrue('e' in F) + + def test_simpleAlias(self): + F = self.F + nTx = F.survey.nTx + e = np.random.rand(F.mesh.nE, nTx) + F[:, 'e'] = e + self.assertTrue(np.all(F[:, 'b'] == F.mesh.edgeCurl * e )) + + e = np.random.rand(F.mesh.nE,1) + F[self.Tx0, 'e'] = e + self.assertTrue(np.all(F[self.Tx0, 'b'] == F.mesh.edgeCurl * Utils.mkvc(e))) + + def f(): + F[self.Tx0, 'b'] = F[self.Tx0, 'b'] + self.assertRaises(KeyError, f) # can't set a alias attr. + class FieldsTest_Time(unittest.TestCase):