diff --git a/SimPEG/Utils/meshutils.py b/SimPEG/Utils/meshutils.py index 6b47d690..155a6d9f 100644 --- a/SimPEG/Utils/meshutils.py +++ b/SimPEG/Utils/meshutils.py @@ -65,19 +65,22 @@ def points2nodes(mesh, pts): pts = np.atleast_2d(pts) - assert mesh._meshType == 'TENSOR' + assert mesh._meshType in ['TENSOR', 'CYL'] assert pts.shape[1] == mesh.dim nodeInds = np.empty(pts.shape[0], dtype=int) for i, pt in enumerate(pts): - nodeInds[i] = ((np.tile(pt, (mesh.nN,1)) - mesh.gridN)**2).sum(axis=1).argmin() + nodeInds[i] = ((np.tile(pt, (mesh.gridN.shape[0],1)) - mesh.gridN)**2).sum(axis=1).argmin() return nodeInds + + + if __name__ == '__main__': - from SimPEG import mesh + from SimPEG import Mesh import matplotlib.pyplot as plt - M = mesh.TensorMesh(meshTensors(((10,10),(40,10),(10,10)), ((10,10),(20,10),(0,0)))) + M = Mesh.TensorMesh(meshTensors(((10,10),(40,10),(10,10)), ((10,10),(20,10),(0,0)))) M.plotGrid() plt.gca().axis('tight') plt.show()