Coverage for gpaw/tddft/solvers/cscg.py: 94%
63 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1# Written by Lauri Lehtovaara, 2008
3"""This module defines CSCG-class, which implements conjugate gradient
4for complex symmetric matrices. Requires Numpy and GPAW's own BLAS."""
6import numpy as np
8from gpaw.utilities.blas import axpy
9from gpaw.mpi import rank
11from .base import BaseSolver
14class CSCG(BaseSolver):
15 """Conjugate gradient for complex symmetric matrices
17 This class solves a set of linear equations A.x = b using conjugate
18 gradient for complex symmetric matrices. The matrix A is a complex,
19 symmetric, and non-singular matrix. The method requires only access
20 to matrix-vector product A.x = b, which is called A.dot(x). Thus A
21 must provide the member function dot(self,x,b), where x and b are
22 complex arrays (numpy.array([], complex), and x is the known vector,
23 and b is the result.
25 Now x and b are multivectors, i.e., list of vectors.
26 """
28 def solve(self, A, x, b):
29 if self.timer is not None:
30 self.timer.start('CSCG')
32 # number of vectors
33 nvec = len(x)
35 # r_0 = b - A x_0
36 r = self.gd.zeros(nvec, dtype=complex)
37 A.dot(-x, r)
38 r += b
40 p = self.gd.zeros(nvec, dtype=complex)
41 q = self.gd.zeros(nvec, dtype=complex)
42 z = self.gd.zeros(nvec, dtype=complex)
44 alpha = np.zeros((nvec,), dtype=complex)
45 beta = np.zeros((nvec,), dtype=complex)
46 rho = np.zeros((nvec,), dtype=complex)
47 rhop = np.zeros((nvec,), dtype=complex)
48 scale = np.zeros((nvec,), dtype=complex)
49 tmp = np.zeros((nvec,), dtype=complex)
51 rhop[:] = 1.
53 # Multivector dot product, a^T b, where ^T is transpose
54 def multi_zdotu(s, x, y, nvec):
55 for i in range(nvec):
56 s[i] = x[i].ravel().dot(y[i].ravel())
57 # s[i] = dotu(x[i],y[i])
58 self.gd.comm.sum(s)
59 return s
61 # Multivector ZAXPY: a x + y => y
62 def multi_zaxpy(a, x, y, nvec):
63 for i in range(nvec):
64 axpy(a[i] * (1 + 0J), x[i], y[i])
66 # Multiscale: a x => x
67 def multi_scale(a, x, nvec):
68 for i in range(nvec):
69 x[i] *= a[i]
71 # scale = square of the norm of b
72 multi_zdotu(scale, b, b, nvec)
73 scale = np.abs(scale)
75 # if scale < eps, then convergence check breaks down
76 if (scale < self.eps).any():
77 raise RuntimeError(
78 "CSCG method detected underflow for squared norm of "
79 "right-hand side (scale = %le < eps = %le)." %
80 (scale, self.eps))
82 # print 'Scale = ', scale
84 slow_convergence_iters = 100
86 for i in range(self.max_iter):
87 # z_i = (M^-1.r)
88 A.apply_preconditioner(r, z)
90 # rho_i-1 = r^T z_i-1
91 multi_zdotu(rho, r, z, nvec)
93 # print 'Rho = ', rho
95 # if i=1, p_i = r_i-1
96 # else beta = (rho_i-1 / rho_i-2) (alpha_i-1 / omega_i-1)
97 # p_i = r_i-1 + b_i-1 (p_i-1 - omega_i-1 v_i-1)
98 beta = rho / rhop
100 # print 'Beta = ', beta
102 # if abs(beta) / scale < eps, then CSCG breaks down
103 if ((i > 0) and
104 ((np.abs(beta) / scale) < self.eps).any()):
105 raise RuntimeError(
106 "Conjugate gradient method failed "
107 "(abs(beta)=%le < eps = %le)." %
108 (np.min(np.abs(beta)), self.eps))
110 # p = z + beta p
111 multi_scale(beta, p, nvec)
112 p += z
114 # q = A.p
115 A.dot(p, q)
117 # alpha_i = rho_i-1 / (p^T q_i)
118 multi_zdotu(alpha, p, q, nvec)
119 alpha = rho / alpha
121 # print 'Alpha = ', alpha
123 # x_i = x_i-1 + alpha_i p_i
124 multi_zaxpy(alpha, p, x, nvec)
125 # r_i = r_i-1 - alpha_i q_i
126 multi_zaxpy(-alpha, q, r, nvec)
128 # if ( |r|^2 < tol^2 ) done
129 multi_zdotu(tmp, r, r, nvec)
130 if ((np.abs(tmp) / scale) < self.tol * self.tol).all():
131 # print 'R2 of proc #', rank, ' = ' , tmp, \
132 # ' after ', i+1, ' iterations'
133 break
135 # print if slow convergence
136 if ((i + 1) % slow_convergence_iters) == 0:
137 print('R2 of proc #', rank, ' = ', tmp,
138 ' after ', i + 1, ' iterations')
140 # finally update rho
141 rhop[:] = rho
143 # if max iters reached, raise error
144 if (i >= self.max_iter - 1):
145 raise RuntimeError(
146 "Conjugate gradient method failed to converged "
147 "within given number of iterations (= %d)." % self.max_iter)
149 # done
150 self.iterations = i + 1
151 # print 'CSCG iterations = ', self.iterations
153 if self.timer is not None:
154 self.timer.stop('CSCG')
156 return self.iterations
157 # print self.iterations