Coverage for gpaw/hybrids/scf.py: 99%

221 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1import numpy as np 

2 

3from gpaw.mpi import serial_comm 

4from gpaw.kpt_descriptor import KPointDescriptor 

5from gpaw.pw.descriptor import PWDescriptor 

6from gpaw.pw.lfc import PWLFC 

7from gpaw.hybrids.kpts import PWKPoint, RSKPoint, to_real_space 

8from gpaw.utilities.blas import mmm 

9 

10 

11def apply1(kpt, Htpsit_xG, wfs, coulomb, sym, paw): 

12 kd = wfs.kd 

13 kpts = [PWKPoint(kpt.psit, 

14 kpt.projections, 

15 kpt.f_n / kpt.weight, # scale to [0, 1] range 

16 kd.ibzk_kc[kpt.k], 

17 kd.weight_k[kpt.k]) 

18 for kpt in wfs.kpt_u[kpt.s::wfs.nspins]] 

19 evv, evc, ekin, v_knG = calculate(kpts, wfs, paw, sym, coulomb) 

20 return evc, evv, ekin, v_knG 

21 

22 

23def calculate(kpts, wfs, paw, sym, coulomb): 

24 pd = kpts[0].psit.pd 

25 gd = pd.gd.new_descriptor(comm=serial_comm) 

26 kd = wfs.kd 

27 comm = wfs.world 

28 nbands = len(kpts[0].psit.array) 

29 shapes = [(nbands, len(Delta_iiL)) 

30 for Delta_iiL in paw.Delta_aiiL] 

31 v_kani = [{a: np.zeros(shape, pd.dtype) 

32 for a, shape in enumerate(shapes)} 

33 for _ in range(len(kpts))] 

34 v_knG = [k.psit.pd.zeros(nbands, global_array=True, q=k.psit.kpt) 

35 for k in kpts] 

36 

37 exxvv = 0.0 

38 ekin = 0.0 

39 for i1, i2, s, k1, k2, count in sym.pairs(kpts, wfs, wfs.spos_ac): 

40 q_c = k2.k_c - k1.k_c 

41 qd = KPointDescriptor([-q_c]) 

42 

43 pd12 = PWDescriptor(pd.ecut, gd, pd.dtype, kd=qd) 

44 ghat = PWLFC([data.ghat_l for data in wfs.setups], pd12) 

45 ghat.set_positions(wfs.spos_ac) 

46 

47 v1_nG = v_knG[i1] 

48 v1_ani = v_kani[i1] 

49 v2_nG = v_knG[i2] 

50 v2_ani = v_kani[i2] 

51 

52 v_G = coulomb.get_potential(pd12) 

53 assert i1 == kpts[i1].psit.kpt 

54 assert i2 == kpts[i2].psit.kpt 

55 e_nn = calculate_exx_for_pair(k1, k2, ghat, v_G, 

56 kpts[i1].psit.pd, 

57 kpts[i2].psit.pd, 

58 kpts[i1].psit.kpt, 

59 kpts[i2].psit.kpt, 

60 k1.f_n, 

61 k2.f_n, 

62 s, 

63 count, 

64 v1_nG, v1_ani, 

65 v2_nG, v2_ani, 

66 wfs, sym, paw) 

67 

68 e_nn *= count 

69 e = k1.f_n.dot(e_nn).dot(k2.f_n) / kd.nbzkpts 

70 exxvv -= 0.5 * e 

71 ekin += e 

72 

73 exxvc = 0.0 

74 for i, kpt in enumerate(kpts): 

75 for a, VV_ii in paw.VV_aii.items(): 

76 P_ni = kpt.proj[a] 

77 vv_n = np.einsum('ni, ij, nj -> n', 

78 P_ni.conj(), VV_ii, P_ni).real 

79 vc_n = np.einsum('ni, ij, nj -> n', 

80 P_ni.conj(), paw.VC_aii[a], P_ni).real 

81 exxvv -= vv_n.dot(kpt.f_n) * kpt.weight 

82 exxvc -= vc_n.dot(kpt.f_n) * kpt.weight 

83 

84 w_knG = {} 

85 G1 = comm.rank * pd.maxmyng 

86 G2 = (comm.rank + 1) * pd.maxmyng 

87 for v_nG, v_ani, kpt in zip(v_knG, v_kani, kpts): 

88 comm.sum(v_nG) 

