diff --git a/SimPEG/Survey.py b/SimPEG/Survey.py index e33b7611..2e2a24e3 100644 --- a/SimPEG/Survey.py +++ b/SimPEG/Survey.py @@ -385,8 +385,8 @@ class TimeFields(Fields): 'F': self.mesh.nF, 'E': self.mesh.nE}[loc] nTx = self.survey.nTx - nT = self.survey.prob.nT - return (nP, nTx, nT + 1) + nT = self.survey.prob.nT + 1 + return (nP, nTx, nT) def _indexAndNameFromKey(self, key, accessType): if type(key) is not tuple: @@ -442,15 +442,19 @@ class TimeFields(Fields): assert hasattr(self, func), 'The alias field function is a string, but it does not exist in the Fields class.' func = getattr(self, func) pointerFields = self._fields[alias][:,txInd,timeInd] - pointerShape = self._correctShape(alias, ind, deflate=True) + pointerShape = self._correctShape(alias, ind) pointerFields = pointerFields.reshape(pointerShape, order='F') - if len(pointerShape) < 3: - out = func(pointerFields, txInd) + + timeII = np.arange(self.survey.prob.nT + 1)[timeInd] + if timeII.size == 1: + pointerShapeDeflated = self._correctShape(alias, ind, deflate=True) + pointerFields = pointerFields.reshape(pointerShapeDeflated, order='F') + out = func(pointerFields, txInd, timeII) else: #loop over the time steps nT = pointerShape[2] out = range(nT) - for i in range(nT): - out[i] = func(pointerFields[:,:,i], txInd) + for i, TIND_i in enumerate(timeII): + out[i] = func(pointerFields[:,:,i], txInd, TIND_i) out[i] = out[i][:,:,np.newaxis] out = np.concatenate(out, axis=2) diff --git a/SimPEG/Tests/test_Survey.py b/SimPEG/Tests/test_Survey.py index df476ebb..1657e7db 100644 --- a/SimPEG/Tests/test_Survey.py +++ b/SimPEG/Tests/test_Survey.py @@ -287,8 +287,8 @@ class FieldsTest_Time_Aliased(unittest.TestCase): survey = Survey.BaseSurvey(txList=txList) prob = Problem.BaseTimeProblem(mesh, timeSteps=[(10.,3), (20.,2)]) survey.pair(prob) - def alias(b, ind): - return self.F.mesh.edgeCurl.T * b + def alias(b, txInd, timeInd): + return self.F.mesh.edgeCurl.T * b + timeInd self.F = Survey.TimeFields(mesh, survey, knownFields={'b':'F'}, aliasFields={'e':['b','E',alias]}) self.Tx0 = Tx0 self.Tx1 = Tx1 @@ -317,7 +317,7 @@ class FieldsTest_Time_Aliased(unittest.TestCase): e = range(nT) for i in range(nT): - e[i] = F.mesh.edgeCurl.T*b[:,:,i] + e[i] = F.mesh.edgeCurl.T*b[:,:,i] + i e[i] = e[i][:,:,np.newaxis] e = np.concatenate(e, axis=2) self.assertTrue(np.all(F[:, 'e', :] == e )) @@ -329,6 +329,8 @@ class FieldsTest_Time_Aliased(unittest.TestCase): b = np.random.rand(F.mesh.nF,nT) F[self.Tx0, 'b',:] = b Cb = F.mesh.edgeCurl.T * b + for i in range(Cb.shape[1]): + Cb[:,i] += i self.assertTrue(np.all(F[self.Tx0, 'e',:] == Cb)) def f():