Add and test time index passing to aliased time fields

This commit is contained in:
rowanc1
2014-05-18 18:22:31 -07:00
parent 129d6d6b45
commit 9fe6973ffb
2 changed files with 16 additions and 10 deletions
+11 -7
View File
@@ -385,8 +385,8 @@ class TimeFields(Fields):
'F': self.mesh.nF,
'E': self.mesh.nE}[loc]
nTx = self.survey.nTx
nT = self.survey.prob.nT
return (nP, nTx, nT + 1)
nT = self.survey.prob.nT + 1
return (nP, nTx, nT)
def _indexAndNameFromKey(self, key, accessType):
if type(key) is not tuple:
@@ -442,15 +442,19 @@ class TimeFields(Fields):
assert hasattr(self, func), 'The alias field function is a string, but it does not exist in the Fields class.'
func = getattr(self, func)
pointerFields = self._fields[alias][:,txInd,timeInd]
pointerShape = self._correctShape(alias, ind, deflate=True)
pointerShape = self._correctShape(alias, ind)
pointerFields = pointerFields.reshape(pointerShape, order='F')
if len(pointerShape) < 3:
out = func(pointerFields, txInd)
timeII = np.arange(self.survey.prob.nT + 1)[timeInd]
if timeII.size == 1:
pointerShapeDeflated = self._correctShape(alias, ind, deflate=True)
pointerFields = pointerFields.reshape(pointerShapeDeflated, order='F')
out = func(pointerFields, txInd, timeII)
else: #loop over the time steps
nT = pointerShape[2]
out = range(nT)
for i in range(nT):
out[i] = func(pointerFields[:,:,i], txInd)
for i, TIND_i in enumerate(timeII):
out[i] = func(pointerFields[:,:,i], txInd, TIND_i)
out[i] = out[i][:,:,np.newaxis]
out = np.concatenate(out, axis=2)
+5 -3
View File
@@ -287,8 +287,8 @@ class FieldsTest_Time_Aliased(unittest.TestCase):
survey = Survey.BaseSurvey(txList=txList)
prob = Problem.BaseTimeProblem(mesh, timeSteps=[(10.,3), (20.,2)])
survey.pair(prob)
def alias(b, ind):
return self.F.mesh.edgeCurl.T * b
def alias(b, txInd, timeInd):
return self.F.mesh.edgeCurl.T * b + timeInd
self.F = Survey.TimeFields(mesh, survey, knownFields={'b':'F'}, aliasFields={'e':['b','E',alias]})
self.Tx0 = Tx0
self.Tx1 = Tx1
@@ -317,7 +317,7 @@ class FieldsTest_Time_Aliased(unittest.TestCase):
e = range(nT)
for i in range(nT):
e[i] = F.mesh.edgeCurl.T*b[:,:,i]
e[i] = F.mesh.edgeCurl.T*b[:,:,i] + i
e[i] = e[i][:,:,np.newaxis]
e = np.concatenate(e, axis=2)
self.assertTrue(np.all(F[:, 'e', :] == e ))
@@ -329,6 +329,8 @@ class FieldsTest_Time_Aliased(unittest.TestCase):
b = np.random.rand(F.mesh.nF,nT)
F[self.Tx0, 'b',:] = b
Cb = F.mesh.edgeCurl.T * b
for i in range(Cb.shape[1]):
Cb[:,i] += i
self.assertTrue(np.all(F[self.Tx0, 'e',:] == Cb))
def f():