Coverage for gpaw/new/pwfd/davidson.py: 99%

153 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from __future__ import annotations 

2 

3from pprint import pformat 

4 

5import numpy as np 

6from gpaw import debug 

7from gpaw.core.matrix import Matrix 

8from gpaw.gpu import as_np 

9from gpaw.mpi import broadcast_exception 

10from gpaw.new.pwfd.eigensolver import PWFDEigensolver, calculate_residuals 

11from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions 

12from gpaw.typing import Array2D 

13from gpaw.new import trace, tracectx 

14 

15 

16class Davidson(PWFDEigensolver): 

17 def __init__(self, 

18 nbands: int, 

19 wf_grid, 

20 band_comm, 

21 hamiltonian, 

22 converge_bands='occupied', 

23 niter=2, 

24 scalapack_parameters=None, 

25 max_buffer_mem: int = 200 * 1024 ** 2): 

26 super().__init__( 

27 hamiltonian, 

28 converge_bands, 

29 max_buffer_mem=max_buffer_mem) 

30 self.niter = niter 

31 self.H_NN: Matrix 

32 self.S_NN: Matrix 

33 self.M_nn: Matrix 

34 

35 def __str__(self): 

36 return pformat(dict(name='Davidson', 

37 niter=self.niter, 

38 converge_bands=self.converge_bands)) 

39 

40 def _initialize(self, ibzwfs): 

41 super()._initialize(ibzwfs) 

42 self._allocate_work_arrays(ibzwfs, shape=(1,)) 

43 self._allocate_buffer_arrays(ibzwfs, shape=(1,)) 

44 

45 wfs = ibzwfs.wfs_qs[0][0] 

46 assert isinstance(wfs, PWFDWaveFunctions) 

47 domain_comm = wfs.psit_nX.desc.comm 

48 band_comm = wfs.band_comm 

49 

50 B = ibzwfs.nbands 

51 xp = ibzwfs.xp 

52 dtype = wfs.psit_nX.desc.dtype 

53 if domain_comm.rank == 0 and band_comm.rank == 0: 

54 self.H_NN = Matrix(2 * B, 2 * B, dtype=dtype, xp=xp) 

55 self.S_NN = Matrix(2 * B, 2 * B, dtype=dtype, xp=xp) 

56 else: 

57 self.H_NN = self.S_NN = Matrix(0, 0) 

58 

59 self.M_nn = Matrix(B, B, dtype=dtype, 

60 dist=(band_comm, band_comm.size), 

61 xp=xp) 

62 self.M2_nn = self.M_nn.new() 

63 

64 def iterate1(self, 

65 wfs: PWFDWaveFunctions, 

66 Ht, dH, dS_aii, weight_n): 

67 H_NN = self.H_NN 

68 S_NN = self.S_NN 

69 M_nn = self.M_nn 

70 M2_nn = self.M2_nn 

71 

72 xp = M_nn.xp 

73 

74 psit_nX = wfs.psit_nX 

75 B = psit_nX.dims[0] # number of bands 

76 eig_N = xp.empty(2 * B) 

77 b = psit_nX.mydims[0] 

78 

79 psit2_nX = psit_nX.new(data=self.work_arrays[0, :b]) 

80 data_buffer = self.data_buffers[0] 

81 

82 wfs.subspace_diagonalize(Ht, dH, 

83 psit2_nX=psit2_nX, 

84 data_buffer=data_buffer) 

85 

86 P_ani = wfs.P_ani 

87 P2_ani = P_ani.new() 

88 P3_ani = P_ani.new() 

89 

90 domain_comm = psit_nX.desc.comm 

91 band_comm = psit_nX.comm 

92 is_domain_band_master = domain_comm.rank == 0 and band_comm.rank == 0 

93 

94 M0_nn = M_nn.new(dist=(band_comm, 1, 1)) 

95 

96 if domain_comm.rank == 0: 

97 eig_N[:B] = xp.asarray(wfs.eig_n) 

98 

99 me_buffer_mX = psit_nX.create_work_buffer(data_buffer) 

100 

101 @trace 

102 def me(a, b, function=None): 

103 """Matrix elements""" 

