diff --git a/SimPEG/Utils/__init__.py b/SimPEG/Utils/__init__.py index 6637a138..5280ae79 100644 --- a/SimPEG/Utils/__init__.py +++ b/SimPEG/Utils/__init__.py @@ -7,4 +7,4 @@ from ipythonutils import easyAnimate as animate from CounterUtils import * import ModelBuilder import SolverUtils - +from coordutils import * diff --git a/SimPEG/Utils/coordutils.py b/SimPEG/Utils/coordutils.py new file mode 100644 index 00000000..260e1a3b --- /dev/null +++ b/SimPEG/Utils/coordutils.py @@ -0,0 +1,62 @@ +import numpy as np +from SimPEG.Utils import mkvc + +def rotationMatrixFromNormals(v0,v1,tol=1e-20): + """ + Performs the minimum number of rotations to define a rotation from the direction indicated by the vector n0 to the direction indicated by n1. + The axis of rotation is n0 x n1 + https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula + + :param numpy.array v0: vector of length 3 + :param numpy.array v1: vector of length 3 + :param tol = 1e-20: tolerance. If the norm of the cross product between the two vectors is below this, no rotation is performed + :rtype: numpy.array, 3x3 + :return: rotation matrix which rotates the frame so that n0 is aligned with n1 + + """ + + # ensure both n0, n1 are vectors of length 1 + assert len(v0) == 3, "Length of n0 should be 3" + assert len(v1) == 3, "Length of n1 should be 3" + + # ensure both are true normals + n0 = v0*1./np.linalg.norm(v0) + n1 = v1*1./np.linalg.norm(v1) + + n0dotn1 = n0.dot(n1) + + # define the rotation axis, which is the cross product of the two vectors + rotAx = np.cross(n0,n1) + + if np.linalg.norm(rotAx) < tol: + return np.eye(3,dtype=float) + + rotAx *= 1./np.linalg.norm(rotAx) + + cosT = n0dotn1/(np.linalg.norm(n0)*np.linalg.norm(n1)) + sinT = np.sqrt(1.-n0dotn1**2) + + ux = np.array([[0., -rotAx[2], rotAx[1]], [rotAx[2], 0., -rotAx[0]], [-rotAx[1], rotAx[0], 0.]],dtype=float) + + return np.eye(3,dtype=float) + sinT*ux + (1.-cosT)*(ux.dot(ux)) + + +def rotatePointsFromNormals(XYZ,n0,n1,x0=np.r_[0.,0.,0.]): + """ + rotates a grid so that the vector n0 is aligned with the vector n1 + + :param numpy.array n0: vector of length 3, should have norm 1 + :param numpy.array n1: vector of length 3, should have norm 1 + :param numpy.array x0: vector of length 3, point about which we perform the rotation + :rtype: numpy.array, 3x3 + :return: rotation matrix which rotates the frame so that n0 is aligned with n1 + """ + + R = rotationMatrixFromNormals(n0, n1) + + assert XYZ.shape[1] == 3, "Grid XYZ should be 3 wide" + assert len(x0) == 3, "x0 should have length 3" + + X0 = np.ones([XYZ.shape[0],1])*mkvc(x0) + + return (XYZ - X0).dot(R.T) + X0 # equivalent to (R*(XYZ - X0)).T + X0 \ No newline at end of file diff --git a/tests/utils/test_coordutils.py b/tests/utils/test_coordutils.py new file mode 100644 index 00000000..b17afcf7 --- /dev/null +++ b/tests/utils/test_coordutils.py @@ -0,0 +1,45 @@ +import unittest, os +import numpy as np +from SimPEG import Utils + +tol = 1e-15 + +class coorUtilsTest(unittest.TestCase): + + def test_rotationMatrixFromNormals(self): + v0 = np.random.rand(3) + v0 *= 1./np.linalg.norm(v0) + v1 = np.random.rand(3) + v1 *= 1./np.linalg.norm(v1) + Rf = Utils.coordutils.rotationMatrixFromNormals(v0,v1) + Ri = Utils.coordutils.rotationMatrixFromNormals(v1,v0) + + self.assertTrue(np.linalg.norm(Utils.mkvc(Rf.dot(v0) - v1)) < tol) + self.assertTrue(np.linalg.norm(Utils.mkvc(Ri.dot(v1) - v0)) < tol) + + def test_rotatePointsFromNormals(self): + v0 = np.random.rand(3) + v0*= 1./np.linalg.norm(v0) + v1 = np.random.rand(3) + v1*= 1./np.linalg.norm(v1) + + v2 = Utils.mkvc(Utils.coordutils.rotatePointsFromNormals(Utils.mkvc(v0,2).T,v0,v1)) + + self.assertTrue(np.linalg.norm(v2-v1) < tol) + + def test_rotateMatrixFromNormals(self): + n0 = np.random.rand(3) + n0*= 1./np.linalg.norm(n0) + n1 = np.random.rand(3) + n1*= 1./np.linalg.norm(n1) + + scale = np.random.rand(100,1) + XYZ0 = scale * n0 + XYZ1 = scale * n1 + + XYZ2 = Utils.coordutils.rotatePointsFromNormals(XYZ0,n0,n1) + self.assertTrue(np.linalg.norm(Utils.mkvc(XYZ1) - Utils.mkvc(XYZ2))/np.linalg.norm(Utils.mkvc(XYZ1)) < tol) + +if __name__ == '__main__': + unittest.main() +