mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-27 21:08:35 +08:00
type in fields storage.
This commit is contained in:
+26
-20
@@ -200,6 +200,7 @@ class Fields(object):
|
||||
|
||||
knownFields = None
|
||||
txPair = BaseTx
|
||||
dtype = float
|
||||
|
||||
def __init__(self, mesh, survey, **kwargs):
|
||||
self.survey = survey
|
||||
@@ -207,6 +208,10 @@ class Fields(object):
|
||||
Utils.setKwargs(self, **kwargs)
|
||||
self._fields = {}
|
||||
|
||||
def _storageShape(self, nP):
|
||||
nTx = self.survey.nTx
|
||||
return (nP, nTx)
|
||||
|
||||
def _initStore(self, name):
|
||||
if name in self._fields:
|
||||
return self._fields[name]
|
||||
@@ -219,13 +224,31 @@ class Fields(object):
|
||||
'F': self.mesh.nF,
|
||||
'E': self.mesh.nE}[loc]
|
||||
|
||||
nTx = self.survey.nTx
|
||||
field = np.empty((nP, nTx))
|
||||
if type(self.dtype) is dict:
|
||||
dtype = self.dtype[name]
|
||||
else:
|
||||
dtype = self.dtype
|
||||
field = np.empty(self._storageShape(nP), dtype=dtype)
|
||||
|
||||
self._fields[name] = field
|
||||
|
||||
return field
|
||||
|
||||
def _txIndex(self, txTestList):
|
||||
if type(txTestList) is slice:
|
||||
ind = txTestList
|
||||
else:
|
||||
if type(txTestList) is not list:
|
||||
txTestList = [txTestList]
|
||||
for txTest in txTestList:
|
||||
if not isinstance(txTest, self.txPair):
|
||||
raise KeyError('First index must be a Transmitter')
|
||||
if txTest not in self.survey.txList:
|
||||
raise KeyError('Invalid Transmitter, not in survey list.')
|
||||
|
||||
ind = np.in1d(self.survey.txList, txTestList)
|
||||
return ind
|
||||
|
||||
def _indexAndNameFromKey(self, key):
|
||||
if type(key) is not tuple:
|
||||
key = (key,)
|
||||
@@ -239,19 +262,7 @@ class Fields(object):
|
||||
if name is not None and name not in self.knownFields:
|
||||
raise KeyError('Invalid field name')
|
||||
|
||||
if type(txTestList) is slice:
|
||||
ind = txTestList
|
||||
else:
|
||||
if type(txTestList) is not list:
|
||||
txTestList = [txTestList]
|
||||
for txTest in txTestList:
|
||||
if not isinstance(txTest, self.txPair):
|
||||
raise KeyError('First index must be a Transmitter')
|
||||
if txTest not in self.survey.txList:
|
||||
raise KeyError('Invalid Transmitter, not in survey list.')
|
||||
|
||||
ind = np.in1d(self.survey.txList, txTestList)
|
||||
|
||||
ind = self._txIndex(txTestList)
|
||||
return ind, name
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
@@ -289,7 +300,6 @@ class Fields(object):
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class BaseSurvey(object):
|
||||
"""Survey holds the observed data, and the standard deviations."""
|
||||
|
||||
@@ -492,7 +502,3 @@ class BaseSurvey(object):
|
||||
# def phi_d_target(self, value):
|
||||
# self._phi_d_target = value
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
d = BaseData()
|
||||
d.dpred()
|
||||
|
||||
Reference in New Issue
Block a user