diff --git a/SimPEG/Survey.py b/SimPEG/Survey.py index 2d2a5a00..1ed649db 100644 --- a/SimPEG/Survey.py +++ b/SimPEG/Survey.py @@ -200,6 +200,7 @@ class Fields(object): knownFields = None txPair = BaseTx + dtype = float def __init__(self, mesh, survey, **kwargs): self.survey = survey @@ -207,6 +208,10 @@ class Fields(object): Utils.setKwargs(self, **kwargs) self._fields = {} + def _storageShape(self, nP): + nTx = self.survey.nTx + return (nP, nTx) + def _initStore(self, name): if name in self._fields: return self._fields[name] @@ -219,13 +224,31 @@ class Fields(object): 'F': self.mesh.nF, 'E': self.mesh.nE}[loc] - nTx = self.survey.nTx - field = np.empty((nP, nTx)) + if type(self.dtype) is dict: + dtype = self.dtype[name] + else: + dtype = self.dtype + field = np.empty(self._storageShape(nP), dtype=dtype) self._fields[name] = field return field + def _txIndex(self, txTestList): + if type(txTestList) is slice: + ind = txTestList + else: + if type(txTestList) is not list: + txTestList = [txTestList] + for txTest in txTestList: + if not isinstance(txTest, self.txPair): + raise KeyError('First index must be a Transmitter') + if txTest not in self.survey.txList: + raise KeyError('Invalid Transmitter, not in survey list.') + + ind = np.in1d(self.survey.txList, txTestList) + return ind + def _indexAndNameFromKey(self, key): if type(key) is not tuple: key = (key,) @@ -239,19 +262,7 @@ class Fields(object): if name is not None and name not in self.knownFields: raise KeyError('Invalid field name') - if type(txTestList) is slice: - ind = txTestList - else: - if type(txTestList) is not list: - txTestList = [txTestList] - for txTest in txTestList: - if not isinstance(txTest, self.txPair): - raise KeyError('First index must be a Transmitter') - if txTest not in self.survey.txList: - raise KeyError('Invalid Transmitter, not in survey list.') - - ind = np.in1d(self.survey.txList, txTestList) - + ind = self._txIndex(txTestList) return ind, name def __setitem__(self, key, value): @@ -289,7 +300,6 @@ class Fields(object): return out - class BaseSurvey(object): """Survey holds the observed data, and the standard deviations.""" @@ -492,7 +502,3 @@ class BaseSurvey(object): # def phi_d_target(self, value): # self._phi_d_target = value - -if __name__ == '__main__': - d = BaseData() - d.dpred()