mirror of
https://github.com/wassname/Run-Skeleton-Run.git
synced 2026-06-27 17:14:10 +08:00
39 lines
923 B
Python
39 lines
923 B
Python
import numpy as np
|
|
|
|
|
|
def cg(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10):
|
|
"""
|
|
Demmel p 312
|
|
"""
|
|
p = b.copy()
|
|
r = b.copy()
|
|
x = np.zeros_like(b)
|
|
rdotr = r.dot(r)
|
|
|
|
fmtstr = "%10i %10.3g %10.3g"
|
|
titlestr = "%10s %10s %10s"
|
|
if verbose:
|
|
print(titlestr % ("iter", "residual norm", "soln norm"))
|
|
|
|
for i in range(cg_iters):
|
|
if callback is not None:
|
|
callback(x)
|
|
if verbose: print(fmtstr % (i, rdotr, np.linalg.norm(x)))
|
|
z = f_Ax(p)
|
|
v = rdotr / p.dot(z)
|
|
x += v * p
|
|
r -= v * z
|
|
newrdotr = r.dot(r)
|
|
mu = newrdotr / rdotr
|
|
p = r + mu * p
|
|
|
|
rdotr = newrdotr
|
|
if rdotr < residual_tol:
|
|
break
|
|
|
|
if callback is not None:
|
|
callback(x)
|
|
if verbose:
|
|
print(fmtstr % (i + 1, rdotr, np.linalg.norm(x))) # pylint: disable=W0631
|
|
return x
|