Coverage for gpaw/new/pw/hybrids.py: 24%
207 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1from __future__ import annotations
3from dataclasses import dataclass
4from functools import cached_property
5from math import pi, nan
7import numpy as np
8from gpaw.core import PWArray, PWDesc, UGArray, UGDesc
9from gpaw.core.arrays import DistributedArrays as XArray
10from gpaw.core.atom_arrays import AtomArrays
11from gpaw.hybrids.paw import pawexxvv
12from gpaw.hybrids.wstc import WignerSeitzTruncatedCoulomb
13from gpaw.new import zips
14from gpaw.new.ibzwfs import IBZWaveFunctions
15from gpaw.new.pw.hamiltonian import PWHamiltonian
16from gpaw.typing import Array1D
17from gpaw.utilities import unpack_hermitian
18from gpaw.utilities.blas import mmm
21def coulomb(pw: PWDesc,
22 grid: UGDesc,
23 omega: float,
24 yukawa: bool = False) -> PWArray:
25 if omega == 0.0:
26 wstc = WignerSeitzTruncatedCoulomb(
27 pw.cell_cv, np.array([1, 1, 1]))
28 return wstc.get_potential_new(pw, grid)
29 return truncated_coulomb(pw, omega, yukawa)
32def truncated_coulomb(pw: PWDesc,
33 omega: float = 0.11,
34 yukawa: bool = False) -> PWArray:
35 """Fourier transform of truncated Coulomb.
37 Real space:::
39 erfc(ωr)
40 --------.
41 r
43 Reciprocal space:::
45 4π _ _ 2 2
46 ------(1 - exp(-(G+k) /(4 ω )))
47 _ _ 2
48 (G+k)
50 (G+k=0 limit is pi/ω^2).
51 """
52 v_G = pw.empty()
53 G2_G = pw.ekin_G * 2
54 if yukawa:
55 v_G.data[:] = 4 * pi / (G2_G + omega**2)
56 else:
57 v_G.data[:] = 4 * pi * (1 - np.exp(-G2_G / (4 * omega**2)))
58 ok_G = G2_G > 1e-10
59 v_G.data[ok_G] /= G2_G[ok_G]
60 v_G.data[~ok_G] = pi / omega**2
61 return v_G
64@dataclass
65class Psi:
66 psit_nG: PWArray
67 P_ani: AtomArrays
68 f_n: Array1D | None = None
69 psit_nR: UGArray | None = None
71 def empty(self):
72 return Psi(self.psit_nG.new(),
73 self.P_ani.new(),
74 np.empty_like(self.f_n))
76 @cached_property
77 def comm(self):
78 return self.psit_nG.comm
80 def send(self, rank):
81 self.requests = [self.comm.send(self.psit_nG.data, rank, block=False),
82 self.comm.send(self.P_ani.data, rank, block=False),
83 self.comm.send(self.f_n, rank, block=False)]
85 def receive(self, rank):
86 self.requests = [
87 self.comm.receive(self.psit_nG.data, rank, block=False),
88 self.comm.receive(self.P_ani.data, rank, block=False),
89 self.comm.receive(self.f_n, rank, block=False)]
91 def wait(self):
92 comm = self.psit_nG.comm
93 comm.waitall(self.requests)
96class PWHybridHamiltonian(PWHamiltonian):
97 band_local = False
99 def __init__(self,
100 grid: UGDesc,
101 pw: PWDesc,
102 xc,
103 setups,
104 relpos_ac,
105 atomdist,
106 comp_charge_in_real_space: bool = False):
107 super().__init__(grid, pw)
108 self.comp_charge_in_real_space = comp_charge_in_real_space
109 self.pw = pw
110 self.exx_fraction = xc.exx_fraction
111 self.exx_omega = xc.exx_omega
112 self.exx_yukawa = xc.exx_yukawa
113 self.xc = xc
115 # Stuff for PAW core-core, core-valence and valence-valence correctios:
116 self.exx_cc = sum(setup.ExxC for setup in setups) * self.exx_fraction
117 self.VC_aii = [unpack_hermitian(setup.X_p * self.exx_fraction)
118 for setup in setups]
119 self.delta_aiiL = [setup.Delta_iiL for setup in setups]
120 self.VV_app = [setup.M_pp * self.exx_fraction for setup in setups]
122 self.v_G = coulomb(pw, grid, self.exx_omega)
123 self.v_G.data *= self.exx_fraction
125 desc = grid if comp_charge_in_real_space else pw
127 self.ghat_aLX = setups.create_compensation_charges(
128 desc, relpos_ac, atomdist)
129 if not comp_charge_in_real_space:
130 self.ghat_aLX._lazy_init()
131 self.ghat_GA = self.ghat_aLX._lfc.expand()
132 else:
133 self.ghat_GA = None
134 # self.plan = grid.fft_plans()
136 def apply_orbital_dependent(self,
137 ibzwfs: IBZWaveFunctions,
138 D_asii,
139 psit2_nG: XArray,
140 spin: int,
141 Htpsit2_nG: XArray) -> None:
142 assert isinstance(psit2_nG, PWArray)
143 assert isinstance(Htpsit2_nG, PWArray)
144 wfs = ibzwfs.wfs_qs[0][spin]
145 D_aii = D_asii[:, spin].copy()
146 if ibzwfs.nspins == 1:
147 D_aii = D_aii.copy()
148 D_aii.data *= 0.5
149 psi1 = Psi(wfs.psit_nX, wfs.P_ani, wfs.myocc_n)
150 pt_aiG = wfs.pt_aiX
152 # We should pass a flag instead of this:
153 if psi1.psit_nG.data is psit2_nG.data:
154 # We are doing a subspace diagonalization ...
155 evv, evc, ekin = self.apply1(D_aii, pt_aiG,
156 psi1, psi1, Htpsit2_nG)
157 for name, e in [('hybrid_xc', evv + evc),
158 ('hybrid_kinetic_correction', ekin)]:
159 e *= ibzwfs.spin_degeneracy
160 if spin == 0:
161 self.xc.energies[name] = e
162 else:
163 self.xc.energies[name] += e
164 self.xc.energies['hybrid_xc'] += self.exx_cc
165 return
167 # We are applying the exchange operator (defined by psit1_nG,
168 # P1_ani, f1_n and D_aii) to another set of wave functions
169 # (psit2_nG):
170 psi2 = Psi(psit2_nG, pt_aiG.integrate(psit2_nG))
171 self.apply1(D_aii, pt_aiG, psi1, psi2, Htpsit2_nG)
173 def apply1(self,
174 D_aii,
175 pt_aiG,
176 psi1: Psi,
177 psi2: Psi,
178 Htpsit_nG: PWArray) -> tuple[float, float, float]:
179 comm = Htpsit_nG.comm
180 mynbands1 = psi1.psit_nG.mydims[0]
181 mynbands2 = psi2.psit_nG.mydims[0]
182 same = psi1 is psi2
183 evv = 0.0
184 evc = 0.0
185 ekin = 0.0
186 B_ani = {}
187 for a, D_ii in D_aii.items():
188 VV_ii = pawexxvv(self.VV_app[a], D_ii)
189 VC_ii = self.VC_aii[a]
190 V_ii = -VC_ii - 2 * VV_ii
191 B_ani[a] = psi2.P_ani[a] @ V_ii
192 if same:
193 ec = (D_ii * VC_ii).sum()
194 ev = (D_ii * VV_ii).sum()
195 ekin += ec + 2 * ev
196 evv -= ev
197 evc -= ec
199 Q_anL = self.ghat_aLX.empty(mynbands1)
200 Q_nA = Q_anL.data
201 assert Q_nA.shape == (mynbands1,
202 sum(delta_iiL.shape[2]
203 for delta_iiL in self.delta_aiiL))
204 assert Q_nA.dtype == self.pw.dtype
206 rhot_nR = self.grid_local.empty(mynbands1)
207 rhot_nG = self.pw.empty(mynbands1)
208 vrhot_G = self.pw.empty()
210 if psi1 is not psi2 or comm.size > 1:
211 psit1_nR = self.grid_local.empty(mynbands1)
212 else:
213 psit1_nR = None
215 e = 0.0
216 for p in range(comm.size):
217 if p < comm.size - 1:
218 psi1.send((comm.rank + 1) % comm.size)
219 if p == 0:
220 psi = psi1.empty()
221 psi.receive((comm.rank - 1) % comm.size)
222 if p == 0:
223 psi2.psit_nR = self.grid_local.empty(mynbands2)
224 ifft(psi2.psit_nG, psi2.psit_nR, self.plan)
225 e += self.inner(psi1, psi2,
226 Q_anL,
227 psit1_nR,
228 rhot_nG, rhot_nR, vrhot_G,
229 Htpsit_nG, B_ani)
230 if p < comm.size - 1:
231 psi.wait()
232 psi1.wait()
233 if p == 0:
234 psi1 = psi
235 psi = psi1.empty()
236 else:
237 psi1, psi = psi, psi1
239 pt_aiG.add_to(Htpsit_nG, B_ani)
241 if same:
242 e = comm.sum_scalar(e)
243 evv -= 0.5 * e
244 ekin += e
245 return evv, evc, ekin
247 return nan, nan, nan
249 def inner(self, psi1, psi2,
250 Q_anL,
251 psit1_nR,
252 rhot_nG, rhot_nR, vrhot_G,
253 Htpsit_nG, B_ani):
254 Q1_aniL = {a: np.einsum('ijL, nj -> niL',
255 delta_iiL, psi1.P_ani[a])
256 for a, delta_iiL in enumerate(self.delta_aiiL)}
258 if psi1 is psi2:
259 psit1_nR = psi2.psit_nR
260 else:
261 ifft(psi1.psit_nG, psit1_nR, self.plan)
263 e = 0.0
264 for n2, (psit2_R, out_G) in enumerate(zips(psi2.psit_nR, Htpsit_nG)):
265 rhot_nR.data[:] = psit1_nR.data * psit2_R.data.conj()
266 for a, Q1_niL in Q1_aniL.items():
267 P2_i = psi2.P_ani[a][n2]
268 Q_anL[a][:] = P2_i.conj() @ Q1_niL
269 e += self.inner2(
270 psi1, psi2,
271 rhot_nR, rhot_nG,
272 vrhot_G,
273 Q_anL, Q1_aniL, B_ani, n2)
274 rhot_nR.data *= psit1_nR.data
275 fft(rhot_nR, rhot_nG, self.plan)
276 out_G.data -= psi1.f_n @ rhot_nG.data
277 return e
279 def inner2(self,
280 psi1, psi2,
281 rhot_nR, rhot_nG,
282 vrhot_G,
283 Q_anL, Q1_aniL, B_ani, n2) -> float:
284 if self.comp_charge_in_real_space:
285 return self.inner2_real_space(psi1, psi2,
286 rhot_nR, rhot_nG,
287 vrhot_G,
288 Q_anL, Q1_aniL, B_ani, n2)
289 fft(rhot_nR, rhot_nG, plan=self.plan)
290 if self.pw.dtype == float:
291 # Note that G runs over
292 # G0.real, G0.imag, G1.real, G1.imag, ...
293 mmm(1.0 / self.pw.dv, Q_anL.data, 'N', self.ghat_GA, 'T',
294 1.0, rhot_nG.data.view(float))
295 else:
296 mmm(1.0 / self.pw.dv, Q_anL.data, 'N', self.ghat_GA, 'T',
297 1.0, rhot_nG.data)
299 e = 0.0
300 for n1, (rhot_R, rhot_G, f1) in enumerate(zips(rhot_nR,
301 rhot_nG,
302 psi1.f_n)):
303 vrhot_G.data = rhot_G.data * self.v_G.data
304 if psi2.f_n is not None:
305 e += f1 * psi2.f_n[n2] * rhot_G.integrate(vrhot_G).real
306 rhot_G.data[:] = vrhot_G.data
308 if self.pw.dtype == float:
309 vrhot_G.data[0] *= 0.5
310 A1_A = vrhot_G.data.view(float) @ self.ghat_GA * 2.0
311 else:
312 A1_A = vrhot_G.data @ self.ghat_GA
313 A1 = 0
314 for a, Q1_niL in Q1_aniL.items():
315 A2 = A1 + Q1_niL.shape[2]
316 B_ani[a][n2] -= Q1_niL[n1] @ (f1 * A1_A[A1:A2])
317 A1 = A2
318 ifft(rhot_nG, rhot_nR, plan=self.plan)
319 return e
321 def inner2_real_space(self,
322 psi1, psi2,
323 rhot_nR, rhot_nG,
324 vrhot_G,
325 Q_anL, Q1_aniL, B_ani, n2) -> float:
326 self.ghat_aLX.add_to(rhot_nR, Q_anL)
327 fft(rhot_nR, rhot_nG, plan=self.plan)
328 e = 0.0
329 for n1, (rhot_R, rhot_G, f1) in enumerate(zips(rhot_nR,
330 rhot_nG,
331 psi1.f_n)):
332 vrhot_G.data = rhot_G.data * self.v_G.data
333 if psi2.f_n is not None:
334 e += f1 * psi2.f_n[n2] * rhot_G.integrate(vrhot_G).real
335 rhot_G.data[:] = vrhot_G.data
337 ifft(rhot_nG, rhot_nR, plan=self.plan)
339 A1_anL = self.ghat_aLX.integrate(rhot_nR)
340 for a, Q1_niL in Q1_aniL.items():
341 B_ani[a][n2] -= np.einsum('niL, n, nL -> i',
342 Q1_niL, psi1.f_n, A1_anL[a])
343 return e
346def ifft(psit_nG, out_nR, plan):
347 for psit_G, out_R in zips(psit_nG, out_nR):
348 psit_G.ifft(out=out_R, plan=plan)
351def fft(rhot_nR, rhot_nG, plan):
352 for rhot_R, rhot_G in zips(rhot_nR, rhot_nG):
353 rhot_R.fft(out=rhot_G, plan=plan)