Coverage for gpaw/eigensolvers/eigensolver.py: 92%

144 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-12 00:18 +0000

1"""Module defining an eigensolver base-class.""" 

2from functools import partial 

3 

4import numpy as np 

5from ase.dft.bandgap import _bandgap 

6from ase.units import Ha 

7from ase.utils.timing import timer 

8 

9from gpaw.matrix import matrix_matrix_multiply as mmm 

10from gpaw.utilities.mblas import multi_axpy 

11from gpaw.xc.hybrid import HybridXC 

12from gpaw.mpi import broadcast_exception 

13 

14 

15def reshape(a_x, shape): 

16 """Get an ndarray of size shape from a_x buffer.""" 

17 return a_x.ravel()[:np.prod(shape)].reshape(shape) 

18 

19 

20class Eigensolver: 

21 def __init__(self, keep_htpsit=True, blocksize=1): 

22 self.keep_htpsit = keep_htpsit 

23 self.initialized = False 

24 self.Htpsit_nG = None 

25 self.error = np.inf 

26 self.blocksize = blocksize 

27 self.orthonormalization_required = True 

28 

29 def initialize(self, wfs): 

30 self.timer = wfs.timer 

31 self.world = wfs.world 

32 self.kpt_comm = wfs.kd.comm 

33 self.band_comm = wfs.bd.comm 

34 self.dtype = wfs.dtype 

35 self.bd = wfs.bd 

36 self.nbands = wfs.bd.nbands 

37 self.mynbands = wfs.bd.mynbands 

38 

39 if wfs.bd.comm.size > 1: 

40 self.keep_htpsit = False 

41 

42 if self.keep_htpsit: 

43 self.Htpsit_nG = np.empty_like(wfs.work_array) 

44 

45 # Preconditioner for the electronic gradients: 

46 self.preconditioner = wfs.make_preconditioner(self.blocksize) 

47 

48 for kpt in wfs.kpt_u: 

49 if kpt.eps_n is None: 

50 kpt.eps_n = np.empty(self.mynbands) 

51 

52 self.initialized = True 

53 

54 def reset(self): 

55 self.initialized = False 

56 

57 def weights(self, wfs): 

58 """Calculate convergence weights for all eigenstates.""" 

59 weight_un = np.zeros((len(wfs.kpt_u), self.bd.mynbands)) 

60 

61 if isinstance(self.nbands_converge, int): 

62 # Converge fixed number of bands: 

63 n = self.nbands_converge - self.bd.beg 

64 if n > 0: 

65 for weight_n, kpt in zip(weight_un, wfs.kpt_u): 

66 weight_n[:n] = kpt.weight 

67 elif self.nbands_converge == 'occupied': 

68 # Conveged occupied bands: 

69 for weight_n, kpt in zip(weight_un, wfs.kpt_u): 

70 if kpt.f_n is None: # no eigenvalues yet 

71 weight_n[:] = np.inf 

72 else: 

73 # Methfessel-Paxton distribution can give negative 

74 # occupation numbers - so we take the absolute value: 

75 weight_n[:] = np.abs(kpt.f_n) 

76 else: 

77 # Converge state with energy up to CBM + delta: 

78 assert self.nbands_converge.startswith('CBM+') 

79 delta = float(self.nbands_converge[4:]) / Ha 

80 

81 if wfs.kpt_u[0].f_n is None: 

82 weight_un[:] = np.inf # no eigenvalues yet 

83 else: 

84 # Collect all eigenvalues and calculate band gap: 

85 efermi = np.mean(wfs.fermi_levels) 

86 eps_skn = np.array( 

87 [[wfs.collect_eigenvalues(k, spin) - efermi 

88 for k in range(wfs.kd.nibzkpts)] 

89 for spin in range(wfs.nspins)]) 

90 if wfs.world.rank > 0: 

91 eps_skn = np.empty((wfs.nspins, 

92 wfs.kd.nibzkpts, 

93 wfs.bd.nbands)) 

94 wfs.world.broadcast(eps_skn, 0) 

95 try: 

96 # Find bandgap + positions of CBM: 

97 try: 

98 gap, _, (s, k, n) = _bandgap(eps_skn, direct=False) 

99 except TypeError: 

100 # Old ASE: 

101 gap, _, (s, k, n) = _bandgap(eps_skn, spin=None, 

102 direct=False) 

103 except ValueError: 

104 gap = 0.0 

105 

106 if gap == 0.0: 

107 cbm = efermi 

108 else: 

109 cbm = efermi + eps_skn[s, k, n] 

110 

111 ecut = cbm + delta 

112 for weight_n, kpt in zip(weight_un, wfs.kpt_u): 

113 weight_n[kpt.eps_n < ecut] = kpt.weight 

114 

115 if (eps_skn[:, :, -1] < ecut - efermi).any(): 

116 # We don't have enough bands! 

117 weight_un[:] = np.inf 

118 

119 return weight_un 

120 

121 def iterate(self, ham, wfs): 

122 """Solves eigenvalue problem iteratively 

123 

124 This method is inherited by the actual eigensolver which should 

125 implement *iterate_one_k_point* method for a single iteration of 

126 a single kpoint. 

127 """ 

128 

129 if not self.initialized: 

130 if isinstance(ham.xc, HybridXC): 

131 self.blocksize = wfs.bd.mynbands 

132 self.initialize(wfs) 

133 

134 weight_un = self.weights(wfs) 

135 

136 error = 0.0 

137 with broadcast_exception(self.kpt_comm): 

