diff --git a/SimPEG/Survey.py b/SimPEG/Survey.py index 2e2a24e3..5fd3318c 100644 --- a/SimPEG/Survey.py +++ b/SimPEG/Survey.py @@ -446,16 +446,23 @@ class TimeFields(Fields): pointerFields = pointerFields.reshape(pointerShape, order='F') timeII = np.arange(self.survey.prob.nT + 1)[timeInd] + txII = list(np.array(self.survey.txList)[txInd]) if timeII.size == 1: pointerShapeDeflated = self._correctShape(alias, ind, deflate=True) pointerFields = pointerFields.reshape(pointerShapeDeflated, order='F') - out = func(pointerFields, txInd, timeII) + out = func(pointerFields, txII, timeII) else: #loop over the time steps nT = pointerShape[2] out = range(nT) for i, TIND_i in enumerate(timeII): - out[i] = func(pointerFields[:,:,i], txInd, TIND_i) - out[i] = out[i][:,:,np.newaxis] + fieldI = pointerFields[:,:,i] + if fieldI.ndim == 2 and fieldI.shape[1] == 1: + fieldI = Utils.mkvc(fieldI) + out[i] = func(fieldI, txII, TIND_i) + if out[i].ndim == 1: + out[i] = out[i][:,np.newaxis,np.newaxis] + elif out[i].ndim == 2: + out[i] = out[i][:,:,np.newaxis] out = np.concatenate(out, axis=2) shape = self._correctShape(name, ind, deflate=True)