mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-30 05:09:41 +08:00
Alias fields option.
This commit is contained in:
+36
-12
@@ -219,7 +219,8 @@ class Fields(object):
|
||||
|
||||
"""
|
||||
|
||||
knownFields = None #: the known fields, a dict with locations, e.g. {"phi": "CC", "b": "F"}
|
||||
knownFields = None #: Known fields, a dict with locations, e.g. {"phi": "CC", "b": "F"}
|
||||
aliasFields = None #: Aliased fields, a dict with [alias, location, and function or float], e.g. {"b":["e",lambda(F,e)]}
|
||||
txPair = BaseTx
|
||||
dtype = float #: dtype is the type of the storage matrix. This can be a dictionary.
|
||||
|
||||
@@ -229,6 +230,11 @@ class Fields(object):
|
||||
Utils.setKwargs(self, **kwargs)
|
||||
self._fields = {}
|
||||
|
||||
if self.knownFields is None:
|
||||
raise Exception('knownFields cannot be set to None')
|
||||
if self.aliasFields is None:
|
||||
self.aliasFields = {}
|
||||
|
||||
def _storageShape(self, nP):
|
||||
nTx = self.survey.nTx
|
||||
return (nP, nTx)
|
||||
@@ -271,17 +277,25 @@ class Fields(object):
|
||||
ind = np.in1d(self.survey.txList, txTestList)
|
||||
return ind
|
||||
|
||||
def _nameIndex(self, name):
|
||||
def _nameIndex(self, name, accessType):
|
||||
|
||||
if type(name) is slice:
|
||||
assert name == slice(None,None,None), 'Fancy field name slicing is not supported... yet.'
|
||||
name = None
|
||||
|
||||
if name is not None and name not in self.knownFields:
|
||||
raise KeyError('Invalid field name')
|
||||
if name is None:
|
||||
return
|
||||
if accessType=='set' and name not in self.knownFields:
|
||||
if name in self.aliasFields:
|
||||
raise KeyError("Invalid field name (%s) for setter, you can't set an aliased property"%name)
|
||||
else:
|
||||
raise KeyError('Invalid field name (%s) for setter'%name)
|
||||
|
||||
elif accessType=='get' and (name not in self.knownFields and name not in self.aliasFields):
|
||||
raise KeyError('Invalid field name (%s) for getter'%name)
|
||||
return name
|
||||
|
||||
def _indexAndNameFromKey(self, key):
|
||||
def _indexAndNameFromKey(self, key, accessType):
|
||||
if type(key) is not tuple:
|
||||
key = (key,)
|
||||
if len(key) == 1:
|
||||
@@ -290,12 +304,12 @@ class Fields(object):
|
||||
assert len(key) == 2, 'must be [Tx, fieldName]'
|
||||
|
||||
txTestList, name = key
|
||||
name = self._nameIndex(name)
|
||||
name = self._nameIndex(name, accessType)
|
||||
ind = self._txIndex(txTestList)
|
||||
return ind, name
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
ind, name = self._indexAndNameFromKey(key)
|
||||
ind, name = self._indexAndNameFromKey(key, 'set')
|
||||
if name is None:
|
||||
freq = key
|
||||
assert type(value) is dict, 'New fields must be a dictionary, if field is not specified.'
|
||||
@@ -310,7 +324,7 @@ class Fields(object):
|
||||
self._setField(field, newFields[name], ind)
|
||||
|
||||
def __getitem__(self, key):
|
||||
ind, name = self._indexAndNameFromKey(key)
|
||||
ind, name = self._indexAndNameFromKey(key, 'get')
|
||||
if name is None:
|
||||
out = {}
|
||||
for name in self._fields:
|
||||
@@ -324,13 +338,23 @@ class Fields(object):
|
||||
field[:,ind] = val
|
||||
|
||||
def _getField(self, name, ind):
|
||||
out = self._fields[name][:,ind]
|
||||
if name in self._fields:
|
||||
out = self._fields[name][:,ind]
|
||||
else:
|
||||
out = self._getAliasField(name, ind)
|
||||
|
||||
if out.shape[1] == 1:
|
||||
out = Utils.mkvc(out)
|
||||
return out
|
||||
|
||||
def _getAliasField(self, name, ind):
|
||||
alias, func = self.aliasFields[name]
|
||||
return func(self, self._fields[alias][:,ind])
|
||||
|
||||
def __contains__(self, other):
|
||||
return self._fields.__contains__(other)
|
||||
if other in self.aliasFields:
|
||||
other = self.aliasFields[other][0]
|
||||
return self._fields.__contains__(other)
|
||||
|
||||
class TimeFields(Fields):
|
||||
"""Fancy Field Storage for time domain problems
|
||||
@@ -345,7 +369,7 @@ class TimeFields(Fields):
|
||||
nT = self.survey.prob.nT
|
||||
return (nP, nTx, nT + 1)
|
||||
|
||||
def _indexAndNameFromKey(self, key):
|
||||
def _indexAndNameFromKey(self, key, accessType):
|
||||
if type(key) is not tuple:
|
||||
key = (key,)
|
||||
if len(key) == 1:
|
||||
@@ -357,7 +381,7 @@ class TimeFields(Fields):
|
||||
|
||||
txTestList, name, timeInd = key
|
||||
|
||||
name = self._nameIndex(name)
|
||||
name = self._nameIndex(name, accessType)
|
||||
txInd = self._txIndex(txTestList)
|
||||
|
||||
return (txInd, timeInd), name
|
||||
|
||||
@@ -40,6 +40,15 @@ class DataAndFieldsTest(unittest.TestCase):
|
||||
D2 = Survey.Data(self.D.survey, V)
|
||||
self.assertTrue(np.all(Utils.mkvc(D2) == Utils.mkvc(self.D)))
|
||||
|
||||
def test_contains(self):
|
||||
F = self.F
|
||||
nTx = F.survey.nTx
|
||||
self.assertTrue('b' not in F)
|
||||
self.assertTrue('b' not in F)
|
||||
e = np.random.rand(F.mesh.nE, nTx)
|
||||
F[:, 'e'] = e
|
||||
self.assertTrue('b' not in F)
|
||||
self.assertTrue('e' in F)
|
||||
|
||||
def test_uniqueTxs(self):
|
||||
txs = self.D.survey.txList
|
||||
@@ -96,6 +105,56 @@ class DataAndFieldsTest(unittest.TestCase):
|
||||
self.assertRaises(KeyError, fun)
|
||||
|
||||
|
||||
class FieldsTest_Alias(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
mesh = Mesh.TensorMesh([np.ones(n)*5 for n in [10,11,12]],[0,0,-30])
|
||||
x = np.linspace(5,10,3)
|
||||
XYZ = Utils.ndgrid(x,x,np.r_[0.])
|
||||
txLoc = np.r_[0,0,0.]
|
||||
rxList0 = Survey.BaseRx(XYZ, 'exi')
|
||||
Tx0 = Survey.BaseTx(txLoc, 'VMD', [rxList0])
|
||||
rxList1 = Survey.BaseRx(XYZ, 'bxi')
|
||||
Tx1 = Survey.BaseTx(txLoc, 'VMD', [rxList1])
|
||||
rxList2 = Survey.BaseRx(XYZ, 'bxi')
|
||||
Tx2 = Survey.BaseTx(txLoc, 'VMD', [rxList2])
|
||||
rxList3 = Survey.BaseRx(XYZ, 'bxi')
|
||||
Tx3 = Survey.BaseTx(txLoc, 'VMD', [rxList3])
|
||||
Tx4 = Survey.BaseTx(txLoc, 'VMD', [rxList0, rxList1, rxList2, rxList3])
|
||||
txList = [Tx0,Tx1,Tx2,Tx3,Tx4]
|
||||
survey = Survey.BaseSurvey(txList=txList)
|
||||
self.D = Survey.Data(survey)
|
||||
self.F = Survey.Fields(mesh, survey, knownFields={'e':'E'}, aliasFields={'b':['e',(lambda F, e: F.mesh.edgeCurl * e)]})
|
||||
self.Tx0 = Tx0
|
||||
self.Tx1 = Tx1
|
||||
self.mesh = mesh
|
||||
self.XYZ = XYZ
|
||||
|
||||
def test_contains(self):
|
||||
F = self.F
|
||||
nTx = F.survey.nTx
|
||||
self.assertTrue('b' not in F)
|
||||
self.assertTrue('b' not in F)
|
||||
e = np.random.rand(F.mesh.nE, nTx)
|
||||
F[:, 'e'] = e
|
||||
self.assertTrue('b' in F)
|
||||
self.assertTrue('e' in F)
|
||||
|
||||
def test_simpleAlias(self):
|
||||
F = self.F
|
||||
nTx = F.survey.nTx
|
||||
e = np.random.rand(F.mesh.nE, nTx)
|
||||
F[:, 'e'] = e
|
||||
self.assertTrue(np.all(F[:, 'b'] == F.mesh.edgeCurl * e ))
|
||||
|
||||
e = np.random.rand(F.mesh.nE,1)
|
||||
F[self.Tx0, 'e'] = e
|
||||
self.assertTrue(np.all(F[self.Tx0, 'b'] == F.mesh.edgeCurl * Utils.mkvc(e)))
|
||||
|
||||
def f():
|
||||
F[self.Tx0, 'b'] = F[self.Tx0, 'b']
|
||||
self.assertRaises(KeyError, f) # can't set a alias attr.
|
||||
|
||||
|
||||
class FieldsTest_Time(unittest.TestCase):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user