89 w_nG = v_nG[:, G1:G2].copy() 

90 w_knG[len(w_knG)] = w_nG 

91 for v_ni in v_ani.values(): 

92 comm.sum(v_ni) 

93 v1_ani = {} 

94 for a, VV_ii in paw.VV_aii.items(): 

95 P_ni = kpt.proj[a] 

96 v_ni = P_ni.dot(paw.VC_aii[a] + 2 * VV_ii) 

97 v1_ani[a] = v_ani[a] - v_ni 

98 ekin += (np.einsum('n, ni, ni', 

99 kpt.f_n, P_ni.conj(), v_ni).real * 

100 kpt.weight) 

101 wfs.pt.add(w_nG, v1_ani, kpt.psit.kpt) 

102 

103 return (comm.sum_scalar(exxvv), 

104 comm.sum_scalar(exxvc), 

105 comm.sum_scalar(ekin), 

106 w_knG) 

107 

108 

109def calculate_exx_for_pair(k1, 

110 k2, 

111 ghat, 

112 v_G, 

113 pd1, pd2, 

114 index1, index2, 

115 f1_n, f2_n, 

116 s, 

117 count, 

118 v1_nG, 

119 v1_ani, 

120 v2_nG, 

121 v2_ani, 

122 wfs, 

123 sym, 

124 paw, 

125 F_av=None): 

126 kd = wfs.kd 

127 comm = wfs.world 

128 factor = 1.0 / kd.nbzkpts 

129 

130 N1 = len(k1.u_nR) 

131 N2 = len(k2.u_nR) 

132 

133 size = comm.size 

134 rank = comm.rank 

135 

136 Q_annL = [np.einsum('mi, ijL, nj -> mnL', 

137 k1.proj[a], 

138 Delta_iiL, 

139 k2.proj[a].conj(), 

140 optimize=True) 

141 for a, Delta_iiL in enumerate(paw.Delta_aiiL)] 

142 

143 if v2_nG is not None: 

144 T, T_a, cc = sym.symmetry_operation(s, wfs, inverse=True) 

145 

146 if k1 is k2: 

147 n2max = (N1 + size - 1) // size 

148 else: 

149 n2max = N2 

150 

151 e_nn = np.zeros((N1, N2)) 

152 rho_nG = ghat.pd.empty(n2max, k1.u_nR.dtype) 

153 vrho_nG = ghat.pd.empty(n2max, k1.u_nR.dtype) 

154 

155 f_GI = ghat.expand() 

156 

157 for n1, u1_R in enumerate(k1.u_nR): 

158 if k1 is k2: 

159 B = (N1 - n1 + size - 1) // size 

160 n20 = 0 

161 n2a = min(n1 + rank * B, N2) 

162 n2b = min(n2a + B, N2) 

163 else: 

164 B = (N1 + size - 1) // size 

165 n20 = min(B * rank, N1) 

166 n2a = 0 

167 n2b = N2 

168 

169 for n2, rho_G in enumerate(rho_nG[:n2b - n2a], n2a): 

170 rho_G[:] = ghat.pd.fft(u1_R * k2.u_nR[n2].conj()) 

171 

172 add(ghat, rho_nG[:n2b - n2a], 

173 {a: Q_nnL[n1, n2a:n2b] 

174 for a, Q_nnL in enumerate(Q_annL)}, 

175 f_GI) 

176 for n2, rho_G in enumerate(rho_nG[:n2b - n2a], n2a): 

177 vrho_G = v_G * rho_G 

178 e = ghat.pd.integrate(rho_G, vrho_G).real 

179 e_nn[n1, n2] = e 

180 if k1 is k2: 

181 e_nn[n2, n1] = e 

182 vrho_nG[n2 - n2a] = vrho_G 

183 

184 if v1_nG is not None: 

185 vrho_R = ghat.pd.ifft(vrho_G) 

186 if v2_nG is None: 

187 assert k1 is not k2 

188 v1_nG[n1] -= f2_n[n2] * factor * pd1.fft( 

189 vrho_R * k2.u_nR[n2], index1, local=True) 

190 else: 

191 x = factor * count / 2 

192 if k1 is k2 and n1 != n2: 

193 x *= 2 

194 x1 = x / (kd.weight_k[index1] * kd.nbzkpts) 

195 x2 = x / (kd.weight_k[index2] * kd.nbzkpts) 

