diff --git a/SimPEG/Survey.py b/SimPEG/Survey.py index 276beebf..7954bc21 100644 --- a/SimPEG/Survey.py +++ b/SimPEG/Survey.py @@ -103,7 +103,7 @@ class BaseTx(object): def __init__(self, loc, txType, rxList, **kwargs): assert type(rxList) is list, 'rxList must be a list' for rx in rxList: - assert isinstance(rx, self.rxPair), 'rxList must be a %s'%self.rxListPair.__name__ + assert isinstance(rx, self.rxPair), 'rxList must be a %s'%self.rxPair.__name__ assert len(set(rxList)) == len(rxList), 'The rxList must be unique' self.loc = loc @@ -321,7 +321,7 @@ class TimeFields(Fields): def _storageShape(self, nP): nTx = self.survey.nTx nT = self.survey.prob.nT - return (nP, nTx, nT) + return (nP, nTx, nT + 1) def _indexAndNameFromKey(self, key): if type(key) is not tuple: @@ -342,8 +342,6 @@ class TimeFields(Fields): def _setField(self, field, val, ind): txInd, timeInd = ind - if val.ndim == 2: - val = val[:, np.newaxis, :] field[:,txInd,timeInd] = val def _getField(self, name, ind): diff --git a/SimPEG/Tests/test_Survey.py b/SimPEG/Tests/test_Survey.py index 1d4867ba..163ce1de 100644 --- a/SimPEG/Tests/test_Survey.py +++ b/SimPEG/Tests/test_Survey.py @@ -126,7 +126,7 @@ class FieldsTest_Time(unittest.TestCase): def test_SetGet(self): F = self.F nTx = F.survey.nTx - nT = F.survey.prob.nT + nT = F.survey.prob.nT + 1 e = np.random.rand(F.mesh.nE, nTx, nT) F[:, 'e'] = e b = np.random.rand(F.mesh.nF, nTx, nT) @@ -138,9 +138,13 @@ class FieldsTest_Time(unittest.TestCase): self.assertTrue(np.all(F[:, 'e'] == e)) self.assertTrue(np.all(F[:, 'b'] == b)) - b = np.random.rand(F.mesh.nF,nT) + b = np.random.rand(F.mesh.nF,1,nT) F[self.Tx0, 'b'] = b - self.assertTrue(np.all(F[self.Tx0, 'b'] == b)) + self.assertTrue(np.all(F[self.Tx0, 'b'] == b[:,0,:])) + + b = np.random.rand(F.mesh.nF,1,nT) + F[self.Tx0, 'b', 0] = b[:,:,0] + self.assertTrue(np.all(F[self.Tx0, 'b', 0] == b[:,0,0])) phi = np.random.rand(F.mesh.nC,2,nT) F[[self.Tx0,self.Tx1], 'phi'] = phi @@ -161,6 +165,10 @@ class FieldsTest_Time(unittest.TestCase): self.assertTrue(np.all(F[self.Tx0,'b',4] == b[:,0,4])) self.assertTrue(np.all(F[self.Tx1,'b',4] == b[:,1,4])) + + b = np.random.rand(F.mesh.nF, 2, nT) + F[[self.Tx0, self.Tx1],'b', 0] = b[:,:,0] + def test_assertions(self): freq = [self.Tx0, self.Tx1] bWrongSize = np.random.rand(self.F.mesh.nE, self.F.survey.nTx)