diff --git a/SimPEG/Survey.py b/SimPEG/Survey.py index a6acd1a5..276beebf 100644 --- a/SimPEG/Survey.py +++ b/SimPEG/Survey.py @@ -250,6 +250,16 @@ class Fields(object): ind = np.in1d(self.survey.txList, txTestList) return ind + def _nameIndex(self, name): + + if type(name) is slice: + assert name == slice(None,None,None), 'Fancy field name slicing is not supported... yet.' + name = None + + if name is not None and name not in self.knownFields: + raise KeyError('Invalid field name') + return name + def _indexAndNameFromKey(self, key): if type(key) is not tuple: key = (key,) @@ -259,10 +269,7 @@ class Fields(object): assert len(key) == 2, 'must be [Tx, fieldName]' txTestList, name = key - - if name is not None and name not in self.knownFields: - raise KeyError('Invalid field name') - + name = self._nameIndex(name) ind = self._txIndex(txTestList) return ind, name @@ -328,9 +335,7 @@ class TimeFields(Fields): txTestList, name, timeInd = key - if name is not None and name not in self.knownFields: - raise KeyError('Invalid field name') - + name = self._nameIndex(name) txInd = self._txIndex(txTestList) return (txInd, timeInd), name diff --git a/SimPEG/Tests/test_Survey.py b/SimPEG/Tests/test_Survey.py index 4c0b1a8a..1d4867ba 100644 --- a/SimPEG/Tests/test_Survey.py +++ b/SimPEG/Tests/test_Survey.py @@ -20,7 +20,7 @@ class DataAndFieldsTest(unittest.TestCase): txList = [Tx0,Tx1,Tx2,Tx3,Tx4] survey = Survey.BaseSurvey(txList=txList) self.D = Survey.Data(survey) - self.F = Survey.Fields(mesh, survey, knownFields={'phi':'CC','e':'E','b':'F'}) + self.F = Survey.Fields(mesh, survey, knownFields={'phi':'CC','e':'E','b':'F'}, dtype={"phi":float,"e":complex,"b":complex}) self.Tx0 = Tx0 self.Tx1 = Tx1 self.mesh = mesh @@ -49,9 +49,9 @@ class DataAndFieldsTest(unittest.TestCase): def test_SetGet(self): F = self.F nTx = F.survey.nTx - e = np.random.rand(F.mesh.nE, nTx) + e = np.random.rand(F.mesh.nE, nTx) + np.random.rand(F.mesh.nE, nTx)*1j F[:, 'e'] = e - b = np.random.rand(F.mesh.nF, nTx) + b = np.random.rand(F.mesh.nF, nTx) + np.random.rand(F.mesh.nF, nTx)*1j F[:, 'b'] = b self.assertTrue(np.all(F[:, 'e'] == e)) @@ -72,7 +72,7 @@ class DataAndFieldsTest(unittest.TestCase): F[[self.Tx0,self.Tx1], 'phi'] = phi self.assertTrue(np.all(F[[self.Tx0,self.Tx1], 'phi'] == phi)) - fdict = F[:] + fdict = F[:,:] self.assertTrue(type(fdict) is dict) self.assertTrue(sorted([k for k in fdict]) == ['b','e','phi'])