Coverage for gpaw/new/pw/hybrids.py: 24%

207 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1from __future__ import annotations 

2 

3from dataclasses import dataclass 

4from functools import cached_property 

5from math import pi, nan 

6 

7import numpy as np 

8from gpaw.core import PWArray, PWDesc, UGArray, UGDesc 

9from gpaw.core.arrays import DistributedArrays as XArray 

10from gpaw.core.atom_arrays import AtomArrays 

11from gpaw.hybrids.paw import pawexxvv 

12from gpaw.hybrids.wstc import WignerSeitzTruncatedCoulomb 

13from gpaw.new import zips 

14from gpaw.new.ibzwfs import IBZWaveFunctions 

15from gpaw.new.pw.hamiltonian import PWHamiltonian 

16from gpaw.typing import Array1D 

17from gpaw.utilities import unpack_hermitian 

18from gpaw.utilities.blas import mmm 

19 

20 

21def coulomb(pw: PWDesc, 

22 grid: UGDesc, 

23 omega: float, 

24 yukawa: bool = False) -> PWArray: 

25 if omega == 0.0: 

26 wstc = WignerSeitzTruncatedCoulomb( 

27 pw.cell_cv, np.array([1, 1, 1])) 

28 return wstc.get_potential_new(pw, grid) 

29 return truncated_coulomb(pw, omega, yukawa) 

30 

31 

32def truncated_coulomb(pw: PWDesc, 

33 omega: float = 0.11, 

34 yukawa: bool = False) -> PWArray: 

35 """Fourier transform of truncated Coulomb. 

36 

37 Real space::: 

38 

39 erfc(ωr) 

40 --------. 

41 r 

42 

43 Reciprocal space::: 

44 

45 4π _ _ 2 2 

46 ------(1 - exp(-(G+k) /(4 ω ))) 

47 _ _ 2 

48 (G+k) 

49 

50 (G+k=0 limit is pi/ω^2). 

51 """ 

52 v_G = pw.empty() 

53 G2_G = pw.ekin_G * 2 

54 if yukawa: 

55 v_G.data[:] = 4 * pi / (G2_G + omega**2) 

56 else: 

57 v_G.data[:] = 4 * pi * (1 - np.exp(-G2_G / (4 * omega**2))) 

58 ok_G = G2_G > 1e-10 

59 v_G.data[ok_G] /= G2_G[ok_G] 

60 v_G.data[~ok_G] = pi / omega**2 

61 return v_G 

62 

63 

64@dataclass 

65class Psi: 

66 psit_nG: PWArray 

67 P_ani: AtomArrays 

68 f_n: Array1D | None = None 

69 psit_nR: UGArray | None = None 

70 

71 def empty(self): 

72 return Psi(self.psit_nG.new(), 

73 self.P_ani.new(), 

74 np.empty_like(self.f_n)) 

75 

76 @cached_property 

77 def comm(self): 

78 return self.psit_nG.comm 

79 

80 def send(self, rank): 

81 self.requests = [self.comm.send(self.psit_nG.data, rank, block=False), 

82 self.comm.send(self.P_ani.data, rank, block=False), 

83 self.comm.send(self.f_n, rank, block=False)] 

84 

85 def receive(self, rank): 

86 self.requests = [ 

87 self.comm.receive(self.psit_nG.data, rank, block=False), 

88 self.comm.receive(self.P_ani.data, rank, block=False), 

89 self.comm.receive(self.f_n, rank, block=False)] 

90 

91 def wait(self): 

92 comm = self.psit_nG.comm 

93 comm.waitall(self.requests) 

94 

95 

96class PWHybridHamiltonian(PWHamiltonian): 

97 band_local = False 

98 

99 def __init__(self, 

100 grid: UGDesc, 

101 pw: PWDesc, 

102 xc, 

103 setups, 

104 relpos_ac, 

105 atomdist, 

106 comp_charge_in_real_space: bool = False): 

107 super().__init__(grid, pw) 

108 self.comp_charge_in_real_space = comp_charge_in_real_space 

109 self.pw = pw 

110 self.exx_fraction = xc.exx_fraction 

111 self.exx_omega = xc.exx_omega 

112 self.exx_yukawa = xc.exx_yukawa 

113 self.xc = xc 

114 

115 # Stuff for PAW core-core, core-valence and valence-valence correctios: 

116 self.exx_cc = sum(setup.ExxC for setup in setups) * self.exx_fraction 

117 self.VC_aii = [unpack_hermitian(setup.X_p * self.exx_fraction) 

118 for setup in setups] 

119 self.delta_aiiL = [setup.Delta_iiL for setup in setups] 

120 self.VV_app = [setup.M_pp * self.exx_fraction for setup in setups] 

121 

122 self.v_G = coulomb(pw, grid, self.exx_omega) 

123 self.v_G.data *= self.exx_fraction 

124 

125 desc = grid if comp_charge_in_real_space else pw 

126 

127 self.ghat_aLX = setups.create_compensation_charges( 

128 desc, relpos_ac, atomdist) 

129 if not comp_charge_in_real_space: 

130 self.ghat_aLX._lazy_init() 