196 v1_nG[n1] -= f2_n[n2] * x1 * pd1.fft( 

197 vrho_R * k2.u_nR[n2], index1, local=True) 

198 v2_nG[n2 + n20] -= f1_n[n1] * x2 * pd2.fft( 

199 T(vrho_R.conj() * u1_R), index2, 

200 local=True) 

201 

202 if v1_nG is not None and v2_nG is None: 

203 for a, v_nL in integrate(ghat, vrho_nG[:n2b - n2a], f_GI): 

204 v_iin = paw.Delta_aiiL[a].dot(v_nL.T) 

205 v1_ani[a][n1] -= np.einsum('ijn, nj, n -> i', 

206 v_iin, 

207 k2.proj[a][n2a:n2b], 

208 f2_n[n2a:n2b] * factor) 

209 

210 if v1_nG is not None and v2_nG is not None: 

211 x = factor * count / kd.nbzkpts / 2 

212 x1 = x / kd.weight_k[index1] 

213 x2 = x / kd.weight_k[index2] 

214 if k1 is k2: 

215 x1 *= 2 

216 x2 *= 2 

217 

218 for a, v_nL in integrate(ghat, vrho_nG[:n2b - n2a], f_GI): 

219 if k1 is k2 and n2a <= n1 < n2b: 

220 v_nL[n1 - n2a] *= 0.5 

221 v_iin = paw.Delta_aiiL[a].dot(v_nL.T) 

222 v1_ani[a][n1] -= np.einsum('ijn, nj, n -> i', 

223 v_iin, 

224 k2.proj[a][n2a:n2b], 

225 f2_n[n2a:n2b] * x1) 

226 b, S_c, U_ii = T_a[a] 

227 v_ni = np.einsum('ijn, j, ik -> nk', 

228 v_iin.conj(), 

229 k1.proj[b][n1], 

230 U_ii) 

231 if v_nL.dtype == complex: 

232 v_ni *= np.exp(2j * np.pi * k2.k_c.dot(S_c)) 

233 if cc: 

234 v_ni = v_ni.conj() 

235 v2_ani[a][n20 + n2a:n20 + n2b] -= v_ni * f1_n[n1] * x2 

236 

237 return e_nn * factor 

238 

239 

240def add(ghat, a_xG, c_axi, f_GI): 

241 c_xI = np.empty(a_xG.shape[:-1] + (ghat.nI,), ghat.pd.dtype) 

242 for a, I1, I2 in ghat.my_indices: 

243 c_xI[..., I1:I2] = c_axi[a] * ghat.eikR_qa[0][a].conj() 

244 nx = np.prod(c_xI.shape[:-1], dtype=int) 

245 c_xI = c_xI.reshape((nx, ghat.nI)) 

246 a_xG = a_xG.reshape((nx, a_xG.shape[-1])).view(ghat.pd.dtype) 

247 mmm(1.0 / ghat.pd.gd.dv, c_xI, 'N', f_GI, 'T', 1.0, a_xG) 

248 

249 

250def integrate(ghat, a_xG, f_GI): 

251 c_xI = np.zeros(a_xG.shape[:-1] + (ghat.nI,), ghat.pd.dtype) 

252 

253 nx = np.prod(c_xI.shape[:-1], dtype=int) 

254 b_xI = c_xI.reshape((nx, ghat.nI)) 

255 a_xG = a_xG.reshape((nx, a_xG.shape[-1])) 

256 

257 alpha = 1.0 / ghat.pd.gd.N_c.prod() 

258 if ghat.pd.dtype == float: 

259 alpha *= 2 

260 a_xG = a_xG.view(float) 

261 f_GI[0] *= 0.5 

262 else: 

263 f_GI.imag[:] = -f_GI.imag 

264 mmm(alpha, a_xG, 'N', f_GI, 'N', 0.0, b_xI) 

265 if ghat.pd.dtype == complex: 

266 f_GI.imag[:] = -f_GI.imag 

267 else: 

268 f_GI[0] *= 2.0 

269 for a, I1, I2 in ghat.my_indices: 

270 yield a, ghat.eikR_qa[0][a] * c_xI[..., I1:I2] 

271 

272 

273def apply2(kpt, psit_xG, Htpsit_xG, wfs, coulomb, sym, paw): 

274 kd = wfs.kd 

275 

