Coverage for gpaw/hybrids/scf.py: 99%
221 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1import numpy as np
3from gpaw.mpi import serial_comm
4from gpaw.kpt_descriptor import KPointDescriptor
5from gpaw.pw.descriptor import PWDescriptor
6from gpaw.pw.lfc import PWLFC
7from gpaw.hybrids.kpts import PWKPoint, RSKPoint, to_real_space
8from gpaw.utilities.blas import mmm
11def apply1(kpt, Htpsit_xG, wfs, coulomb, sym, paw):
12 kd = wfs.kd
13 kpts = [PWKPoint(kpt.psit,
14 kpt.projections,
15 kpt.f_n / kpt.weight, # scale to [0, 1] range
16 kd.ibzk_kc[kpt.k],
17 kd.weight_k[kpt.k])
18 for kpt in wfs.kpt_u[kpt.s::wfs.nspins]]
19 evv, evc, ekin, v_knG = calculate(kpts, wfs, paw, sym, coulomb)
20 return evc, evv, ekin, v_knG
23def calculate(kpts, wfs, paw, sym, coulomb):
24 pd = kpts[0].psit.pd
25 gd = pd.gd.new_descriptor(comm=serial_comm)
26 kd = wfs.kd
27 comm = wfs.world
28 nbands = len(kpts[0].psit.array)
29 shapes = [(nbands, len(Delta_iiL))
30 for Delta_iiL in paw.Delta_aiiL]
31 v_kani = [{a: np.zeros(shape, pd.dtype)
32 for a, shape in enumerate(shapes)}
33 for _ in range(len(kpts))]
34 v_knG = [k.psit.pd.zeros(nbands, global_array=True, q=k.psit.kpt)
35 for k in kpts]
37 exxvv = 0.0
38 ekin = 0.0
39 for i1, i2, s, k1, k2, count in sym.pairs(kpts, wfs, wfs.spos_ac):
40 q_c = k2.k_c - k1.k_c
41 qd = KPointDescriptor([-q_c])
43 pd12 = PWDescriptor(pd.ecut, gd, pd.dtype, kd=qd)
44 ghat = PWLFC([data.ghat_l for data in wfs.setups], pd12)
45 ghat.set_positions(wfs.spos_ac)
47 v1_nG = v_knG[i1]
48 v1_ani = v_kani[i1]
49 v2_nG = v_knG[i2]
50 v2_ani = v_kani[i2]
52 v_G = coulomb.get_potential(pd12)
53 assert i1 == kpts[i1].psit.kpt
54 assert i2 == kpts[i2].psit.kpt
55 e_nn = calculate_exx_for_pair(k1, k2, ghat, v_G,
56 kpts[i1].psit.pd,
57 kpts[i2].psit.pd,
58 kpts[i1].psit.kpt,
59 kpts[i2].psit.kpt,
60 k1.f_n,
61 k2.f_n,
62 s,
63 count,
64 v1_nG, v1_ani,
65 v2_nG, v2_ani,
66 wfs, sym, paw)
68 e_nn *= count
69 e = k1.f_n.dot(e_nn).dot(k2.f_n) / kd.nbzkpts
70 exxvv -= 0.5 * e
71 ekin += e
73 exxvc = 0.0
74 for i, kpt in enumerate(kpts):
75 for a, VV_ii in paw.VV_aii.items():
76 P_ni = kpt.proj[a]
77 vv_n = np.einsum('ni, ij, nj -> n',
78 P_ni.conj(), VV_ii, P_ni).real
79 vc_n = np.einsum('ni, ij, nj -> n',
80 P_ni.conj(), paw.VC_aii[a], P_ni).real
81 exxvv -= vv_n.dot(kpt.f_n) * kpt.weight
82 exxvc -= vc_n.dot(kpt.f_n) * kpt.weight
84 w_knG = {}
85 G1 = comm.rank * pd.maxmyng
86 G2 = (comm.rank + 1) * pd.maxmyng
87 for v_nG, v_ani, kpt in zip(v_knG, v_kani, kpts):
88 comm.sum(v_nG)
89 w_nG = v_nG[:, G1:G2].copy()
90 w_knG[len(w_knG)] = w_nG
91 for v_ni in v_ani.values():
92 comm.sum(v_ni)
93 v1_ani = {}
94 for a, VV_ii in paw.VV_aii.items():
95 P_ni = kpt.proj[a]
96 v_ni = P_ni.dot(paw.VC_aii[a] + 2 * VV_ii)
97 v1_ani[a] = v_ani[a] - v_ni
98 ekin += (np.einsum('n, ni, ni',
99 kpt.f_n, P_ni.conj(), v_ni).real *
100 kpt.weight)
101 wfs.pt.add(w_nG, v1_ani, kpt.psit.kpt)
103 return (comm.sum_scalar(exxvv),
104 comm.sum_scalar(exxvc),
105 comm.sum_scalar(ekin),
106 w_knG)
109def calculate_exx_for_pair(k1,
110 k2,
111 ghat,
112 v_G,
113 pd1, pd2,
114 index1, index2,
115 f1_n, f2_n,
116 s,
117 count,
118 v1_nG,
119 v1_ani,
120 v2_nG,
121 v2_ani,
122 wfs,
123 sym,
124 paw,
125 F_av=None):
126 kd = wfs.kd
127 comm = wfs.world
128 factor = 1.0 / kd.nbzkpts
130 N1 = len(k1.u_nR)
131 N2 = len(k2.u_nR)
133 size = comm.size
134 rank = comm.rank
136 Q_annL = [np.einsum('mi, ijL, nj -> mnL',
137 k1.proj[a],
138 Delta_iiL,
139 k2.proj[a].conj(),
140 optimize=True)
141 for a, Delta_iiL in enumerate(paw.Delta_aiiL)]
143 if v2_nG is not None:
144 T, T_a, cc = sym.symmetry_operation(s, wfs, inverse=True)
146 if k1 is k2:
147 n2max = (N1 + size - 1) // size
148 else:
149 n2max = N2
151 e_nn = np.zeros((N1, N2))
152 rho_nG = ghat.pd.empty(n2max, k1.u_nR.dtype)
153 vrho_nG = ghat.pd.empty(n2max, k1.u_nR.dtype)
155 f_GI = ghat.expand()
157 for n1, u1_R in enumerate(k1.u_nR):
158 if k1 is k2:
159 B = (N1 - n1 + size - 1) // size
160 n20 = 0
161 n2a = min(n1 + rank * B, N2)
162 n2b = min(n2a + B, N2)
163 else:
164 B = (N1 + size - 1) // size
165 n20 = min(B * rank, N1)
166 n2a = 0
167 n2b = N2
169 for n2, rho_G in enumerate(rho_nG[:n2b - n2a], n2a):
170 rho_G[:] = ghat.pd.fft(u1_R * k2.u_nR[n2].conj())
172 add(ghat, rho_nG[:n2b - n2a],
173 {a: Q_nnL[n1, n2a:n2b]
174 for a, Q_nnL in enumerate(Q_annL)},
175 f_GI)
176 for n2, rho_G in enumerate(rho_nG[:n2b - n2a], n2a):
177 vrho_G = v_G * rho_G
178 e = ghat.pd.integrate(rho_G, vrho_G).real
179 e_nn[n1, n2] = e
180 if k1 is k2:
181 e_nn[n2, n1] = e
182 vrho_nG[n2 - n2a] = vrho_G
184 if v1_nG is not None:
185 vrho_R = ghat.pd.ifft(vrho_G)
186 if v2_nG is None:
187 assert k1 is not k2
188 v1_nG[n1] -= f2_n[n2] * factor * pd1.fft(
189 vrho_R * k2.u_nR[n2], index1, local=True)
190 else:
191 x = factor * count / 2
192 if k1 is k2 and n1 != n2:
193 x *= 2
194 x1 = x / (kd.weight_k[index1] * kd.nbzkpts)
195 x2 = x / (kd.weight_k[index2] * kd.nbzkpts)
196 v1_nG[n1] -= f2_n[n2] * x1 * pd1.fft(
197 vrho_R * k2.u_nR[n2], index1, local=True)
198 v2_nG[n2 + n20] -= f1_n[n1] * x2 * pd2.fft(
199 T(vrho_R.conj() * u1_R), index2,
200 local=True)
202 if v1_nG is not None and v2_nG is None:
203 for a, v_nL in integrate(ghat, vrho_nG[:n2b - n2a], f_GI):
204 v_iin = paw.Delta_aiiL[a].dot(v_nL.T)
205 v1_ani[a][n1] -= np.einsum('ijn, nj, n -> i',
206 v_iin,
207 k2.proj[a][n2a:n2b],
208 f2_n[n2a:n2b] * factor)
210 if v1_nG is not None and v2_nG is not None:
211 x = factor * count / kd.nbzkpts / 2
212 x1 = x / kd.weight_k[index1]
213 x2 = x / kd.weight_k[index2]
214 if k1 is k2:
215 x1 *= 2
216 x2 *= 2
218 for a, v_nL in integrate(ghat, vrho_nG[:n2b - n2a], f_GI):
219 if k1 is k2 and n2a <= n1 < n2b:
220 v_nL[n1 - n2a] *= 0.5
221 v_iin = paw.Delta_aiiL[a].dot(v_nL.T)
222 v1_ani[a][n1] -= np.einsum('ijn, nj, n -> i',
223 v_iin,
224 k2.proj[a][n2a:n2b],
225 f2_n[n2a:n2b] * x1)
226 b, S_c, U_ii = T_a[a]
227 v_ni = np.einsum('ijn, j, ik -> nk',
228 v_iin.conj(),
229 k1.proj[b][n1],
230 U_ii)
231 if v_nL.dtype == complex:
232 v_ni *= np.exp(2j * np.pi * k2.k_c.dot(S_c))
233 if cc:
234 v_ni = v_ni.conj()
235 v2_ani[a][n20 + n2a:n20 + n2b] -= v_ni * f1_n[n1] * x2
237 return e_nn * factor
240def add(ghat, a_xG, c_axi, f_GI):
241 c_xI = np.empty(a_xG.shape[:-1] + (ghat.nI,), ghat.pd.dtype)
242 for a, I1, I2 in ghat.my_indices:
243 c_xI[..., I1:I2] = c_axi[a] * ghat.eikR_qa[0][a].conj()
244 nx = np.prod(c_xI.shape[:-1], dtype=int)
245 c_xI = c_xI.reshape((nx, ghat.nI))
246 a_xG = a_xG.reshape((nx, a_xG.shape[-1])).view(ghat.pd.dtype)
247 mmm(1.0 / ghat.pd.gd.dv, c_xI, 'N', f_GI, 'T', 1.0, a_xG)
250def integrate(ghat, a_xG, f_GI):
251 c_xI = np.zeros(a_xG.shape[:-1] + (ghat.nI,), ghat.pd.dtype)
253 nx = np.prod(c_xI.shape[:-1], dtype=int)
254 b_xI = c_xI.reshape((nx, ghat.nI))
255 a_xG = a_xG.reshape((nx, a_xG.shape[-1]))
257 alpha = 1.0 / ghat.pd.gd.N_c.prod()
258 if ghat.pd.dtype == float:
259 alpha *= 2
260 a_xG = a_xG.view(float)
261 f_GI[0] *= 0.5
262 else:
263 f_GI.imag[:] = -f_GI.imag
264 mmm(alpha, a_xG, 'N', f_GI, 'N', 0.0, b_xI)
265 if ghat.pd.dtype == complex:
266 f_GI.imag[:] = -f_GI.imag
267 else:
268 f_GI[0] *= 2.0
269 for a, I1, I2 in ghat.my_indices:
270 yield a, ghat.eikR_qa[0][a] * c_xI[..., I1:I2]
273def apply2(kpt, psit_xG, Htpsit_xG, wfs, coulomb, sym, paw):
274 kd = wfs.kd
276 psit = kpt.psit.new(buf=psit_xG)
277 P = kpt.projections.new()
278 psit.matrix_elements(wfs.pt, out=P)
280 kpt1 = PWKPoint(psit,
281 P,
282 kpt.f_n + np.nan,
283 kd.ibzk_kc[kpt.k],
284 np.nan)
286 kpts2 = [PWKPoint(kpt.psit,
287 kpt.projections,
288 kpt.f_n / kpt.weight, # scale to [0, 1] range
289 kd.ibzk_kc[kpt.k],
290 kd.weight_k[kpt.k])
291 for kpt in wfs.kpt_u[kpt.s::wfs.nspins]]
292 v_nG = calculate2(kpt1, kpts2, wfs, paw, sym, coulomb)
293 return v_nG
296def calculate2(kpt1, kpts2, wfs, paw, sym, coulomb):
297 pd = kpt1.psit.pd
298 gd = pd.gd.new_descriptor(comm=serial_comm)
299 kd = wfs.kd
300 comm = wfs.world
301 nbands = len(kpt1.psit.array)
302 shapes = [(nbands, len(Delta_iiL))
303 for Delta_iiL in paw.Delta_aiiL]
304 v_ani = {a: np.zeros(shape, pd.dtype)
305 for a, shape in enumerate(shapes)}
306 v_nG = kpt1.psit.pd.zeros(nbands, global_array=True, q=kpt1.psit.kpt)
308 u1_nR = to_real_space(kpt1.psit)
309 proj1 = kpt1.proj.broadcast()
310 k1 = RSKPoint(u1_nR,
311 proj1,
312 kpt1.f_n,
313 kpt1.k_c,
314 kpt1.weight)
316 N2 = len(kpts2[0].psit.array)
317 nsym = len(kd.symmetry.op_scc)
319 size = comm.size
320 rank = comm.rank
321 B = (N2 + size - 1) // size
322 na = min(B * rank, N2)
323 nb = min(na + B, N2)
324 for i2, kpt2 in enumerate(kpts2):
325 u2_nR = to_real_space(kpt2.psit, na, nb)
326 k0 = RSKPoint(u2_nR,
327 kpt2.proj.broadcast().view(na, nb),
328 kpt2.f_n[na:nb],
329 kpt2.k_c,
330 kpt2.weight)
331 for k, i in enumerate(kd.bz2ibz_k):
332 if i != i2:
333 continue
334 s = kd.sym_k[k] + kd.time_reversal_k[k] * nsym
335 k2 = sym.apply_symmetry(s, k0, wfs, wfs.spos_ac)
336 q_c = k2.k_c - k1.k_c
337 qd = KPointDescriptor([-q_c])
339 pd12 = PWDescriptor(pd.ecut, gd, pd.dtype, kd=qd)
340 ghat = PWLFC([data.ghat_l for data in wfs.setups], pd12)
341 ghat.set_positions(wfs.spos_ac)
343 v_G = coulomb.get_potential(pd12)
344 calculate_exx_for_pair(k1, k2, ghat, v_G,
345 kpt1.psit.pd,
346 kpts2[i2].psit.pd,
347 kpt1.psit.kpt,
348 kpts2[i2].psit.kpt,
349 k1.f_n,
350 k2.f_n,
351 s,
352 1.0,
353 v_nG, v_ani,
354 None, None,
355 wfs, sym, paw)
357 G1 = comm.rank * pd.maxmyng
358 G2 = (comm.rank + 1) * pd.maxmyng
359 comm.sum(v_nG)
360 w_nG = v_nG[:, G1:G2].copy()
361 for v_ni in v_ani.values():
362 comm.sum(v_ni)
363 v1_ani = {}
364 for a, VV_ii in paw.VV_aii.items():
365 P_ni = kpt1.proj[a]
366 v_ni = P_ni.dot(paw.VC_aii[a] + 2 * VV_ii)
367 v1_ani[a] = v_ani[a] - v_ni
368 wfs.pt.add(w_nG, v1_ani, kpt1.psit.kpt)
370 return w_nG