138 for kpt, weights in zip(wfs.kpt_u, weight_un): 

139 if not wfs.orthonormalized: 

140 wfs.orthonormalize(kpt) 

141 e = self.iterate_one_k_point(ham, wfs, kpt, weights) 

142 error += e 

143 if self.orthonormalization_required: 

144 wfs.orthonormalize(kpt) 

145 

146 wfs.orthonormalized = True 

147 self.error = self.band_comm.sum_scalar(self.kpt_comm.sum_scalar(error)) 

148 

149 def iterate_one_k_point(self, ham, kpt): 

150 """Implemented in subclasses.""" 

151 raise NotImplementedError 

152 

153 def calculate_residuals(self, kpt, wfs, ham, psit, P, eps_n, 

154 R, C, n_x=None, calculate_change=False): 

155 """Calculate residual. 

156 

157 From R=Ht*psit calculate R=H*psit-eps*S*psit.""" 

158 

159 multi_axpy(-eps_n, psit.array, R.array) 

160 

161 ham.dH(P, out=C) 

162 for a, I1, I2 in P.indices: 

163 dS_ii = ham.setups[a].dO_ii 

164 C.array[..., I1:I2] -= np.dot((P.array[..., I1:I2].T * eps_n).T, 

165 dS_ii) 

166 

167 ham.xc.add_correction(kpt, psit.array, R.array, 

168 {a: P_ni for a, P_ni in P.items()}, 

169 {a: C_ni for a, C_ni in C.items()}, 

170 n_x, 

171 calculate_change) 

172 wfs.pt.add(R.array, {a: C_ni for a, C_ni in C.items()}, kpt.q) 

173 

174 @timer('Subspace diag') 

175 def subspace_diagonalize(self, ham, wfs, kpt, rotate_psi=True): 

176 """Diagonalize the Hamiltonian in the subspace of kpt.psit_nG 

177 

178 *Htpsit_nG* is a work array of same size as psit_nG which contains 

179 the local part of the Hamiltonian times psit on exit 

180 

181 First, the Hamiltonian (defined by *kin*, *vt_sG*, and 

182 *dH_asp*) is applied to the wave functions, then the *H_nn* 

183 matrix is calculated and diagonalized, and finally, the wave 

184 functions (and also Htpsit_nG are rotated. Also the 

185 projections *P_ani* are rotated. 

186 

187 It is assumed that the wave functions *psit_nG* are orthonormal 

188 and that the integrals of projector functions and wave functions 

189 *P_ani* are already calculated. 

190 

191 Return rotated wave functions and H applied to the rotated 

192 wave functions if self.keep_htpsit is True. 

193 """ 

194 

195 if self.band_comm.size > 1 and wfs.bd.strided: 

196 raise NotImplementedError 

197 

198 psit = kpt.psit 

199 tmp = psit.new(buf=wfs.work_array) 

200 H = wfs.work_matrix_nn 

201 P2 = kpt.projections.new() 

202 

203 Ht = partial(wfs.apply_pseudo_hamiltonian, kpt, ham) 

204 

205 with self.timer('calc_h_matrix'): 

206 # We calculate the complex conjugate of H, because 

207 # that is what is most efficient with BLAS given the layout of 

208 # our matrices. 

209 psit.matrix_elements(operator=Ht, result=tmp, out=H, 

210 symmetric=True, cc=True) 

211 ham.dH(kpt.projections, out=P2) 

212 mmm(1.0, kpt.projections, 'N', P2, 'C', 1.0, H, symmetric=True) 

213 ham.xc.correct_hamiltonian_matrix(kpt, H.array) 

214 

215 with wfs.timer('diagonalize'): 

216 slcomm, r, c, b = wfs.scalapack_parameters 

217 if r == c == 1: 

218 slcomm = None 

219 # Complex conjugate before diagonalizing: 

220 eps_n = H.eigh(cc=True, scalapack=(slcomm, r, c, b)) 

221 # H.array[n, :] now contains the n'th eigenvector and eps_n[n] 

222 # the n'th eigenvalue 

223 kpt.eps_n[:] = eps_n[wfs.bd.get_slice()] 

224 

225 with self.timer('rotate_psi'): 

226 if not rotate_psi: 

227 return 

228 if self.keep_htpsit: 

229 Htpsit = psit.new(buf=self.Htpsit_nG) 

230 mmm(1.0, H, 'N', tmp, 'N', 0.0, Htpsit) 

231 mmm(1.0, H, 'N', psit, 'N', 0.0, tmp) 

232 psit[:] = tmp 

233 mmm(1.0, H, 'N', kpt.projections, 'N', 0.0, P2) 

234 kpt.projections.matrix = P2.matrix 

235 # Rotate orbital dependent XC stuff: 

236 ham.xc.rotate(kpt, H.array.T) 

237 

238 def estimate_memory(self, mem, wfs): 

239 gridmem = wfs.bytes_per_wave_function() 

240 

241 keep_htpsit = self.keep_htpsit and (wfs.bd.mynbands == wfs.bd.nbands) 

242 

243 if keep_htpsit: 

244 mem.subnode('Htpsit', wfs.bd.nbands * gridmem) 

245 else: 

246 mem.subnode('No Htpsit', 0) 

247 

248 mem.subnode('eps_n', wfs.bd.mynbands * mem.floatsize) 

249 mem.subnode('eps_N', wfs.bd.nbands * mem.floatsize) 

250 mem.subnode('Preconditioner', 4 * gridmem) 

251 mem.subnode('Work', gridmem)