mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-28 11:53:55 +08:00
Updates to check derivative code
(should now work for complex, without throwing warnings.)
This commit is contained in:
+23
-14
@@ -220,35 +220,44 @@ def checkDerivative(fctn, x0, num=7, plotIt=True, dx=None, expectedOrder=2, tole
|
||||
"""
|
||||
|
||||
print "%s checkDerivative %s" % ('='*20, '='*20)
|
||||
print "iter\th\t\t|J0-Jt|\t\t|J0+h*dJ'*dx-Jt|\tOrder\n%s" % ('-'*57)
|
||||
print "iter h |f0-ft| |f0-ft-h*J0*dx| Order\n%s" % ('-'*57)
|
||||
|
||||
Jc = fctn(x0)
|
||||
f0, J0 = fctn(x0)
|
||||
|
||||
x0 = mkvc(x0)
|
||||
|
||||
if dx is None:
|
||||
dx = np.random.randn(len(x0))
|
||||
|
||||
t = np.logspace(-1, -num, num)
|
||||
E0 = np.ones(t.shape)
|
||||
E1 = np.ones(t.shape)
|
||||
h = np.logspace(-1, -num, num)
|
||||
E0 = np.ones(h.shape)
|
||||
E1 = np.ones(h.shape)
|
||||
|
||||
def l2norm(x):
|
||||
# because np.norm breaks if they are scalars?
|
||||
return np.sqrt(np.real(np.vdot(x, x)))
|
||||
|
||||
l2norm = lambda x: np.sqrt(np.inner(x, x)) # because np.norm breaks if they are scalars?
|
||||
for i in range(num):
|
||||
Jt = fctn(x0+t[i]*dx)
|
||||
E0[i] = l2norm(Jt[0]-Jc[0]) # 0th order Taylor
|
||||
if inspect.isfunction(Jc[1]):
|
||||
E1[i] = l2norm(Jt[0]-Jc[0]-t[i]*Jc[1](dx)) # 1st order Taylor
|
||||
# Evaluate at test point
|
||||
ft, Jt = fctn( x0 + h[i]*dx )
|
||||
# 0th order Taylor
|
||||
E0[i] = l2norm( ft - f0 )
|
||||
# 1st order Taylor
|
||||
if inspect.isfunction(J0):
|
||||
E1[i] = l2norm( ft - f0 - h[i]*J0(dx) )
|
||||
else:
|
||||
# We assume it is a numpy.ndarray
|
||||
E1[i] = l2norm(Jt[0]-Jc[0]-t[i]*Jc[1].dot(dx)) # 1st order Taylor
|
||||
E1[i] = l2norm( ft - f0 - h[i]*J0.dot(dx) )
|
||||
|
||||
order0 = np.log10(E0[:-1]/E0[1:])
|
||||
order1 = np.log10(E1[:-1]/E1[1:])
|
||||
print "%d\t%1.2e\t%1.3e\t\t%1.3e\t\t%1.3f" % (i, t[i], E0[i], E1[i], np.nan if i == 0 else order1[i-1])
|
||||
print " %d %1.2e %1.3e %1.3e %1.3f" % (i, h[i], E0[i], E1[i], np.nan if i == 0 else order1[i-1])
|
||||
|
||||
# Ensure we are about precision
|
||||
order0 = order0[E0[1:] > eps]
|
||||
order1 = order1[E1[1:] > eps]
|
||||
belowTol = order1.size == 0 and order0.size > 0
|
||||
# Make sure we get the correct order
|
||||
correctOrder = order1.size > 0 and np.mean(order1) > tolerance * expectedOrder
|
||||
|
||||
passTest = belowTol or correctOrder
|
||||
@@ -264,8 +273,8 @@ def checkDerivative(fctn, x0, num=7, plotIt=True, dx=None, expectedOrder=2, tole
|
||||
if plotIt:
|
||||
plt.figure()
|
||||
plt.clf()
|
||||
plt.loglog(t, E0, 'b')
|
||||
plt.loglog(t, E1, 'g--')
|
||||
plt.loglog(h, E0, 'b')
|
||||
plt.loglog(h, E1, 'g--')
|
||||
plt.title('checkDerivative')
|
||||
plt.xlabel('h')
|
||||
plt.ylabel('error of Taylor approximation')
|
||||
|
||||
Reference in New Issue
Block a user