type in fields storage.

This commit is contained in:
rowanc1
2014-04-26 17:04:22 -07:00
parent cb2c2aec3b
commit 169ee9dd0b
+26 -20
View File
@@ -200,6 +200,7 @@ class Fields(object):
knownFields = None
txPair = BaseTx
dtype = float
def __init__(self, mesh, survey, **kwargs):
self.survey = survey
@@ -207,6 +208,10 @@ class Fields(object):
Utils.setKwargs(self, **kwargs)
self._fields = {}
def _storageShape(self, nP):
nTx = self.survey.nTx
return (nP, nTx)
def _initStore(self, name):
if name in self._fields:
return self._fields[name]
@@ -219,13 +224,31 @@ class Fields(object):
'F': self.mesh.nF,
'E': self.mesh.nE}[loc]
nTx = self.survey.nTx
field = np.empty((nP, nTx))
if type(self.dtype) is dict:
dtype = self.dtype[name]
else:
dtype = self.dtype
field = np.empty(self._storageShape(nP), dtype=dtype)
self._fields[name] = field
return field
def _txIndex(self, txTestList):
if type(txTestList) is slice:
ind = txTestList
else:
if type(txTestList) is not list:
txTestList = [txTestList]
for txTest in txTestList:
if not isinstance(txTest, self.txPair):
raise KeyError('First index must be a Transmitter')
if txTest not in self.survey.txList:
raise KeyError('Invalid Transmitter, not in survey list.')
ind = np.in1d(self.survey.txList, txTestList)
return ind
def _indexAndNameFromKey(self, key):
if type(key) is not tuple:
key = (key,)
@@ -239,19 +262,7 @@ class Fields(object):
if name is not None and name not in self.knownFields:
raise KeyError('Invalid field name')
if type(txTestList) is slice:
ind = txTestList
else:
if type(txTestList) is not list:
txTestList = [txTestList]
for txTest in txTestList:
if not isinstance(txTest, self.txPair):
raise KeyError('First index must be a Transmitter')
if txTest not in self.survey.txList:
raise KeyError('Invalid Transmitter, not in survey list.')
ind = np.in1d(self.survey.txList, txTestList)
ind = self._txIndex(txTestList)
return ind, name
def __setitem__(self, key, value):
@@ -289,7 +300,6 @@ class Fields(object):
return out
class BaseSurvey(object):
"""Survey holds the observed data, and the standard deviations."""
@@ -492,7 +502,3 @@ class BaseSurvey(object):
# def phi_d_target(self, value):
# self._phi_d_target = value
if __name__ == '__main__':
d = BaseData()
d.dpred()