mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-03 12:37:13 +08:00
interp updates
This commit is contained in:
+33
-10
@@ -367,7 +367,7 @@ class TensorMesh(BaseRectangularMesh, TensorView, DiffOperators, InnerProducts):
|
||||
return [t for t in ten if t is not None]
|
||||
|
||||
|
||||
def isInside(self, pts):
|
||||
def isInside(self, pts, locType='N'):
|
||||
"""
|
||||
Determines if a set of points are inside a mesh.
|
||||
|
||||
@@ -376,15 +376,23 @@ class TensorMesh(BaseRectangularMesh, TensorView, DiffOperators, InnerProducts):
|
||||
:return inside, numpy array of booleans
|
||||
"""
|
||||
|
||||
pts = np.atleast_2d(pts)
|
||||
inside = (pts[:,0] >= self.vectorNx.min()) & (pts[:,0] <= self.vectorNx.max())
|
||||
tensors = self.getTensor(locType)
|
||||
if type(pts) == list:
|
||||
pts = np.array(pts)
|
||||
assert type(pts) == np.ndarray, "must be a numpy array"
|
||||
if self.dim > 1:
|
||||
inside = inside & ((pts[:,1] >= self.vectorNy.min()) & (pts[:,1] <= self.vectorNy.max()))
|
||||
if self.dim > 2:
|
||||
inside = inside & ((pts[:,2] >= self.vectorNz.min()) & (pts[:,2] <= self.vectorNz.max()))
|
||||
assert pts.shape[1] == self.dim, "must be a column vector of shape (nPts, mesh.dim)"
|
||||
elif len(pts.shape) == 1:
|
||||
pts = pts[:,np.newaxis]
|
||||
else:
|
||||
assert pts.shape[1] == self.dim, "must be a column vector of shape (nPts, mesh.dim)"
|
||||
|
||||
inside = np.ones(pts.shape[0],dtype=bool)
|
||||
for i, tensor in enumerate(tensors):
|
||||
inside = inside & (pts[:,i] >= tensor.min()) & (pts[:,i] <= tensor.max())
|
||||
return inside
|
||||
|
||||
def getInterpolationMat(self, loc, locType):
|
||||
def getInterpolationMat(self, loc, locType, zerosOutside=False):
|
||||
""" Produces interpolation matrix
|
||||
|
||||
:param numpy.ndarray loc: Location of points to interpolate to
|
||||
@@ -404,8 +412,21 @@ class TensorMesh(BaseRectangularMesh, TensorView, DiffOperators, InnerProducts):
|
||||
'CC' -> scalar field defined on cell centers
|
||||
"""
|
||||
|
||||
loc = np.atleast_2d(loc)
|
||||
assert np.all(self.isInside(loc)), "Points outside of mesh"
|
||||
if type(loc) == list:
|
||||
loc = np.array(loc)
|
||||
assert type(loc) == np.ndarray, "must be a numpy array"
|
||||
if self.dim > 1:
|
||||
assert loc.shape[1] == self.dim, "must be a column vector of shape (nPts, mesh.dim)"
|
||||
elif len(loc.shape) == 1:
|
||||
loc = loc[:,np.newaxis]
|
||||
else:
|
||||
assert loc.shape[1] == self.dim, "must be a column vector of shape (nPts, mesh.dim)"
|
||||
|
||||
if zerosOutside is False:
|
||||
assert np.all(self.isInside(loc)), "Points outside of mesh"
|
||||
else:
|
||||
indZeros = np.logical_not(self.isInside(loc))
|
||||
loc[indZeros, :] = np.array([v.mean() for v in self.getTensor('CC')])
|
||||
|
||||
ind = 0 if 'x' in locType else 1 if 'y' in locType else 2 if 'z' in locType else -1
|
||||
if locType in ['Fx','Fy','Fz','Ex','Ey','Ez'] and self.dim >= ind:
|
||||
@@ -417,7 +438,9 @@ class TensorMesh(BaseRectangularMesh, TensorView, DiffOperators, InnerProducts):
|
||||
Q = Utils.interpmat(loc, *self.getTensor(locType))
|
||||
else:
|
||||
raise NotImplementedError('getInterpolationMat: locType=='+locType+' and mesh.dim=='+str(self.dim))
|
||||
return Q
|
||||
if zerosOutside:
|
||||
Q[indZeros, :] = 0
|
||||
return Q.tocsr()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('Welcome to tensor mesh!')
|
||||
|
||||
@@ -20,6 +20,8 @@ def _interp_point_1D(x, xr_i):
|
||||
elif xr_i - x[im] < 0: # Point on the right
|
||||
ind_x1 = im-1
|
||||
ind_x2 = im
|
||||
ind_x1 = max(min(ind_x1, x.size-1), 0)
|
||||
ind_x2 = max(min(ind_x2, x.size-1), 0)
|
||||
dx1 = xr_i - x[ind_x1]
|
||||
dx2 = x[ind_x2] - xr_i
|
||||
return ind_x1, ind_x2, dx1, dx2
|
||||
@@ -77,7 +79,7 @@ def _interpmat1D(locs, x):
|
||||
inds = [ind_x1, ind_x2]
|
||||
vals = [(1-dx1/Dx),(1-dx2/Dx)]
|
||||
Q[i, inds] = vals
|
||||
return Q.tocsr()
|
||||
return Q
|
||||
|
||||
|
||||
|
||||
@@ -114,7 +116,7 @@ def _interpmat2D(locs, x, y):
|
||||
|
||||
Q[i, mkvc(inds)] = vals
|
||||
|
||||
return Q.tocsr()
|
||||
return Q
|
||||
|
||||
|
||||
|
||||
@@ -162,7 +164,7 @@ def _interpmat3D(locs, x, y, z):
|
||||
|
||||
Q[i, mkvc(inds)] = vals
|
||||
|
||||
return Q.tocsr()
|
||||
return Q
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user