mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-04 13:57:43 +08:00
+7
-5
@@ -139,7 +139,7 @@ class Fields(object):
|
||||
return self._getField(name, ind)
|
||||
|
||||
def _setField(self, field, val, name, ind):
|
||||
if isinstance(val, np.ndarray) and (field.shape[1] == 1 or val.ndim == 1):
|
||||
if isinstance(val, np.ndarray) and (field.shape[0] == field.size or val.ndim == 1):
|
||||
val = Utils.mkvc(val,2)
|
||||
field[:,ind] = val
|
||||
|
||||
@@ -160,8 +160,8 @@ class Fields(object):
|
||||
assert hasattr(self, func), 'The alias field function is a string, but it does not exist in the Fields class.'
|
||||
func = getattr(self, func)
|
||||
out = func(self._fields[alias][:,ind], srcII)
|
||||
if out.shape[0] == out.size:
|
||||
out = Utils.mkvc(out)
|
||||
if isinstance(out, np.ndarray) and (out.shape[0] == out.size or out.ndim == 1):
|
||||
out = Utils.mkvc(out,2)
|
||||
return out
|
||||
|
||||
def __contains__(self, other):
|
||||
@@ -216,6 +216,8 @@ class TimeFields(Fields):
|
||||
shape = nP, nSrc, nT
|
||||
if deflate:
|
||||
shape = tuple([s for s in shape if s > 1])
|
||||
if len(shape) == 1:
|
||||
shape = shape + (1,)
|
||||
return shape
|
||||
|
||||
def _setField(self, field, val, name, ind):
|
||||
@@ -260,8 +262,8 @@ class TimeFields(Fields):
|
||||
out = range(nT)
|
||||
for i, TIND_i in enumerate(timeII):
|
||||
fieldI = pointerFields[:,:,i]
|
||||
if fieldI.ndim == 2 and fieldI.shape[1] == 1:
|
||||
fieldI = Utils.mkvc(fieldI)
|
||||
if fieldI.shape[0] == fieldI.size:
|
||||
fieldI = Utils.mkvc(fieldI,2)
|
||||
out[i] = func(fieldI, srcII, TIND_i)
|
||||
if out[i].ndim == 1:
|
||||
out[i] = out[i][:,np.newaxis,np.newaxis]
|
||||
|
||||
+13
-13
@@ -80,9 +80,9 @@ class DataAndFieldsTest(unittest.TestCase):
|
||||
|
||||
b = np.random.rand(F.mesh.nF,1)
|
||||
F[self.Src0, 'b'] = b
|
||||
self.assertTrue(np.all(F[self.Src0, 'b'] == Utils.mkvc(b)))
|
||||
self.assertTrue(np.all(F[self.Src0, 'b'] == b))
|
||||
|
||||
b = np.random.rand(F.mesh.nF)
|
||||
b = np.random.rand(F.mesh.nF,1)
|
||||
F[self.Src0, 'b'] = b
|
||||
self.assertTrue(np.all(F[self.Src0, 'b'] == b))
|
||||
|
||||
@@ -96,10 +96,10 @@ class DataAndFieldsTest(unittest.TestCase):
|
||||
|
||||
b = np.random.rand(F.mesh.nF, 2)
|
||||
F[[self.Src0, self.Src1],'b'] = b
|
||||
self.assertTrue(F[self.Src0]['b'].shape == (F.mesh.nF,))
|
||||
self.assertTrue(F[self.Src0,'b'].shape == (F.mesh.nF,))
|
||||
self.assertTrue(np.all(F[self.Src0,'b'] == b[:,0]))
|
||||
self.assertTrue(np.all(F[self.Src1,'b'] == b[:,1]))
|
||||
self.assertTrue(F[self.Src0]['b'].shape == (F.mesh.nF,1))
|
||||
self.assertTrue(F[self.Src0,'b'].shape == (F.mesh.nF,1))
|
||||
self.assertTrue(np.all(F[self.Src0,'b'] == Utils.mkvc(b[:,0],2)))
|
||||
self.assertTrue(np.all(F[self.Src1,'b'] == Utils.mkvc(b[:,1],2)))
|
||||
|
||||
def test_assertions(self):
|
||||
freq = [self.Src0, self.Src1]
|
||||
@@ -158,7 +158,7 @@ class FieldsTest_Alias(unittest.TestCase):
|
||||
|
||||
e = np.random.rand(F.mesh.nE,1)
|
||||
F[self.Src0, 'e'] = e
|
||||
self.assertTrue(np.all(F[self.Src0, 'b'] == F.mesh.edgeCurl * Utils.mkvc(e)))
|
||||
self.assertTrue(np.all(F[self.Src0, 'b'] == F.mesh.edgeCurl * e))
|
||||
|
||||
def f():
|
||||
F[self.Src0, 'b'] = F[self.Src0, 'b']
|
||||
@@ -249,7 +249,7 @@ class FieldsTest_Time(unittest.TestCase):
|
||||
|
||||
b = np.random.rand(F.mesh.nF,1,nT)
|
||||
F[self.Src0, 'b', 0] = b[:,:,0]
|
||||
self.assertTrue(np.all(F[self.Src0, 'b', 0] == b[:,0,0]))
|
||||
self.assertTrue(np.all(F[self.Src0, 'b', 0] == Utils.mkvc(b[:,0,0],2)))
|
||||
|
||||
phi = np.random.rand(F.mesh.nC,2,nT)
|
||||
F[[self.Src0,self.Src1], 'phi'] = phi
|
||||
@@ -265,10 +265,10 @@ class FieldsTest_Time(unittest.TestCase):
|
||||
self.assertTrue(F[self.Src0,'b'].shape == (F.mesh.nF,nT))
|
||||
self.assertTrue(np.all(F[self.Src0,'b'] == b[:,0,:]))
|
||||
self.assertTrue(np.all(F[self.Src1,'b'] == b[:,1,:]))
|
||||
self.assertTrue(np.all(F[self.Src0,'b',1] == b[:,0,1]))
|
||||
self.assertTrue(np.all(F[self.Src1,'b',1] == b[:,1,1]))
|
||||
self.assertTrue(np.all(F[self.Src0,'b',4] == b[:,0,4]))
|
||||
self.assertTrue(np.all(F[self.Src1,'b',4] == b[:,1,4]))
|
||||
self.assertTrue(np.all(F[self.Src0,'b',1] == Utils.mkvc(b[:,0,1],2)))
|
||||
self.assertTrue(np.all(F[self.Src1,'b',1] == Utils.mkvc(b[:,1,1],2)))
|
||||
self.assertTrue(np.all(F[self.Src0,'b',4] == Utils.mkvc(b[:,0,4],2)))
|
||||
self.assertTrue(np.all(F[self.Src1,'b',4] == Utils.mkvc(b[:,1,4],2)))
|
||||
|
||||
|
||||
b = np.random.rand(F.mesh.nF, 2, nT)
|
||||
@@ -344,7 +344,7 @@ class FieldsTest_Time_Aliased(unittest.TestCase):
|
||||
self.assertTrue(np.all(F[self.Src0, 'e', :] == e[:,0,:] ))
|
||||
self.assertTrue(np.all(F[self.Src1, 'e', :] == e[:,1,:] ))
|
||||
for t in range(nT):
|
||||
self.assertTrue(np.all(F[self.Src1, 'e', t] == e[:,1,t] ))
|
||||
self.assertTrue(np.all(F[self.Src1, 'e', t] == Utils.mkvc(e[:,1,t],2) ))
|
||||
|
||||
b = np.random.rand(F.mesh.nF,nT)
|
||||
F[self.Src0, 'b',:] = b
|
||||
|
||||
Reference in New Issue
Block a user