104 return a.matrix_elements(b, 

105 domain_sum=False, 

106 out=M_nn, 

107 function=function, 

108 cc=True) 

109 

110 calculate_residuals(wfs.psit_nX, 

111 psit2_nX, 

112 wfs.pt_aiX, 

113 wfs.P_ani, 

114 wfs.myeig_n, 

115 dH, dS_aii, P2_ani, P3_ani) 

116 

117 def copy(C_nn: Array2D, M_nn: Matrix) -> None: 

118 domain_comm.sum(M_nn.data, 0) 

119 if domain_comm.rank == 0: 

120 M_nn.redist(M0_nn) 

121 if band_comm.rank == 0: 

122 C_nn[:] = M0_nn.data 

123 

124 for i in range(self.niter): 

125 if i == self.niter - 1: # last iteration 

126 # Calculate error before we destroy residuals: 

127 if weight_n is None: 

128 error = np.inf 

129 else: 

130 error = (weight_n @ as_np(psit2_nX.norm2())).sum() 

131 

132 sliced_preconditioner(psit_nX, psit2_nX, 

133 buffer=me_buffer_mX, 

134 precon=self.preconditioner) 

135 

136 # Calculate projections 

137 wfs.pt_aiX.integrate(psit2_nX, out=P2_ani) 

138 with tracectx('Matrix elements'): 

139 # Sliced matrix elements with hamiltonian. See 

140 # sliced_matrix_elements docstring. 

141 sliced_matrix_elements(psit_nX, psit2_nX, 

142 buffer_mX=me_buffer_mX, 

143 Ht=Ht, 

144 M1_nn=M_nn, 

145 M2_nn=M2_nn) 

146 

147 # <psi2 | H | psi2> 

148 dH(P2_ani, out_ani=P3_ani) 

149 P2_ani.matrix.multiply(P3_ani, opb='C', symmetric=True, beta=1, 

150 out=M2_nn) 

151 copy(H_NN.data[B:, B:], M2_nn) 

152 

153 # <psi2 | H | psi> 

154 P3_ani.matrix.multiply(P_ani, opb='C', beta=1.0, out=M_nn) 

155 copy(H_NN.data[B:, :B], M_nn) 

156 

157 # <psi2 | S | psi2> 

158 me(psit2_nX, psit2_nX) 

159 P2_ani.block_diag_multiply(dS_aii, out_ani=P3_ani) 

160 P2_ani.matrix.multiply(P3_ani, opb='C', symmetric=True, beta=1, 

161 out=M_nn) 

162 copy(S_NN.data[B:, B:], M_nn) 

163 

164 # <psi2 | S | psi> 

165 me(psit2_nX, psit_nX) 

166 P3_ani.matrix.multiply(P_ani, opb='C', beta=1.0, out=M_nn) 

167 copy(S_NN.data[B:, :B], M_nn) 

168 

169 with tracectx('Diagonalize'): 

170 with broadcast_exception(domain_comm): 

171 with broadcast_exception(band_comm): 

172 if is_domain_band_master: 

173 H_NN.data[:B, :B] = xp.diag(eig_N[:B]) 

174 S_NN.data[:B, :B] = xp.eye(B) 

175 eig_N[:] = H_NN.eigh(S_NN) 

176 wfs._eig_n = as_np(eig_N[:B]) 

177 if domain_comm.rank == 0: 

178 band_comm.broadcast(wfs.eig_n, 0) 

179 domain_comm.broadcast(wfs.eig_n, 0) 

180 

181 if domain_comm.rank == 0: 

182 if band_comm.rank == 0: 

183 M0_nn.data[:] = H_NN.data[:B, :B] 

184 M0_nn.complex_conjugate() 

185 M0_nn.redist(M_nn) 

186 domain_comm.broadcast(M_nn.data, 0) 

187 

188 with tracectx('Rotate Psi'): 

189 M_nn.multiply(psit_nX, out=psit_nX, 

190 data_buffer=data_buffer) 

191 M_nn.multiply(P_ani, out=P3_ani) 

192 

193 if domain_comm.rank == 0: 

194 if band_comm.rank == 0: 

