Alias fields option.

This commit is contained in:
rowanc1
2014-05-15 09:10:51 -07:00
parent ab249d31b3
commit ebc2853615
2 changed files with 95 additions and 12 deletions
+36 -12
View File
@@ -219,7 +219,8 @@ class Fields(object):
"""
knownFields = None #: the known fields, a dict with locations, e.g. {"phi": "CC", "b": "F"}
knownFields = None #: Known fields, a dict with locations, e.g. {"phi": "CC", "b": "F"}
aliasFields = None #: Aliased fields, a dict with [alias, location, and function or float], e.g. {"b":["e",lambda(F,e)]}
txPair = BaseTx
dtype = float #: dtype is the type of the storage matrix. This can be a dictionary.
@@ -229,6 +230,11 @@ class Fields(object):
Utils.setKwargs(self, **kwargs)
self._fields = {}
if self.knownFields is None:
raise Exception('knownFields cannot be set to None')
if self.aliasFields is None:
self.aliasFields = {}
def _storageShape(self, nP):
nTx = self.survey.nTx
return (nP, nTx)
@@ -271,17 +277,25 @@ class Fields(object):
ind = np.in1d(self.survey.txList, txTestList)
return ind
def _nameIndex(self, name):
def _nameIndex(self, name, accessType):
if type(name) is slice:
assert name == slice(None,None,None), 'Fancy field name slicing is not supported... yet.'
name = None
if name is not None and name not in self.knownFields:
raise KeyError('Invalid field name')
if name is None:
return
if accessType=='set' and name not in self.knownFields:
if name in self.aliasFields:
raise KeyError("Invalid field name (%s) for setter, you can't set an aliased property"%name)
else:
raise KeyError('Invalid field name (%s) for setter'%name)
elif accessType=='get' and (name not in self.knownFields and name not in self.aliasFields):
raise KeyError('Invalid field name (%s) for getter'%name)
return name
def _indexAndNameFromKey(self, key):
def _indexAndNameFromKey(self, key, accessType):
if type(key) is not tuple:
key = (key,)
if len(key) == 1:
@@ -290,12 +304,12 @@ class Fields(object):
assert len(key) == 2, 'must be [Tx, fieldName]'
txTestList, name = key
name = self._nameIndex(name)
name = self._nameIndex(name, accessType)
ind = self._txIndex(txTestList)
return ind, name
def __setitem__(self, key, value):
ind, name = self._indexAndNameFromKey(key)
ind, name = self._indexAndNameFromKey(key, 'set')
if name is None:
freq = key
assert type(value) is dict, 'New fields must be a dictionary, if field is not specified.'
@@ -310,7 +324,7 @@ class Fields(object):
self._setField(field, newFields[name], ind)
def __getitem__(self, key):
ind, name = self._indexAndNameFromKey(key)
ind, name = self._indexAndNameFromKey(key, 'get')
if name is None:
out = {}
for name in self._fields:
@@ -324,13 +338,23 @@ class Fields(object):
field[:,ind] = val
def _getField(self, name, ind):
out = self._fields[name][:,ind]
if name in self._fields:
out = self._fields[name][:,ind]
else:
out = self._getAliasField(name, ind)
if out.shape[1] == 1:
out = Utils.mkvc(out)
return out
def _getAliasField(self, name, ind):
alias, func = self.aliasFields[name]
return func(self, self._fields[alias][:,ind])
def __contains__(self, other):
return self._fields.__contains__(other)
if other in self.aliasFields:
other = self.aliasFields[other][0]
return self._fields.__contains__(other)
class TimeFields(Fields):
"""Fancy Field Storage for time domain problems
@@ -345,7 +369,7 @@ class TimeFields(Fields):
nT = self.survey.prob.nT
return (nP, nTx, nT + 1)
def _indexAndNameFromKey(self, key):
def _indexAndNameFromKey(self, key, accessType):
if type(key) is not tuple:
key = (key,)
if len(key) == 1:
@@ -357,7 +381,7 @@ class TimeFields(Fields):
txTestList, name, timeInd = key
name = self._nameIndex(name)
name = self._nameIndex(name, accessType)
txInd = self._txIndex(txTestList)
return (txInd, timeInd), name
+59
View File
@@ -40,6 +40,15 @@ class DataAndFieldsTest(unittest.TestCase):
D2 = Survey.Data(self.D.survey, V)
self.assertTrue(np.all(Utils.mkvc(D2) == Utils.mkvc(self.D)))
def test_contains(self):
F = self.F
nTx = F.survey.nTx
self.assertTrue('b' not in F)
self.assertTrue('b' not in F)
e = np.random.rand(F.mesh.nE, nTx)
F[:, 'e'] = e
self.assertTrue('b' not in F)
self.assertTrue('e' in F)
def test_uniqueTxs(self):
txs = self.D.survey.txList
@@ -96,6 +105,56 @@ class DataAndFieldsTest(unittest.TestCase):
self.assertRaises(KeyError, fun)
class FieldsTest_Alias(unittest.TestCase):
def setUp(self):
mesh = Mesh.TensorMesh([np.ones(n)*5 for n in [10,11,12]],[0,0,-30])
x = np.linspace(5,10,3)
XYZ = Utils.ndgrid(x,x,np.r_[0.])
txLoc = np.r_[0,0,0.]
rxList0 = Survey.BaseRx(XYZ, 'exi')
Tx0 = Survey.BaseTx(txLoc, 'VMD', [rxList0])
rxList1 = Survey.BaseRx(XYZ, 'bxi')
Tx1 = Survey.BaseTx(txLoc, 'VMD', [rxList1])
rxList2 = Survey.BaseRx(XYZ, 'bxi')
Tx2 = Survey.BaseTx(txLoc, 'VMD', [rxList2])
rxList3 = Survey.BaseRx(XYZ, 'bxi')
Tx3 = Survey.BaseTx(txLoc, 'VMD', [rxList3])
Tx4 = Survey.BaseTx(txLoc, 'VMD', [rxList0, rxList1, rxList2, rxList3])
txList = [Tx0,Tx1,Tx2,Tx3,Tx4]
survey = Survey.BaseSurvey(txList=txList)
self.D = Survey.Data(survey)
self.F = Survey.Fields(mesh, survey, knownFields={'e':'E'}, aliasFields={'b':['e',(lambda F, e: F.mesh.edgeCurl * e)]})
self.Tx0 = Tx0
self.Tx1 = Tx1
self.mesh = mesh
self.XYZ = XYZ
def test_contains(self):
F = self.F
nTx = F.survey.nTx
self.assertTrue('b' not in F)
self.assertTrue('b' not in F)
e = np.random.rand(F.mesh.nE, nTx)
F[:, 'e'] = e
self.assertTrue('b' in F)
self.assertTrue('e' in F)
def test_simpleAlias(self):
F = self.F
nTx = F.survey.nTx
e = np.random.rand(F.mesh.nE, nTx)
F[:, 'e'] = e
self.assertTrue(np.all(F[:, 'b'] == F.mesh.edgeCurl * e ))
e = np.random.rand(F.mesh.nE,1)
F[self.Tx0, 'e'] = e
self.assertTrue(np.all(F[self.Tx0, 'b'] == F.mesh.edgeCurl * Utils.mkvc(e)))
def f():
F[self.Tx0, 'b'] = F[self.Tx0, 'b']
self.assertRaises(KeyError, f) # can't set a alias attr.
class FieldsTest_Time(unittest.TestCase):