diff --git a/SimPEG/Problem.py b/SimPEG/Problem.py index b629d723..1787a472 100644 --- a/SimPEG/Problem.py +++ b/SimPEG/Problem.py @@ -139,7 +139,7 @@ class Fields(object): return self._getField(name, ind) def _setField(self, field, val, name, ind): - if isinstance(val, np.ndarray) and (field.shape[1] == 1 or val.ndim == 1): + if isinstance(val, np.ndarray) and (field.shape[0] == field.size or val.ndim == 1): val = Utils.mkvc(val,2) field[:,ind] = val @@ -160,8 +160,8 @@ class Fields(object): assert hasattr(self, func), 'The alias field function is a string, but it does not exist in the Fields class.' func = getattr(self, func) out = func(self._fields[alias][:,ind], srcII) - if out.shape[0] == out.size: - out = Utils.mkvc(out) + if isinstance(out, np.ndarray) and (out.shape[0] == out.size or out.ndim == 1): + out = Utils.mkvc(out,2) return out def __contains__(self, other): @@ -216,6 +216,8 @@ class TimeFields(Fields): shape = nP, nSrc, nT if deflate: shape = tuple([s for s in shape if s > 1]) + if len(shape) == 1: + shape = shape + (1,) return shape def _setField(self, field, val, name, ind): @@ -260,8 +262,8 @@ class TimeFields(Fields): out = range(nT) for i, TIND_i in enumerate(timeII): fieldI = pointerFields[:,:,i] - if fieldI.ndim == 2 and fieldI.shape[1] == 1: - fieldI = Utils.mkvc(fieldI) + if fieldI.shape[0] == fieldI.size: + fieldI = Utils.mkvc(fieldI,2) out[i] = func(fieldI, srcII, TIND_i) if out[i].ndim == 1: out[i] = out[i][:,np.newaxis,np.newaxis] diff --git a/SimPEG/Tests/test_Survey.py b/SimPEG/Tests/test_Survey.py index f5109b71..efc1e901 100644 --- a/SimPEG/Tests/test_Survey.py +++ b/SimPEG/Tests/test_Survey.py @@ -80,9 +80,9 @@ class DataAndFieldsTest(unittest.TestCase): b = np.random.rand(F.mesh.nF,1) F[self.Src0, 'b'] = b - self.assertTrue(np.all(F[self.Src0, 'b'] == Utils.mkvc(b))) + self.assertTrue(np.all(F[self.Src0, 'b'] == b)) - b = np.random.rand(F.mesh.nF) + b = np.random.rand(F.mesh.nF,1) F[self.Src0, 'b'] = b self.assertTrue(np.all(F[self.Src0, 'b'] == b)) @@ -96,10 +96,10 @@ class DataAndFieldsTest(unittest.TestCase): b = np.random.rand(F.mesh.nF, 2) F[[self.Src0, self.Src1],'b'] = b - self.assertTrue(F[self.Src0]['b'].shape == (F.mesh.nF,)) - self.assertTrue(F[self.Src0,'b'].shape == (F.mesh.nF,)) - self.assertTrue(np.all(F[self.Src0,'b'] == b[:,0])) - self.assertTrue(np.all(F[self.Src1,'b'] == b[:,1])) + self.assertTrue(F[self.Src0]['b'].shape == (F.mesh.nF,1)) + self.assertTrue(F[self.Src0,'b'].shape == (F.mesh.nF,1)) + self.assertTrue(np.all(F[self.Src0,'b'] == Utils.mkvc(b[:,0],2))) + self.assertTrue(np.all(F[self.Src1,'b'] == Utils.mkvc(b[:,1],2))) def test_assertions(self): freq = [self.Src0, self.Src1] @@ -158,7 +158,7 @@ class FieldsTest_Alias(unittest.TestCase): e = np.random.rand(F.mesh.nE,1) F[self.Src0, 'e'] = e - self.assertTrue(np.all(F[self.Src0, 'b'] == F.mesh.edgeCurl * Utils.mkvc(e))) + self.assertTrue(np.all(F[self.Src0, 'b'] == F.mesh.edgeCurl * e)) def f(): F[self.Src0, 'b'] = F[self.Src0, 'b'] @@ -249,7 +249,7 @@ class FieldsTest_Time(unittest.TestCase): b = np.random.rand(F.mesh.nF,1,nT) F[self.Src0, 'b', 0] = b[:,:,0] - self.assertTrue(np.all(F[self.Src0, 'b', 0] == b[:,0,0])) + self.assertTrue(np.all(F[self.Src0, 'b', 0] == Utils.mkvc(b[:,0,0],2))) phi = np.random.rand(F.mesh.nC,2,nT) F[[self.Src0,self.Src1], 'phi'] = phi @@ -265,10 +265,10 @@ class FieldsTest_Time(unittest.TestCase): self.assertTrue(F[self.Src0,'b'].shape == (F.mesh.nF,nT)) self.assertTrue(np.all(F[self.Src0,'b'] == b[:,0,:])) self.assertTrue(np.all(F[self.Src1,'b'] == b[:,1,:])) - self.assertTrue(np.all(F[self.Src0,'b',1] == b[:,0,1])) - self.assertTrue(np.all(F[self.Src1,'b',1] == b[:,1,1])) - self.assertTrue(np.all(F[self.Src0,'b',4] == b[:,0,4])) - self.assertTrue(np.all(F[self.Src1,'b',4] == b[:,1,4])) + self.assertTrue(np.all(F[self.Src0,'b',1] == Utils.mkvc(b[:,0,1],2))) + self.assertTrue(np.all(F[self.Src1,'b',1] == Utils.mkvc(b[:,1,1],2))) + self.assertTrue(np.all(F[self.Src0,'b',4] == Utils.mkvc(b[:,0,4],2))) + self.assertTrue(np.all(F[self.Src1,'b',4] == Utils.mkvc(b[:,1,4],2))) b = np.random.rand(F.mesh.nF, 2, nT) @@ -344,7 +344,7 @@ class FieldsTest_Time_Aliased(unittest.TestCase): self.assertTrue(np.all(F[self.Src0, 'e', :] == e[:,0,:] )) self.assertTrue(np.all(F[self.Src1, 'e', :] == e[:,1,:] )) for t in range(nT): - self.assertTrue(np.all(F[self.Src1, 'e', t] == e[:,1,t] )) + self.assertTrue(np.all(F[self.Src1, 'e', t] == Utils.mkvc(e[:,1,t],2) )) b = np.random.rand(F.mesh.nF,nT) F[self.Src0, 'b',:] = b