195 M0_nn.data[:] = H_NN.data[:B, B:] 

196 M0_nn.complex_conjugate() 

197 M0_nn.redist(M_nn) 

198 domain_comm.broadcast(M_nn.data, 0) 

199 

200 M_nn.multiply(psit2_nX, beta=1.0, out=psit_nX) 

201 M_nn.multiply(P2_ani, beta=1.0, out=P3_ani) 

202 P_ani, P3_ani = P3_ani, P_ani 

203 wfs._P_ani = P_ani 

204 

205 if i < self.niter - 1: 

206 Ht(psit_nX, out=psit2_nX) 

207 calculate_residuals( 

208 wfs.psit_nX, 

209 psit2_nX, 

210 wfs.pt_aiX, wfs.P_ani, wfs.myeig_n, 

211 dH, dS_aii, P2_ani, P3_ani) 

212 

213 if debug: 

214 psit_nX.sanity_check() 

215 

216 return error 

217 

218 

219def sliced_preconditioner(psit_nX, psit2_nX, buffer, precon): 

220 # Sliced recursive preconditioning 

221 buffer_size = buffer.data.shape[0] 

222 mybands = psit_nX.data.shape[0] 

223 if not mybands == 0: 

224 for i_local in range(0, mybands, buffer_size): 

225 buffer_view = buffer[:mybands - i_local] 

226 precon( 

227 psit_nX[i_local:i_local + buffer_size], 

228 psit2_nX[i_local:i_local + buffer_size], 

229 out=buffer_view) 

230 psit2_nX.data[i_local:i_local + buffer_size] \ 

231 = buffer_view.data[:] 

232 

233 

234def sliced_matrix_elements(psit1_nX, psit2_nX, buffer_mX, Ht, M1_nn, M2_nn): 

235 ''' Method for calculating matrix elements in a sliced manner: 

236 <psi2 | H | psi2> -> M2_nn 

237 <psi2 | H | psi1> -> M1_nn 

238 

239 This function uses less memory than, but is otherwise identical to: 

240 psit3_nX = psit2_nX.new() 

241 psit2_nX.matrix_elements(psit2_nX, 

242 out=M2_nn, 

243 domain_sum=False, 

244 function=partial(Ht, out=psit3_nX), 

245 cc=True) 

246 psit3_nX.matrix_elements(psit1_nX, 

247 out=M_nn, 

248 domain_sum=False, 

249 cc=True) 

250 ''' 

251 comm = psit1_nX.comm 

252 b = psit1_nX.data.shape[0] 

253 blocksize = buffer_mX.data.shape[0] 

254 blocksize_world = comm.sum_scalar(blocksize) 

255 totalbands = comm.sum_scalar(b) 

256 for i1, N1 in enumerate( 

257 range(0, totalbands, blocksize_world)): 

258 n1 = i1 * blocksize 

259 n2 = n1 + blocksize 

260 if n2 > b: 

261 n2 = b 

262 

263 world_N = min(blocksize_world, 

264 totalbands - N1) 

265 

266 buffer_view_aX = buffer_mX.new( 

267 data=buffer_mX.data[:n2 - n1], 

268 dims=(world_N,) + buffer_mX.dims[1:], 

269 ) 

270 Ht(psit2_nX[n1:n2], out=buffer_view_aX) 

271 

272 out1 = Matrix( 

273 M=world_N, 

274 N=M1_nn.shape[1], 

275 data=M1_nn.data[n1:n2, :], 

276 dist=(comm, -1, 1), 

277 xp=M1_nn.xp) 

278 out2 = Matrix( 

279 M=world_N, 

280 N=M2_nn.shape[1], 

281 data=M2_nn.data[n1:n2, :], 

282 dist=(comm, -1, 1), 

283 xp=M2_nn.xp) 

284 buffer_view_aX.matrix_elements(psit1_nX, 

285 out=out1, 

286 symmetric=False, 

287 domain_sum=False, 

288 cc=True) 

289 buffer_view_aX.matrix_elements(psit2_nX, 

290 out=out2, 

291 symmetric=False, 

292 domain_sum=False, 

293 cc=True)