mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-27 19:48:52 +08:00
Add and test time index passing to aliased time fields
This commit is contained in:
+11
-7
@@ -385,8 +385,8 @@ class TimeFields(Fields):
|
||||
'F': self.mesh.nF,
|
||||
'E': self.mesh.nE}[loc]
|
||||
nTx = self.survey.nTx
|
||||
nT = self.survey.prob.nT
|
||||
return (nP, nTx, nT + 1)
|
||||
nT = self.survey.prob.nT + 1
|
||||
return (nP, nTx, nT)
|
||||
|
||||
def _indexAndNameFromKey(self, key, accessType):
|
||||
if type(key) is not tuple:
|
||||
@@ -442,15 +442,19 @@ class TimeFields(Fields):
|
||||
assert hasattr(self, func), 'The alias field function is a string, but it does not exist in the Fields class.'
|
||||
func = getattr(self, func)
|
||||
pointerFields = self._fields[alias][:,txInd,timeInd]
|
||||
pointerShape = self._correctShape(alias, ind, deflate=True)
|
||||
pointerShape = self._correctShape(alias, ind)
|
||||
pointerFields = pointerFields.reshape(pointerShape, order='F')
|
||||
if len(pointerShape) < 3:
|
||||
out = func(pointerFields, txInd)
|
||||
|
||||
timeII = np.arange(self.survey.prob.nT + 1)[timeInd]
|
||||
if timeII.size == 1:
|
||||
pointerShapeDeflated = self._correctShape(alias, ind, deflate=True)
|
||||
pointerFields = pointerFields.reshape(pointerShapeDeflated, order='F')
|
||||
out = func(pointerFields, txInd, timeII)
|
||||
else: #loop over the time steps
|
||||
nT = pointerShape[2]
|
||||
out = range(nT)
|
||||
for i in range(nT):
|
||||
out[i] = func(pointerFields[:,:,i], txInd)
|
||||
for i, TIND_i in enumerate(timeII):
|
||||
out[i] = func(pointerFields[:,:,i], txInd, TIND_i)
|
||||
out[i] = out[i][:,:,np.newaxis]
|
||||
out = np.concatenate(out, axis=2)
|
||||
|
||||
|
||||
@@ -287,8 +287,8 @@ class FieldsTest_Time_Aliased(unittest.TestCase):
|
||||
survey = Survey.BaseSurvey(txList=txList)
|
||||
prob = Problem.BaseTimeProblem(mesh, timeSteps=[(10.,3), (20.,2)])
|
||||
survey.pair(prob)
|
||||
def alias(b, ind):
|
||||
return self.F.mesh.edgeCurl.T * b
|
||||
def alias(b, txInd, timeInd):
|
||||
return self.F.mesh.edgeCurl.T * b + timeInd
|
||||
self.F = Survey.TimeFields(mesh, survey, knownFields={'b':'F'}, aliasFields={'e':['b','E',alias]})
|
||||
self.Tx0 = Tx0
|
||||
self.Tx1 = Tx1
|
||||
@@ -317,7 +317,7 @@ class FieldsTest_Time_Aliased(unittest.TestCase):
|
||||
|
||||
e = range(nT)
|
||||
for i in range(nT):
|
||||
e[i] = F.mesh.edgeCurl.T*b[:,:,i]
|
||||
e[i] = F.mesh.edgeCurl.T*b[:,:,i] + i
|
||||
e[i] = e[i][:,:,np.newaxis]
|
||||
e = np.concatenate(e, axis=2)
|
||||
self.assertTrue(np.all(F[:, 'e', :] == e ))
|
||||
@@ -329,6 +329,8 @@ class FieldsTest_Time_Aliased(unittest.TestCase):
|
||||
b = np.random.rand(F.mesh.nF,nT)
|
||||
F[self.Tx0, 'b',:] = b
|
||||
Cb = F.mesh.edgeCurl.T * b
|
||||
for i in range(Cb.shape[1]):
|
||||
Cb[:,i] += i
|
||||
self.assertTrue(np.all(F[self.Tx0, 'e',:] == Cb))
|
||||
|
||||
def f():
|
||||
|
||||
Reference in New Issue
Block a user