Coverage for gpaw/response/pair.py: 97%
269 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1import numpy as np
3from gpaw.response import ResponseContext, ResponseGroundStateAdapter, timer
4from gpaw.response.pw_parallelization import block_partition
5from gpaw.utilities.blas import mmm
8class KPoint:
9 def __init__(self, s, K, n1, n2, blocksize, na, nb,
10 ut_nR, eps_n, f_n, P_ani, k_c):
11 self.s = s # spin index
12 self.K = K # BZ k-point index
13 self.n1 = n1 # first band
14 self.n2 = n2 # first band not included
15 self.blocksize = blocksize
16 self.na = na # first band of block
17 self.nb = nb # first band of block not included
18 self.ut_nR = ut_nR # periodic part of wave functions in real-space
19 self.eps_n = eps_n # eigenvalues
20 self.f_n = f_n # occupation numbers
21 self.P_ani = P_ani # PAW projections
22 self.k_c = k_c # k-point coordinates
25class KPointPair:
26 """This class defines the kpoint-pair container object.
28 Used for calculating pair quantities it contains two kpoints,
29 and an associated set of Fourier components."""
30 def __init__(self, kpt1, kpt2, Q_G):
31 self.kpt1 = kpt1
32 self.kpt2 = kpt2
33 self.Q_G = Q_G
35 def get_transition_energies(self):
36 """Return the energy difference for specified bands."""
37 kpt1 = self.kpt1
38 kpt2 = self.kpt2
39 deps_nm = kpt1.eps_n[:, np.newaxis] - kpt2.eps_n
40 return deps_nm
42 def get_occupation_differences(self):
43 """Get difference in occupation factor between specified bands."""
44 kpt1 = self.kpt1
45 kpt2 = self.kpt2
46 df_nm = kpt1.f_n[:, np.newaxis] - kpt2.f_n
47 return df_nm
50class KPointPairFactory:
51 def __init__(self, gs, context):
52 self.gs = gs
53 self.context = context
54 assert self.gs.kd.symmetry.symmorphic
55 assert self.gs.world.size == 1
57 @timer('Get a k-point')
58 def get_k_point(self, s, K, n1, n2, blockcomm=None):
59 """Return wave functions for a specific k-point and spin.
61 s: int
62 Spin index (0 or 1).
63 K: int
64 BZ k-point index.
65 n1, n2: int
66 Range of bands to include.
67 """
69 assert n1 <= n2
71 gs = self.gs
72 kd = gs.kd
74 if blockcomm:
75 nblocks = blockcomm.size
76 rank = blockcomm.rank
77 else:
78 nblocks = 1
79 rank = 0
81 blocksize = (n2 - n1 + nblocks - 1) // nblocks
82 na = min(n1 + rank * blocksize, n2)
83 nb = min(na + blocksize, n2)
85 ik = kd.bz2ibz_k[K]
86 assert kd.comm.size == 1
87 kpt = gs.kpt_qs[ik][s]
89 assert n2 <= len(kpt.eps_n), \
90 'Increase GS-nbands or decrease chi0-nbands!'
91 eps_n = kpt.eps_n[n1:n2]
92 f_n = kpt.f_n[n1:n2] / kpt.weight
94 k_c = self.gs.ibz2bz[K].map_kpoint()
96 with self.context.timer('load wfs'):
97 psit_nG = kpt.psit_nG
98 ut_nR = gs.gd.empty(nb - na, gs.dtype)
99 for n in range(na, nb):
100 ut_nR[n - na] = self.gs.ibz2bz[K].map_pseudo_wave(
101 gs.pd.ifft(psit_nG[n], ik))
103 with self.context.timer('Load projections'):
104 if nb - na > 0:
105 proj = kpt.projections.new(nbands=nb - na, bcomm=None)
106 proj.array[:] = kpt.projections.array[na:nb]
107 proj = self.gs.ibz2bz[K].map_projections(proj)
108 P_ani = [P_ni for _, P_ni in proj.items()]
109 else:
110 P_ani = []
112 return KPoint(s, K, n1, n2, blocksize, na, nb,
113 ut_nR, eps_n, f_n, P_ani, k_c)
115 @timer('Get kpoint pair')
116 def get_kpoint_pair(self, qpd, s, K, n1, n2, m1, m2,
117 blockcomm=None, flipspin=False):
118 assert m1 <= m2
119 assert n1 <= n2
121 kptfinder = self.gs.kpoints.kptfinder
123 k_c = self.gs.kd.bzk_kc[K]
124 K1 = kptfinder.find(k_c)
125 K2 = kptfinder.find(k_c + qpd.q_c)
126 s1 = s
127 s2 = (s + flipspin) % 2
129 with self.context.timer('get k-points'):
130 kpt1 = self.get_k_point(s1, K1, n1, n2)
131 kpt2 = self.get_k_point(s2, K2, m1, m2, blockcomm=blockcomm)
133 with self.context.timer('fft indices'):
134 Q_G = phase_shifted_fft_indices(kpt1.k_c, kpt2.k_c, qpd)
136 return KPointPair(kpt1, kpt2, Q_G)
138 def pair_calculator(self, blockcomm=None):
139 # We have decoupled the actual pair density calculator
140 # from the kpoint factory, but it's still handy to
141 # keep this shortcut -- for now.
142 if blockcomm is None:
143 blockcomm, _ = block_partition(self.context.comm, nblocks=1)
144 return ActualPairDensityCalculator(self, blockcomm)
147class ActualPairDensityCalculator:
148 def __init__(self, kptpair_factory, blockcomm):
149 # it seems weird to use kptpair_factory only for this
150 self.gs = kptpair_factory.gs
151 self.context = kptpair_factory.context
152 self.blockcomm = blockcomm
153 self.ut_sKnvR = None # gradient of wave functions for optical limit
155 def get_optical_pair_density(self, qpd, kptpair, n_n, m_m, *,
156 pawcorr, block=False):
157 """Get the full optical pair density, including the optical limit head
158 for q=0."""
159 tmp_nmG = self.get_pair_density(qpd, kptpair, n_n, m_m,
160 pawcorr=pawcorr, block=block)
162 nG = qpd.ngmax
163 # P = (x, y, z, G1, G2, ...)
164 n_nmP = np.empty((len(n_n), len(m_m), nG + 2), dtype=tmp_nmG.dtype)
165 n_nmP[:, :, 3:] = tmp_nmG[:, :, 1:]
166 n_nmv = self.get_optical_pair_density_head(qpd, kptpair, n_n, m_m,
167 block=block)
168 n_nmP[:, :, :3] = n_nmv
170 return n_nmP
172 @timer('get_pair_density')
173 def get_pair_density(self, qpd, kptpair, n_n, m_m, *,
174 pawcorr, block=False):
175 """Get pair density for a kpoint pair."""
176 cpd = self.calculate_pair_density
178 kpt1 = kptpair.kpt1
179 kpt2 = kptpair.kpt2
180 Q_G = kptpair.Q_G # Fourier components of kpoint pair
181 nG = len(Q_G)
183 n_nmG = np.zeros((len(n_n), len(m_m), nG), qpd.dtype)
185 for j, n in enumerate(n_n):
186 Q_G = kptpair.Q_G
187 with self.context.timer('conj'):
188 ut1cc_R = kpt1.ut_nR[n - kpt1.na].conj()
189 with self.context.timer('paw'):
190 C1_aGi = pawcorr.multiply(kpt1.P_ani, band=n - kpt1.na)
191 n_nmG[j] = cpd(ut1cc_R, C1_aGi, kpt2, qpd, Q_G, block=block)
193 return n_nmG
195 @timer('get_optical_pair_density_head')
196 def get_optical_pair_density_head(self, qpd, kptpair, n_n, m_m,
197 block=False):
198 """Get the optical limit of the pair density head (G=0) for a k-pair.
199 """
200 assert np.allclose(qpd.q_c, 0.0), f"{qpd.q_c} is not the optical limit"
202 kpt1 = kptpair.kpt1
203 kpt2 = kptpair.kpt2
205 # v = (x, y, z)
206 n_nmv = np.zeros((len(n_n), len(m_m), 3), qpd.dtype)
208 for j, n in enumerate(n_n):
209 n_nmv[j] = self.calculate_optical_pair_density_head(n, m_m,
210 kpt1, kpt2,
211 block=block)
213 return n_nmv
215 @timer('Calculate pair-densities')
216 def calculate_pair_density(self, ut1cc_R, C1_aGi, kpt2, qpd, Q_G,
217 block=True):
218 """Calculate FFT of pair-densities and add PAW corrections.
220 ut1cc_R: 3-d complex ndarray
221 Complex conjugate of the periodic part of the left hand side
222 wave function.
223 C1_aGi: list of ndarrays
224 PAW corrections for all atoms.
225 kpt2: KPoint object
226 Right hand side k-point object.
227 qpd: SingleQPWDescriptor
228 Plane-wave descriptor for q=k2-k1.
229 Q_G: 1-d int ndarray
230 Mapping from flattened 3-d FFT grid to 0.5(G+q)^2<ecut sphere.
231 """
232 dv = qpd.gd.dv
233 n_mG = qpd.empty(kpt2.blocksize)
234 myblocksize = kpt2.nb - kpt2.na
236 for ut_R, n_G in zip(kpt2.ut_nR, n_mG):
237 n_R = ut1cc_R * ut_R
238 with self.context.timer('fft'):
239 n_G[:] = qpd.fft(n_R, 0, Q_G) * dv
240 # PAW corrections:
241 with self.context.timer('gemm'):
242 for C1_Gi, P2_mi in zip(C1_aGi, kpt2.P_ani):
243 # gemm(1.0, C1_Gi, P2_mi, 1.0, n_mG[:myblocksize], 't')
244 mmm(1.0, P2_mi, 'N', C1_Gi, 'T', 1.0, n_mG[:myblocksize])
246 if not block or self.blockcomm.size == 1:
247 return n_mG
248 else:
249 n_MG = qpd.empty(kpt2.blocksize * self.blockcomm.size)
250 with self.context.timer('all_gather'):
251 self.blockcomm.all_gather(n_mG, n_MG)
252 return n_MG[:kpt2.n2 - kpt2.n1]
254 @timer('Optical limit')
255 def calculate_optical_pair_velocity(self, n, kpt1, kpt2, block=False):
256 # This has the effect of caching at most one kpoint.
257 # This caching will be efficient only if we are looping over kpoints
258 # in a particular way.
259 #
260 # It would be better to refactor so this caching is handled explicitly
261 # by the caller providing the right thing.
262 #
263 # See https://gitlab.com/gpaw/gpaw/-/issues/625
264 if self.ut_sKnvR is None or kpt1.K not in self.ut_sKnvR[kpt1.s]:
265 self.ut_sKnvR = self.calculate_derivatives(kpt1)
267 gd = self.gs.gd
268 k_v = 2 * np.pi * np.dot(kpt1.k_c, np.linalg.inv(gd.cell_cv).T)
270 ut_vR = self.ut_sKnvR[kpt1.s][kpt1.K][n - kpt1.n1]
271 atomdata_a = self.gs.pawdatasets.by_atom
272 C_avi = [np.dot(atomdata.nabla_iiv.T, P_ni[n - kpt1.na])
273 for atomdata, P_ni in zip(atomdata_a, kpt1.P_ani)]
275 blockbands = kpt2.nb - kpt2.na
276 n0_mv = np.empty((kpt2.blocksize, 3), dtype=complex)
277 nt_m = np.empty(kpt2.blocksize, dtype=complex)
278 n0_mv[:blockbands] = -self.gs.gd.integrate(ut_vR,
279 kpt2.ut_nR).T
280 nt_m[:blockbands] = self.gs.gd.integrate(kpt1.ut_nR[n - kpt1.na],
281 kpt2.ut_nR)
283 n0_mv[:blockbands] += (1j * nt_m[:blockbands, np.newaxis] *
284 k_v[np.newaxis, :])
286 for C_vi, P_mi in zip(C_avi, kpt2.P_ani):
287 # gemm(1.0, C_vi, P_mi, 1.0, n0_mv[:blockbands], 'c')
288 mmm(1.0, P_mi, 'N', C_vi, 'C', 1.0, n0_mv[:blockbands])
290 if block and self.blockcomm.size > 1:
291 n0_Mv = np.empty((kpt2.blocksize * self.blockcomm.size, 3),
292 dtype=complex)
293 with self.context.timer('all_gather optical'):
294 self.blockcomm.all_gather(n0_mv, n0_Mv)
295 n0_mv = n0_Mv[:kpt2.n2 - kpt2.n1]
297 return -1j * n0_mv
299 def calculate_optical_pair_density_head(self, n, m_m, kpt1, kpt2,
300 block=False):
301 # Numerical threshold for the optical limit k dot p perturbation
302 # theory expansion:
303 threshold = 1
305 eps1 = kpt1.eps_n[n - kpt1.n1]
306 deps_m = (eps1 - kpt2.eps_n)[m_m - kpt2.n1]
307 n0_mv = self.calculate_optical_pair_velocity(n, kpt1, kpt2,
308 block=block)
310 deps_m = deps_m.copy()
311 deps_m[deps_m == 0.0] = np.inf
313 smallness_mv = np.abs(-1e-3 * n0_mv / deps_m[:, np.newaxis])
314 inds_mv = (np.logical_and(np.inf > smallness_mv,
315 smallness_mv > threshold))
316 n0_mv *= - 1 / deps_m[:, np.newaxis]
317 n0_mv[inds_mv] = 0
319 return n0_mv
321 @timer('Intraband')
322 def intraband_pair_density(self, kpt, n_n):
323 """Calculate intraband matrix elements of nabla"""
324 # Bands and check for block parallelization
325 na, nb, n1 = kpt.na, kpt.nb, kpt.n1
326 vel_nv = np.zeros((nb - na, 3), dtype=complex)
327 assert np.max(n_n) < nb, 'This is too many bands'
328 assert np.min(n_n) >= na, 'This is too few bands'
330 # Load kpoints
331 gd = self.gs.gd
332 k_v = 2 * np.pi * np.dot(kpt.k_c, np.linalg.inv(gd.cell_cv).T)
333 atomdata_a = self.gs.pawdatasets.by_atom
335 # Break bands into degenerate chunks
336 degchunks_cn = [] # indexing c as chunk number
337 for n in n_n:
338 inds_n = np.nonzero(np.abs(kpt.eps_n[n - n1] -
339 kpt.eps_n) < 1e-5)[0] + n1
341 # Has this chunk already been computed?
342 oldchunk = any([n in chunk for chunk in degchunks_cn])
343 if not oldchunk:
344 if not all([ind in n_n for ind in inds_n]):
345 raise RuntimeError(
346 'You are cutting over a degenerate band '
347 'using block parallelization.')
348 degchunks_cn.append(inds_n)
350 # Calculate matrix elements by diagonalizing each block
351 for ind_n in degchunks_cn:
352 deg = len(ind_n)
353 ut_nvR = self.gs.gd.zeros((deg, 3), complex)
354 vel_nnv = np.zeros((deg, deg, 3), dtype=complex)
355 # States are included starting from kpt.na
356 ut_nR = kpt.ut_nR[ind_n - na]
358 # Get derivatives
359 for ind, ut_vR in zip(ind_n, ut_nvR):
360 ut_vR[:] = self.make_derivative(kpt.s, kpt.K,
361 ind, ind + 1)[0]
363 # Treat the whole degenerate chunk
364 for n in range(deg):
365 ut_vR = ut_nvR[n]
366 C_avi = [np.dot(atomdata.nabla_iiv.T, P_ni[ind_n[n] - na])
367 for atomdata, P_ni in zip(atomdata_a, kpt.P_ani)]
369 nabla0_nv = -self.gs.gd.integrate(ut_vR, ut_nR).T
370 nt_n = self.gs.gd.integrate(ut_nR[n], ut_nR)
371 nabla0_nv += 1j * nt_n[:, np.newaxis] * k_v[np.newaxis, :]
373 for C_vi, P_ni in zip(C_avi, kpt.P_ani):
374 # gemm(1.0, C_vi, P_ni[ind_n - na], 1.0, nabla0_nv, 'c')
375 mmm(1.0, P_ni[ind_n - na], 'N', C_vi, 'C', 1.0, nabla0_nv)
377 vel_nnv[n] = -1j * nabla0_nv
379 for iv in range(3):
380 vel, _ = np.linalg.eig(vel_nnv[..., iv])
381 vel_nv[ind_n - na, iv] = vel # Use eigenvalues
383 return vel_nv[n_n - na]
385 def calculate_derivatives(self, kpt):
386 ut_sKnvR = [{}, {}]
387 ut_nvR = self.make_derivative(kpt.s, kpt.K, kpt.n1, kpt.n2)
388 ut_sKnvR[kpt.s][kpt.K] = ut_nvR
390 return ut_sKnvR
392 @timer('Derivatives')
393 def make_derivative(self, s, K, n1, n2):
394 gs = self.gs
395 U_cc = gs.ibz2bz[K].U_cc
396 A_cv = gs.gd.cell_cv
397 M_vv = np.dot(np.dot(A_cv.T, U_cc.T), np.linalg.inv(A_cv).T)
398 ik = gs.kd.bz2ibz_k[K]
399 assert gs.kd.comm.size == 1
400 kpt = gs.kpt_qs[ik][s]
401 psit_nG = kpt.psit_nG
402 iG_Gv = 1j * gs.pd.get_reciprocal_vectors(q=ik, add_q=False)
403 ut_nvR = gs.gd.zeros((n2 - n1, 3), complex)
404 for n in range(n1, n2):
405 for v in range(3):
406 ut_R = gs.ibz2bz[K].map_pseudo_wave(
407 gs.pd.ifft(iG_Gv[:, v] * psit_nG[n], ik))
408 for v2 in range(3):
409 ut_nvR[n - n1, v2] += ut_R * M_vv[v, v2]
411 return ut_nvR
414def phase_shifted_fft_indices(k1_c, k2_c, qpd, coordinate_transformation=None):
415 """Get phase shifted FFT indices for G-vectors inside the cutoff sphere.
417 The output 1D FFT indices Q_G can be used to extract the plane-wave
418 components G of the phase shifted Fourier transform
420 n_kk'(G+q) = FFT_G[e^(-i[k+q-k']r) n_kk'(r)]
422 where n_kk'(r) is some lattice periodic function and the wave vector
423 difference k + q - k' is commensurate with the reciprocal lattice.
424 """
425 N_c = qpd.gd.N_c
426 Q_G = qpd.Q_qG[0]
427 q_c = qpd.q_c
428 if coordinate_transformation:
429 q_c = coordinate_transformation(q_c)
431 shift_c = k1_c + q_c - k2_c
432 assert np.allclose(shift_c.round(), shift_c)
433 shift_c = shift_c.round().astype(int)
435 if shift_c.any() or coordinate_transformation:
436 # Get the 3D FFT grid indices (relative reciprocal space coordinates)
437 # of the G-vectors inside the cutoff sphere
438 i_cG = np.unravel_index(Q_G, N_c)
439 if coordinate_transformation:
440 i_cG = coordinate_transformation(i_cG)
441 # Shift the 3D FFT grid indices to account for the Bloch-phase shift
442 # e^(-i[k+q-k']r)
443 i_cG += shift_c[:, np.newaxis]
444 # Transform back the FFT grid to 1D FFT indices
445 Q_G = np.ravel_multi_index(i_cG, N_c, 'wrap')
447 return Q_G
450def get_gs_and_context(calc, txt, world, timer):
451 """Interface to initialize gs and context from old input arguments.
452 Should be phased out in the future!"""
453 from gpaw.calculator import GPAW as OldGPAW
454 from gpaw.new.ase_interface import ASECalculator as NewGPAW
456 context = ResponseContext(txt=txt, timer=timer, comm=world)
458 if isinstance(calc, (OldGPAW, NewGPAW)):
459 assert calc.wfs.world.size == 1
460 gs = calc.gs_adapter()
461 else:
462 gs = ResponseGroundStateAdapter.from_gpw_file(gpw=calc)
464 return gs, context