131 self.ghat_GA = self.ghat_aLX._lfc.expand() 

132 else: 

133 self.ghat_GA = None 

134 # self.plan = grid.fft_plans() 

135 

136 def apply_orbital_dependent(self, 

137 ibzwfs: IBZWaveFunctions, 

138 D_asii, 

139 psit2_nG: XArray, 

140 spin: int, 

141 Htpsit2_nG: XArray) -> None: 

142 assert isinstance(psit2_nG, PWArray) 

143 assert isinstance(Htpsit2_nG, PWArray) 

144 wfs = ibzwfs.wfs_qs[0][spin] 

145 D_aii = D_asii[:, spin].copy() 

146 if ibzwfs.nspins == 1: 

147 D_aii = D_aii.copy() 

148 D_aii.data *= 0.5 

149 psi1 = Psi(wfs.psit_nX, wfs.P_ani, wfs.myocc_n) 

150 pt_aiG = wfs.pt_aiX 

151 

152 # We should pass a flag instead of this: 

153 if psi1.psit_nG.data is psit2_nG.data: 

154 # We are doing a subspace diagonalization ... 

155 evv, evc, ekin = self.apply1(D_aii, pt_aiG, 

156 psi1, psi1, Htpsit2_nG) 

157 for name, e in [('hybrid_xc', evv + evc), 

158 ('hybrid_kinetic_correction', ekin)]: 

159 e *= ibzwfs.spin_degeneracy 

160 if spin == 0: 

161 self.xc.energies[name] = e 

162 else: 

163 self.xc.energies[name] += e 

164 self.xc.energies['hybrid_xc'] += self.exx_cc 

165 return 

166 

167 # We are applying the exchange operator (defined by psit1_nG, 

168 # P1_ani, f1_n and D_aii) to another set of wave functions 

169 # (psit2_nG): 

170 psi2 = Psi(psit2_nG, pt_aiG.integrate(psit2_nG)) 

171 self.apply1(D_aii, pt_aiG, psi1, psi2, Htpsit2_nG) 

172 

173 def apply1(self, 

174 D_aii, 

175 pt_aiG, 

176 psi1: Psi, 

177 psi2: Psi, 

178 Htpsit_nG: PWArray) -> tuple[float, float, float]: 

179 comm = Htpsit_nG.comm 

180 mynbands1 = psi1.psit_nG.mydims[0] 

181 mynbands2 = psi2.psit_nG.mydims[0] 

182 same = psi1 is psi2 

183 evv = 0.0 

184 evc = 0.0 

185 ekin = 0.0 

186 B_ani = {} 

187 for a, D_ii in D_aii.items(): 

188 VV_ii = pawexxvv(self.VV_app[a], D_ii) 

189 VC_ii = self.VC_aii[a] 

190 V_ii = -VC_ii - 2 * VV_ii 

191 B_ani[a] = psi2.P_ani[a] @ V_ii 

192 if same: 

193 ec = (D_ii * VC_ii).sum() 

194 ev = (D_ii * VV_ii).sum() 

195 ekin += ec + 2 * ev 

196 evv -= ev 

197 evc -= ec 

198 

199 Q_anL = self.ghat_aLX.empty(mynbands1) 

200 Q_nA = Q_anL.data 

201 assert Q_nA.shape == (mynbands1, 

202 sum(delta_iiL.shape[2] 

203 for delta_iiL in self.delta_aiiL)) 

204 assert Q_nA.dtype == self.pw.dtype 

205 

206 rhot_nR = self.grid_local.empty(mynbands1) 

207 rhot_nG = self.pw.empty(mynbands1) 

208 vrhot_G = self.pw.empty() 

209 

210 if psi1 is not psi2 or comm.size > 1: 

211 psit1_nR = self.grid_local.empty(mynbands1) 

212 else: 

213 psit1_nR = None 

214 

215 e = 0.0 

216 for p in range(comm.size): 

217 if p < comm.size - 1: 

218 psi1.send((comm.rank + 1) % comm.size) 

219 if p == 0: 

220 psi = psi1.empty() 

221 psi.receive((comm.rank - 1) % comm.size) 

222 if p == 0: 

223 psi2.psit_nR = self.grid_local.empty(mynbands2) 

224 ifft(psi2.psit_nG, psi2.psit_nR, self.plan) 

225 e += self.inner(psi1, psi2, 

226 Q_anL, 

227 psit1_nR, 

228 rhot_nG, rhot_nR, vrhot_G, 

229 Htpsit_nG, B_ani) 

230 if p < comm.size - 1: 

231 psi.wait() 

232 psi1.wait() 

233 if p == 0: 

234 psi1 = psi 

235 psi = psi1.empty() 

236 else: 

237 psi1, psi = psi, psi1 

238 

239 pt_aiG.add_to(Htpsit_nG, B_ani) 

240 

241 if same: 

242 e = comm.sum_scalar(e) 

243 evv -= 0.5 * e 

244 ekin += e 

245 return evv, evc, ekin 

246 

247 return nan, nan, nan 

248 

249 def inner(self, psi1, psi2, 

250 Q_anL, 

251 psit1_nR, 

252 rhot_nG, rhot_nR, vrhot_G, 

253 Htpsit_nG, B_ani): 

