Coverage for gpaw/response/chi0_base.py: 97%
210 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-12 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-12 00:18 +0000
1from __future__ import annotations
2from abc import ABC, abstractmethod
4import numpy as np
6from typing import TYPE_CHECKING
8from ase.units import Ha
9from gpaw.bztools import convex_hull_volume
10from gpaw.response import timer
11from gpaw.response.pair import KPointPairFactory
12from gpaw.response.frequencies import NonLinearFrequencyDescriptor
13from gpaw.response.qpd import SingleQPWDescriptor
14from gpaw.response.pw_parallelization import block_partition
15from gpaw.response.integrators import (
16 Integrand, PointIntegrator, TetrahedronIntegrator, Domain)
17from gpaw.response.symmetry import QSymmetryInput, QSymmetryAnalyzer
18from gpaw.response.kpoints import KPointDomain, KPointDomainGenerator
20if TYPE_CHECKING:
21 from gpaw.response.pair import ActualPairDensityCalculator
22 from gpaw.response.context import ResponseContext
23 from gpaw.response.groundstate import ResponseGroundStateAdapter
26class Chi0Integrand(Integrand):
27 def __init__(self, chi0calc: Chi0ComponentPWCalculator,
28 optical: bool,
29 qpd: SingleQPWDescriptor,
30 generator: KPointDomainGenerator,
31 n1: int,
32 n2: int,
33 m1: int,
34 m2: int):
35 """
36 n1 : int
37 Lower occupied band index.
38 n2 : int
39 Upper occupied band index.
40 m1 : int
41 Lower unoccupied band index.
42 m2 : int
43 Upper unoccupied band index.
44 """
46 assert m1 <= m2
47 assert n1 < n2 <= chi0calc.gs.nocc2
48 assert n1 <= chi0calc.gs.nocc1
49 assert chi0calc.gs.nocc1 <= m1
50 self.m1 = m1
51 self.m2 = m2
52 self.n1 = n1
53 self.n2 = n2
55 self._chi0calc = chi0calc
57 self.gs: ResponseGroundStateAdapter = chi0calc.gs
59 self.context: ResponseContext = chi0calc.context
60 self.kptpair_factory: KPointPairFactory = chi0calc.kptpair_factory
62 self.qpd = qpd
63 self.generator = generator
64 self.integrationmode = chi0calc.integrationmode
65 self.optical = optical
66 self.blockcomm = chi0calc.blockcomm
68 @timer('Get matrix element')
69 def matrix_element(self, point):
70 """Return pair density matrix element for integration.
72 A pair density is defined as::
74 <snk| e^(-i (q + G) r) |s'mk+q>,
76 where s and s' are spins, n and m are band indices, k is
77 the kpoint and q is the momentum transfer. For dielectric
78 response s'=s, for the transverse magnetic response
79 s' is flipped with respect to s.
81 Parameters
82 ----------
83 k_v : ndarray
84 Kpoint coordinate in cartesian coordinates.
85 s : int
86 Spin index.
88 If self.optical, then return optical pair densities, that is, the
89 head and wings matrix elements indexed by:
90 # P = (x, y, v, G1, G2, ...).
92 Return
93 ------
94 n_nmG : ndarray
95 Pair densities.
96 """
98 if self.optical:
99 # pair_calc: ActualPairDensityCalculator from gpaw.response.pair
100 target_method = self._chi0calc.pair_calc.get_optical_pair_density
101 out_ngmax = self.qpd.ngmax + 2
102 else:
103 target_method = self._chi0calc.pair_calc.get_pair_density
104 out_ngmax = self.qpd.ngmax
106 return self._get_any_matrix_element(
107 point, target_method=target_method,
108 ).reshape(-1, out_ngmax)
110 def _get_any_matrix_element(self, point, target_method):
111 qpd = self.qpd
113 k_v = point.kpt_c # XXX c/v discrepancy
115 k_c = np.dot(qpd.gd.cell_cv, k_v) / (2 * np.pi)
116 K = self.gs.kpoints.kptfinder.find(k_c)
117 # assert point.K == K, (point.K, K)
119 weight = np.sqrt(self.generator.get_kpoint_weight(k_c)
120 / self.generator.how_many_symmetries())
122 # Here we're again setting pawcorr willy-nilly
123 if self._chi0calc.pawcorr is None:
124 pairden_paw_corr = self.gs.pair_density_paw_corrections
125 self._chi0calc.pawcorr = pairden_paw_corr(qpd)
127 kptpair = self.kptpair_factory.get_kpoint_pair(
128 qpd, point.spin, K, self.n1, self.n2,
129 self.m1, self.m2, blockcomm=self.blockcomm)
131 m_m = np.arange(self.m1, self.m2)
132 n_n = np.arange(self.n1, self.n2)
133 n_nmG = target_method(qpd, kptpair, n_n, m_m,
134 pawcorr=self._chi0calc.pawcorr,
135 block=True)
137 if self.integrationmode == 'point integration':
138 n_nmG *= weight
140 df_nm = kptpair.get_occupation_differences()
141 df_nm[df_nm <= 1e-20] = 0.0
142 n_nmG *= df_nm[..., np.newaxis]**0.5
144 return n_nmG
146 @timer('Get eigenvalues')
147 def eigenvalues(self, point):
148 """A function that can return the eigenvalues.
150 A simple function describing the integrand of
151 the response function which gives an output that
152 is compatible with the gpaw k-point integration
153 routines."""
155 qpd = self.qpd
156 gs = self.gs
157 kd = gs.kd
159 k_v = point.kpt_c # XXX c/v discrepancy
161 k_c = np.dot(qpd.gd.cell_cv, k_v) / (2 * np.pi)
162 kptfinder = self.gs.kpoints.kptfinder
163 K1 = kptfinder.find(k_c)
164 K2 = kptfinder.find(k_c + qpd.q_c)
166 ik1 = kd.bz2ibz_k[K1]
167 ik2 = kd.bz2ibz_k[K2]
168 kpt1 = gs.kpt_qs[ik1][point.spin]
169 assert kd.comm.size == 1
170 kpt2 = gs.kpt_qs[ik2][point.spin]
171 deps_nm = np.subtract(kpt1.eps_n[self.n1:self.n2][:, np.newaxis],
172 kpt2.eps_n[self.m1:self.m2])
173 return deps_nm.reshape(-1)
176class Chi0ComponentCalculator:
177 """Base class for the Chi0XXXCalculator suite."""
179 def __init__(self, gs, context, *, nblocks,
180 qsymmetry: QSymmetryInput = True,
181 integrationmode='point integration'):
182 """Set up attributes common to all chi0 related calculators.
184 Parameters
185 ----------
186 nblocks : int
187 Divide response function memory allocation in nblocks.
188 qsymmetry: bool, dict, or QSymmetryAnalyzer
189 QSymmetryAnalyzer, or bool to enable all/no symmetries,
190 or dict with which to create QSymmetryAnalyzer.
191 Disabling symmetries may be useful for debugging.
192 integrationmode : str
193 Integrator for the k-point integration.
194 If == 'tetrahedron integration' then the kpoint integral is
195 performed using the linear tetrahedron method.
196 If == 'point integration', point integration is used.
197 """
198 self.gs = gs
199 self.context = context
200 self.kptpair_factory = KPointPairFactory(gs, context)
202 self.nblocks = nblocks
203 self.blockcomm, self.kncomm = block_partition(
204 self.context.comm, self.nblocks)
206 self.qsymmetry = QSymmetryAnalyzer.from_input(qsymmetry)
208 # Set up integrator
209 self.integrationmode = integrationmode
210 self.integrator = self.construct_integrator()
212 @property
213 def pbc(self):
214 return self.gs.pbc
216 def construct_integrator(self): # -> Integrator or child of Integrator
217 """Construct k-point integrator"""
218 cls = self.get_integrator_cls()
219 return cls(
220 cell_cv=self.gs.gd.cell_cv,
221 context=self.context,
222 blockcomm=self.blockcomm,
223 kncomm=self.kncomm)
225 def get_integrator_cls(self): # -> Integrator or child of Integrator
226 """Get the appointed k-point integrator class."""
227 if self.integrationmode == 'point integration':
228 self.context.print('Using integrator: PointIntegrator')
229 cls = PointIntegrator
230 elif self.integrationmode == 'tetrahedron integration':
231 self.context.print('Using integrator: TetrahedronIntegrator')
232 cls = TetrahedronIntegrator
233 if not self.qsymmetry.disabled:
234 self.check_high_symmetry_ibz_kpts()
235 else:
236 raise ValueError(f'Integration mode "{self.integrationmode}"'
237 ' not implemented.')
238 return cls
240 def check_high_symmetry_ibz_kpts(self):
241 """Check that the ground state includes all corners of the IBZ."""
242 ibz_vertices_kc = self.gs.get_ibz_vertices()
243 # Here we mimic the k-point grid compatibility check of
244 # gpaw.bztools.find_high_symmetry_monkhorst_pack()
245 bzk_kc = self.gs.kd.bzk_kc
246 for ibz_vertex_c in ibz_vertices_kc:
247 # Relative coordinate difference to the k-point grid
248 diff_kc = np.abs(bzk_kc - ibz_vertex_c)[:, self.gs.pbc].round(6)
249 # The ibz vertex should exits in the BZ grid up to a reciprocal
250 # lattice vector, meaning that the relative coordinate difference
251 # is allowed to be an integer. Thus, at least one relative k-point
252 # difference should vanish, modulo 1
253 mod_diff_kc = np.mod(diff_kc, 1)
254 nodiff_k = np.all(mod_diff_kc < 1e-5, axis=1)
255 if not np.any(nodiff_k):
256 raise ValueError(
257 'The ground state k-point grid does not include all '
258 'vertices of the IBZ. '
259 'Please use find_high_symmetry_monkhorst_pack() from '
260 'gpaw.bztools to generate your k-point grid.')
262 def get_integration_domain(self, q_c, spins):
263 """Get integrator domain and prefactor for the integral."""
264 for spin in spins:
265 assert spin in range(self.gs.nspins)
266 # The integration domain is determined by the following function
267 # that reduces the integration domain to the irreducible zone
268 # of the little group of q.
269 symmetries, generator, kpoints = self.get_kpoints(
270 q_c, integrationmode=self.integrationmode)
272 domain = Domain(kpoints.k_kv, spins)
274 if self.integrationmode == 'tetrahedron integration':
275 # If there are non-periodic directions it is possible that the
276 # integration domain is not compatible with the symmetry operations
277 # which essentially means that too large domains will be
278 # integrated. We normalize by vol(BZ) / vol(domain) to make
279 # sure that to fix this.
280 domainvol = convex_hull_volume(
281 kpoints.k_kv) * generator.how_many_symmetries()
282 bzvol = (2 * np.pi)**3 / self.gs.volume
283 factor = bzvol / domainvol
284 else:
285 factor = 1
287 prefactor = (2 * factor * generator.how_many_symmetries()
288 / (self.gs.nspins * (2 * np.pi)**3)) # Remember prefactor
290 if self.integrationmode == 'point integration':
291 nbzkpts = self.gs.kd.nbzkpts
292 prefactor *= len(kpoints) / nbzkpts
294 return symmetries, generator, domain, prefactor
296 @timer('Get kpoints')
297 def get_kpoints(self, q_c, integrationmode):
298 """Get the integration domain."""
299 symmetries, generator = self.qsymmetry.analyze(
300 np.asarray(q_c), self.gs.kpoints, self.context)
302 if integrationmode == 'point integration':
303 k_kc = generator.get_kpt_domain()
304 elif integrationmode == 'tetrahedron integration':
305 k_kc = generator.get_tetrahedron_kpt_domain(
306 pbc_c=self.pbc, cell_cv=self.gs.gd.cell_cv)
307 kpoints = KPointDomain(k_kc, self.gs.gd.icell_cv)
309 # In the future, we probably want to put enough functionality on the
310 # KPointDomain such that we don't need to also return the
311 # KPointDomainGenerator XXX
312 return symmetries, generator, kpoints
314 def get_gs_info_string(self, tab=''):
315 gs = self.gs
316 gd = gs.gd
318 ns = gs.nspins
319 nk = gs.kd.nbzkpts
320 nik = gs.kd.nibzkpts
322 nocc = self.gs.nocc1
323 npocc = self.gs.nocc2
324 ngridpoints = gd.N_c[0] * gd.N_c[1] * gd.N_c[2]
325 nstat = ns * npocc
326 occsize = nstat * ngridpoints * 16. / 1024**2
328 gs_list = [f'{tab}Ground state adapter containing:',
329 f'Number of spins: {ns}', f'Number of kpoints: {nk}',
330 f'Number of irreducible kpoints: {nik}',
331 f'Number of completely occupied states: {nocc}',
332 f'Number of partially occupied states: {npocc}',
333 f'Occupied states memory: {occsize} M / cpu']
335 return f'\n{tab}'.join(gs_list)
338class Chi0ComponentPWCalculator(Chi0ComponentCalculator, ABC):
339 """Base class for Chi0XXXCalculators, which utilize a plane-wave basis."""
341 def __init__(self, gs, context,
342 *,
343 wd,
344 hilbert=True,
345 nbands=None,
346 timeordered=False,
347 ecut=50.0,
348 eta=0.2,
349 **kwargs):
350 """Set up attributes to calculate the chi0 body and optical extensions.
352 Parameters
353 ----------
354 wd : FrequencyDescriptor
355 Frequencies for which the chi0 component is evaluated.
356 hilbert : bool
357 Hilbert transform flag. If True, the dissipative part of the chi0
358 component is evaluated, and the reactive part is calculated via a
359 hilbert transform. Only works for frequencies on the real axis and
360 requires a nonlinear frequency grid.
361 nbands : int or slice
362 Number of bands to include.
363 timeordered : bool
364 Flag for calculating the time ordered chi0 component. Used for
365 G0W0, which performs its own hilbert transform.
366 ecut : float | dict
367 Plane-wave energy cutoff in eV or dictionary for the plane-wave
368 descriptor type. See response/qpd.py for details.
369 eta : float
370 Artificial broadening of the chi0 component in eV.
371 """
372 super().__init__(gs, context, **kwargs)
374 if not isinstance(ecut, dict):
375 ecut /= Ha
376 self.ecut = ecut
377 self.nbands = nbands or self.gs.nbands
379 self.wd = wd
380 self.context.print(self.wd, flush=False)
381 self.eta = eta / Ha
382 self.hilbert = hilbert
383 self.task = self.construct_integral_task()
385 self.timeordered = bool(timeordered)
386 if self.timeordered:
387 assert self.hilbert # Timeordered is only needed for G0W0
389 self.pawcorr = None
391 self.context.print('Nonperiodic BCs: ', (~self.pbc))
392 if sum(self.pbc) == 1:
393 raise ValueError('1-D not supported atm.')
395 @property
396 def pair_calc(self) -> ActualPairDensityCalculator:
397 return self.kptpair_factory.pair_calculator(self.blockcomm)
399 def construct_integral_task(self):
400 if self.eta == 0:
401 assert not self.hilbert
402 # eta == 0 is used as a synonym for calculating the hermitian part
403 # of the response function at a range of imaginary frequencies
404 assert not self.wd.omega_w.real.any()
405 return self.construct_hermitian_task()
407 if self.hilbert:
408 # The hilbert flag is used to calculate the reponse function via a
409 # hilbert transform of its dissipative (spectral) part.
410 assert isinstance(self.wd, NonLinearFrequencyDescriptor)
411 return self.construct_hilbert_task()
413 # Otherwise, we perform a literal evaluation of the response function
414 # at the given frequencies with broadening eta
415 return self.construct_literal_task()
417 @abstractmethod
418 def construct_hermitian_task(self):
419 """Integral task for the hermitian part of chi0."""
421 def construct_hilbert_task(self):
422 if isinstance(self.integrator, PointIntegrator):
423 return self.construct_point_hilbert_task()
424 else:
425 assert isinstance(self.integrator, TetrahedronIntegrator)
426 return self.construct_tetra_hilbert_task()
428 @abstractmethod
429 def construct_point_hilbert_task(self):
430 """Integral task for point integrating the spectral part of chi0."""
432 @abstractmethod
433 def construct_tetra_hilbert_task(self):
434 """Integral task for tetrahedron integration of the spectral part."""
436 @abstractmethod
437 def construct_literal_task(self):
438 """Integral task for a literal evaluation of chi0."""
440 def get_pw_descriptor(self, q_c):
441 return SingleQPWDescriptor.from_q(q_c, self.ecut, self.gs.gd)
443 def get_response_info_string(self, qpd, tab=''):
444 nw = len(self.wd)
445 if not isinstance(self.ecut, dict):
446 ecut = self.ecut * Ha
447 else:
448 ecut = self.ecut
449 nbands = self.nbands
450 ngmax = qpd.ngmax
451 eta = self.eta * Ha
453 res_list = [f'{tab}Number of frequency points: {nw}',
454 f'Planewave cutoff: {ecut}',
455 f'Number of bands: {nbands}',
456 f'Number of planewaves: {ngmax}',
457 f'Broadening (eta): {eta}']
459 return f'\n{tab}'.join(res_list)