Coverage for gpaw/tddft/solvers/cscg.py: 94%

63 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1# Written by Lauri Lehtovaara, 2008 

2 

3"""This module defines CSCG-class, which implements conjugate gradient 

4for complex symmetric matrices. Requires Numpy and GPAW's own BLAS.""" 

5 

6import numpy as np 

7 

8from gpaw.utilities.blas import axpy 

9from gpaw.mpi import rank 

10 

11from .base import BaseSolver 

12 

13 

14class CSCG(BaseSolver): 

15 """Conjugate gradient for complex symmetric matrices 

16 

17 This class solves a set of linear equations A.x = b using conjugate 

18 gradient for complex symmetric matrices. The matrix A is a complex, 

19 symmetric, and non-singular matrix. The method requires only access 

20 to matrix-vector product A.x = b, which is called A.dot(x). Thus A 

21 must provide the member function dot(self,x,b), where x and b are 

22 complex arrays (numpy.array([], complex), and x is the known vector, 

23 and b is the result. 

24 

25 Now x and b are multivectors, i.e., list of vectors. 

26 """ 

27 

28 def solve(self, A, x, b): 

29 if self.timer is not None: 

30 self.timer.start('CSCG') 

31 

32 # number of vectors 

33 nvec = len(x) 

34 

35 # r_0 = b - A x_0 

36 r = self.gd.zeros(nvec, dtype=complex) 

37 A.dot(-x, r) 

38 r += b 

39 

40 p = self.gd.zeros(nvec, dtype=complex) 

41 q = self.gd.zeros(nvec, dtype=complex) 

42 z = self.gd.zeros(nvec, dtype=complex) 

43 

44 alpha = np.zeros((nvec,), dtype=complex) 

45 beta = np.zeros((nvec,), dtype=complex) 

46 rho = np.zeros((nvec,), dtype=complex) 

47 rhop = np.zeros((nvec,), dtype=complex) 

48 scale = np.zeros((nvec,), dtype=complex) 

49 tmp = np.zeros((nvec,), dtype=complex) 

50 

51 rhop[:] = 1. 

52 

53 # Multivector dot product, a^T b, where ^T is transpose 

54 def multi_zdotu(s, x, y, nvec): 

55 for i in range(nvec): 

56 s[i] = x[i].ravel().dot(y[i].ravel()) 

57 # s[i] = dotu(x[i],y[i]) 

58 self.gd.comm.sum(s) 

59 return s 

60 

61 # Multivector ZAXPY: a x + y => y 

62 def multi_zaxpy(a, x, y, nvec): 

63 for i in range(nvec): 

64 axpy(a[i] * (1 + 0J), x[i], y[i]) 

65 

66 # Multiscale: a x => x 

67 def multi_scale(a, x, nvec): 

68 for i in range(nvec): 

69 x[i] *= a[i] 

70 

71 # scale = square of the norm of b 

72 multi_zdotu(scale, b, b, nvec) 

73 scale = np.abs(scale) 

74 

75 # if scale < eps, then convergence check breaks down 

76 if (scale < self.eps).any(): 

77 raise RuntimeError( 

78 "CSCG method detected underflow for squared norm of " 

79 "right-hand side (scale = %le < eps = %le)." % 

80 (scale, self.eps)) 

81 

82 # print 'Scale = ', scale 

83 

84 slow_convergence_iters = 100 

85 

86 for i in range(self.max_iter): 

87 # z_i = (M^-1.r) 

88 A.apply_preconditioner(r, z) 

89 

90 # rho_i-1 = r^T z_i-1 

91 multi_zdotu(rho, r, z, nvec) 

92 

93 # print 'Rho = ', rho 

94 

95 # if i=1, p_i = r_i-1 

96 # else beta = (rho_i-1 / rho_i-2) (alpha_i-1 / omega_i-1) 

97 # p_i = r_i-1 + b_i-1 (p_i-1 - omega_i-1 v_i-1) 

98 beta = rho / rhop 

99 

100 # print 'Beta = ', beta 

101 

102 # if abs(beta) / scale < eps, then CSCG breaks down 

103 if ((i > 0) and 

104 ((np.abs(beta) / scale) < self.eps).any()): 

105 raise RuntimeError( 

106 "Conjugate gradient method failed " 

107 "(abs(beta)=%le < eps = %le)." % 

108 (np.min(np.abs(beta)), self.eps)) 

109 

110 # p = z + beta p 

111 multi_scale(beta, p, nvec) 

112 p += z 

113 

114 # q = A.p 

115 A.dot(p, q) 

116 

117 # alpha_i = rho_i-1 / (p^T q_i) 

118 multi_zdotu(alpha, p, q, nvec) 

119 alpha = rho / alpha 

120 

121 # print 'Alpha = ', alpha 

122 

123 # x_i = x_i-1 + alpha_i p_i 

124 multi_zaxpy(alpha, p, x, nvec) 

125 # r_i = r_i-1 - alpha_i q_i 

126 multi_zaxpy(-alpha, q, r, nvec) 

127 

128 # if ( |r|^2 < tol^2 ) done 

129 multi_zdotu(tmp, r, r, nvec) 

130 if ((np.abs(tmp) / scale) < self.tol * self.tol).all(): 

131 # print 'R2 of proc #', rank, ' = ' , tmp, \ 

132 # ' after ', i+1, ' iterations' 

133 break 

134 

135 # print if slow convergence 

136 if ((i + 1) % slow_convergence_iters) == 0: 

137 print('R2 of proc #', rank, ' = ', tmp, 

138 ' after ', i + 1, ' iterations') 

139 

140 # finally update rho 

141 rhop[:] = rho 

142 

143 # if max iters reached, raise error 

144 if (i >= self.max_iter - 1): 

145 raise RuntimeError( 

146 "Conjugate gradient method failed to converged " 

147 "within given number of iterations (= %d)." % self.max_iter) 

148 

149 # done 

150 self.iterations = i + 1 

151 # print 'CSCG iterations = ', self.iterations 

152 

153 if self.timer is not None: 

154 self.timer.stop('CSCG') 

155 

156 return self.iterations 

157 # print self.iterations