254 Q1_aniL = {a: np.einsum('ijL, nj -> niL', 

255 delta_iiL, psi1.P_ani[a]) 

256 for a, delta_iiL in enumerate(self.delta_aiiL)} 

257 

258 if psi1 is psi2: 

259 psit1_nR = psi2.psit_nR 

260 else: 

261 ifft(psi1.psit_nG, psit1_nR, self.plan) 

262 

263 e = 0.0 

264 for n2, (psit2_R, out_G) in enumerate(zips(psi2.psit_nR, Htpsit_nG)): 

265 rhot_nR.data[:] = psit1_nR.data * psit2_R.data.conj() 

266 for a, Q1_niL in Q1_aniL.items(): 

267 P2_i = psi2.P_ani[a][n2] 

268 Q_anL[a][:] = P2_i.conj() @ Q1_niL 

269 e += self.inner2( 

270 psi1, psi2, 

271 rhot_nR, rhot_nG, 

272 vrhot_G, 

273 Q_anL, Q1_aniL, B_ani, n2) 

274 rhot_nR.data *= psit1_nR.data 

275 fft(rhot_nR, rhot_nG, self.plan) 

276 out_G.data -= psi1.f_n @ rhot_nG.data 

277 return e 

278 

279 def inner2(self, 

280 psi1, psi2, 

281 rhot_nR, rhot_nG, 

282 vrhot_G, 

283 Q_anL, Q1_aniL, B_ani, n2) -> float: 

284 if self.comp_charge_in_real_space: 

285 return self.inner2_real_space(psi1, psi2, 

286 rhot_nR, rhot_nG, 

287 vrhot_G, 

288 Q_anL, Q1_aniL, B_ani, n2) 

289 fft(rhot_nR, rhot_nG, plan=self.plan) 

290 if self.pw.dtype == float: 

291 # Note that G runs over 

292 # G0.real, G0.imag, G1.real, G1.imag, ... 

293 mmm(1.0 / self.pw.dv, Q_anL.data, 'N', self.ghat_GA, 'T', 

294 1.0, rhot_nG.data.view(float)) 

295 else: 

296 mmm(1.0 / self.pw.dv, Q_anL.data, 'N', self.ghat_GA, 'T', 

297 1.0, rhot_nG.data) 

298 

299 e = 0.0 

300 for n1, (rhot_R, rhot_G, f1) in enumerate(zips(rhot_nR, 

301 rhot_nG, 

302 psi1.f_n)): 

303 vrhot_G.data = rhot_G.data * self.v_G.data 

304 if psi2.f_n is not None: 

305 e += f1 * psi2.f_n[n2] * rhot_G.integrate(vrhot_G).real 

306 rhot_G.data[:] = vrhot_G.data 

307 

308 if self.pw.dtype == float: 

309 vrhot_G.data[0] *= 0.5 

310 A1_A = vrhot_G.data.view(float) @ self.ghat_GA * 2.0 

311 else: 

312 A1_A = vrhot_G.data @ self.ghat_GA 

313 A1 = 0 

314 for a, Q1_niL in Q1_aniL.items(): 

315 A2 = A1 + Q1_niL.shape[2] 

316 B_ani[a][n2] -= Q1_niL[n1] @ (f1 * A1_A[A1:A2]) 

317 A1 = A2 

318 ifft(rhot_nG, rhot_nR, plan=self.plan) 

319 return e 

320 

321 def inner2_real_space(self, 

322 psi1, psi2, 

323 rhot_nR, rhot_nG, 

324 vrhot_G, 

325 Q_anL, Q1_aniL, B_ani, n2) -> float: 

326 self.ghat_aLX.add_to(rhot_nR, Q_anL) 

327 fft(rhot_nR, rhot_nG, plan=self.plan) 

328 e = 0.0 

329 for n1, (rhot_R, rhot_G, f1) in enumerate(zips(rhot_nR, 

330 rhot_nG, 

331 psi1.f_n)): 

332 vrhot_G.data = rhot_G.data * self.v_G.data 

333 if psi2.f_n is not None: 

334 e += f1 * psi2.f_n[n2] * rhot_G.integrate(vrhot_G).real 

335 rhot_G.data[:] = vrhot_G.data 

336 

337 ifft(rhot_nG, rhot_nR, plan=self.plan) 

338 

339 A1_anL = self.ghat_aLX.integrate(rhot_nR) 

340 for a, Q1_niL in Q1_aniL.items(): 

341 B_ani[a][n2] -= np.einsum('niL, n, nL -> i', 

342 Q1_niL, psi1.f_n, A1_anL[a]) 

343 return e 

344 

345 

346def ifft(psit_nG, out_nR, plan): 

347 for psit_G, out_R in zips(psit_nG, out_nR): 

348 psit_G.ifft(out=out_R, plan=plan) 

349 

350 

351def fft(rhot_nR, rhot_nG, plan): 

352 for rhot_R, rhot_G in zips(rhot_nR, rhot_nG): 

353 rhot_R.fft(out=rhot_G, plan=plan)