276 psit = kpt.psit.new(buf=psit_xG) 

277 P = kpt.projections.new() 

278 psit.matrix_elements(wfs.pt, out=P) 

279 

280 kpt1 = PWKPoint(psit, 

281 P, 

282 kpt.f_n + np.nan, 

283 kd.ibzk_kc[kpt.k], 

284 np.nan) 

285 

286 kpts2 = [PWKPoint(kpt.psit, 

287 kpt.projections, 

288 kpt.f_n / kpt.weight, # scale to [0, 1] range 

289 kd.ibzk_kc[kpt.k], 

290 kd.weight_k[kpt.k]) 

291 for kpt in wfs.kpt_u[kpt.s::wfs.nspins]] 

292 v_nG = calculate2(kpt1, kpts2, wfs, paw, sym, coulomb) 

293 return v_nG 

294 

295 

296def calculate2(kpt1, kpts2, wfs, paw, sym, coulomb): 

297 pd = kpt1.psit.pd 

298 gd = pd.gd.new_descriptor(comm=serial_comm) 

299 kd = wfs.kd 

300 comm = wfs.world 

301 nbands = len(kpt1.psit.array) 

302 shapes = [(nbands, len(Delta_iiL)) 

303 for Delta_iiL in paw.Delta_aiiL] 

304 v_ani = {a: np.zeros(shape, pd.dtype) 

305 for a, shape in enumerate(shapes)} 

306 v_nG = kpt1.psit.pd.zeros(nbands, global_array=True, q=kpt1.psit.kpt) 

307 

308 u1_nR = to_real_space(kpt1.psit) 

309 proj1 = kpt1.proj.broadcast() 

310 k1 = RSKPoint(u1_nR, 

311 proj1, 

312 kpt1.f_n, 

313 kpt1.k_c, 

314 kpt1.weight) 

315 

316 N2 = len(kpts2[0].psit.array) 

317 nsym = len(kd.symmetry.op_scc) 

318 

319 size = comm.size 

320 rank = comm.rank 

321 B = (N2 + size - 1) // size 

322 na = min(B * rank, N2) 

323 nb = min(na + B, N2) 

324 for i2, kpt2 in enumerate(kpts2): 

325 u2_nR = to_real_space(kpt2.psit, na, nb) 

326 k0 = RSKPoint(u2_nR, 

327 kpt2.proj.broadcast().view(na, nb), 

328 kpt2.f_n[na:nb], 

329 kpt2.k_c, 

330 kpt2.weight) 

331 for k, i in enumerate(kd.bz2ibz_k): 

332 if i != i2: 

333 continue 

334 s = kd.sym_k[k] + kd.time_reversal_k[k] * nsym 

335 k2 = sym.apply_symmetry(s, k0, wfs, wfs.spos_ac) 

336 q_c = k2.k_c - k1.k_c 

337 qd = KPointDescriptor([-q_c]) 

338 

339 pd12 = PWDescriptor(pd.ecut, gd, pd.dtype, kd=qd) 

340 ghat = PWLFC([data.ghat_l for data in wfs.setups], pd12) 

341 ghat.set_positions(wfs.spos_ac) 

342 

343 v_G = coulomb.get_potential(pd12) 

344 calculate_exx_for_pair(k1, k2, ghat, v_G, 

345 kpt1.psit.pd, 

346 kpts2[i2].psit.pd, 

347 kpt1.psit.kpt, 

348 kpts2[i2].psit.kpt, 

349 k1.f_n, 

350 k2.f_n, 

351 s, 

352 1.0, 

353 v_nG, v_ani, 

354 None, None, 

355 wfs, sym, paw) 

356 

357 G1 = comm.rank * pd.maxmyng 

358 G2 = (comm.rank + 1) * pd.maxmyng 

359 comm.sum(v_nG) 

360 w_nG = v_nG[:, G1:G2].copy() 

361 for v_ni in v_ani.values(): 

362 comm.sum(v_ni) 

363 v1_ani = {} 

364 for a, VV_ii in paw.VV_aii.items(): 

365 P_ni = kpt1.proj[a] 

366 v_ni = P_ni.dot(paw.VC_aii[a] + 2 * VV_ii) 

367 v1_ani[a] = v_ani[a] - v_ni 

368 wfs.pt.add(w_nG, v1_ani, kpt1.psit.kpt) 

369 

370 return w_nG