Coverage for gpaw/response/chi0.py: 95%
193 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1from __future__ import annotations
3from time import ctime
4from typing import TYPE_CHECKING, Optional
6import numpy as np
7from ase.units import Ha
9import gpaw
10from gpaw.response import (ResponseGroundStateAdapter, ResponseContext,
11 ResponseGroundStateAdaptable, ResponseContextInput)
12from gpaw.response.symmetrize import (BodySymmetryOperators,
13 WingSymmetryOperators)
14from gpaw.response.chi0_data import (Chi0Data, Chi0BodyData,
15 Chi0OpticalExtensionData)
16from gpaw.response.frequencies import FrequencyDescriptor
17from gpaw.response.qpd import SingleQPWDescriptor
18from gpaw.response.hilbert import HilbertTransform
19from gpaw.response import timer
20from gpaw.response.pw_parallelization import PlaneWaveBlockDistributor
21from gpaw.utilities.memory import maxrss
22from gpaw.response.chi0_base import Chi0ComponentPWCalculator, Chi0Integrand
23from gpaw.response.integrators import (
24 HermitianOpticalLimit, HilbertOpticalLimit, OpticalLimit,
25 HilbertOpticalLimitTetrahedron,
26 Hermitian, Hilbert, HilbertTetrahedron, GenericUpdate)
28if TYPE_CHECKING:
29 from typing import Any
30 from gpaw.typing import ArrayLike1D
31 from gpaw.response.pair import ActualPairDensityCalculator
34class Chi0Calculator:
35 def __init__(self,
36 gs: ResponseGroundStateAdaptable,
37 context: ResponseContextInput = '-',
38 nblocks=1,
39 eshift=None,
40 intraband=True,
41 rate='eta',
42 **kwargs):
43 """
44 Parameters
45 ----------
46 eshift : float or None
47 Energy shift of the conduction bands in eV.
48 """
49 self.eshift = eshift / Ha if eshift else None
50 self.gs = ResponseGroundStateAdapter.from_input(gs)
51 self.context = ResponseContext.from_input(context)
52 self.chi0_body_calc = Chi0BodyCalculator(
53 self.gs, self.context,
54 nblocks=nblocks, eshift=self.eshift, **kwargs)
55 self.chi0_opt_ext_calc = Chi0OpticalExtensionCalculator(
56 self.gs, self.context,
57 intraband=intraband, rate=rate, eshift=self.eshift, **kwargs)
59 @property
60 def wd(self) -> FrequencyDescriptor:
61 wd = self.chi0_body_calc.wd
62 assert wd is self.chi0_opt_ext_calc.wd
63 return wd
65 @property
66 def pair_calc(self) -> ActualPairDensityCalculator:
67 # In a future refactor, we should find better ways to access the pair
68 # density calculator (and the pair density paw corrections) XXX
70 # pair_calc: ActualPairDensityCalculator from gpaw.response.pair
71 return self.chi0_body_calc.pair_calc
73 def create_chi0(self, q_c: list | np.ndarray) -> Chi0Data:
75 # chi0_body: Chi0BodyData from gpaw.response.chi0_data
76 chi0_body = self.chi0_body_calc.create_chi0_body(q_c)
78 # chi0: Chi0Data from gpaw.response.chi0_data
79 chi0 = Chi0Data.from_chi0_body_data(chi0_body)
80 return chi0
82 def calculate(self, q_c: list | np.ndarray) -> Chi0Data:
83 """Calculate chi0 (possibly with optical extensions).
85 Parameters
86 ----------
87 q_c : list or ndarray
88 Momentum vector.
90 Returns
91 -------
92 chi0 : Chi0Data
93 Data object containing the chi0 data arrays along with basis
94 representation descriptors and blocks distribution
95 """
96 # Calculate body
98 # chi0_body: Chi0BodyData from gpaw.response.chi0_data
99 chi0_body = self.chi0_body_calc.calculate(q_c)
100 # SingleQPWDescriptor from gpaw.response.qpd
101 qpd = chi0_body.qpd
103 # Calculate optical extension
104 if qpd.optical_limit:
105 chi0_opt_ext = self.chi0_opt_ext_calc.calculate(qpd=qpd)
106 else:
107 chi0_opt_ext = None
109 self.context.print('\nFinished calculating chi0\n')
111 return Chi0Data(chi0_body, chi0_opt_ext)
113 @timer('Calculate CHI_0')
114 def update_chi0(self,
115 chi0: Chi0Data,
116 *,
117 n1: int = 0, n2: Optional[int] = None,
118 m1: int, m2: int,
119 spins: list
120 ) -> Chi0Data:
121 """In-place calculation of chi0 (with optical extension).
123 Parameters
124 ----------
125 chi0 : Chi0Data
126 Data and representation object
127 m1 : int
128 Lower band cutoff for band summation
129 m2 : int
130 Upper band cutoff for band summation
131 spins : list
132 List of spin indices to include in the calculation
134 Returns
135 -------
136 chi0 : Chi0Data
137 """
138 if n2 is None:
139 n2 = self.gs.nocc2
140 self.chi0_body_calc.update_chi0_body(chi0.body, n1, n2, m1, m2, spins)
141 if chi0.optical_limit:
142 assert chi0.optical_extension is not None
143 # Update the head and wings
144 self.chi0_opt_ext_calc.update_chi0_optical_extension(
145 chi0.optical_extension, n1, n2, m1, m2, spins)
146 return chi0
149class Chi0BodyCalculator(Chi0ComponentPWCalculator):
150 def __init__(self, *args,
151 eshift: float | None = None,
152 **kwargs):
153 """Construct the Chi0BodyCalculator.
155 Parameters
156 ----------
157 eshift : float or None
158 Energy shift of the conduction bands in Hartree.
159 """
161 self.eshift = eshift
163 super().__init__(*args, **kwargs)
165 if self.gs.metallic:
166 assert self.eshift is None, \
167 'A rigid energy shift cannot be applied to the conduction '\
168 'bands if there is no band gap'
170 def create_chi0_body(self, q_c: list | np.ndarray) -> Chi0BodyData:
171 # qpd: SingleQPWDescriptor from gpaw.response.qpd
172 qpd = self.get_pw_descriptor(q_c)
173 return self._create_chi0_body(qpd)
175 def _create_chi0_body(self, qpd: SingleQPWDescriptor) -> Chi0BodyData:
176 return Chi0BodyData(self.wd, qpd, self.get_blockdist())
178 def get_blockdist(self) -> PlaneWaveBlockDistributor:
179 # integrator: Integrator from gpaw.response.integrators
180 # (or a child of this class)
181 return PlaneWaveBlockDistributor(
182 self.context.comm, # _Communicator object from gpaw.mpi
183 self.integrator.blockcomm, # _Communicator object from gpaw.mpi
184 self.integrator.kncomm) # _Communicator object from gpaw.mpi
186 def calculate(self, q_c: list | np.ndarray) -> Chi0BodyData:
187 """Calculate the chi0 body.
189 Parameters
190 ----------
191 q_c : list or ndarray
192 Momentum vector.
193 """
194 # Construct the output data structure
195 # qpd: SingleQPWDescriptor from gpaw.response.qpd
196 qpd = self.get_pw_descriptor(q_c)
197 self.print_info(qpd)
198 # chi0_body: Chi0BodyData from gpaw.response.chi0_data
199 chi0_body = self._create_chi0_body(qpd)
201 # Integrate all transitions into partially filled and empty bands
202 n1, n2, m1, m2 = self.gs.get_band_transitions(nbands=self.nbands)
203 self.update_chi0_body(chi0_body, n1, n2, m1, m2,
204 spins=range(self.gs.nspins))
206 return chi0_body
208 def update_chi0_body(self,
209 chi0_body: Chi0BodyData,
210 n1: int, n2: int,
211 m1: int, m2: int,
212 spins: list | range):
213 """In-place calculation of the body.
215 Parameters
216 ----------
217 m1 : int
218 Lower band cutoff for band summation
219 m2 : int
220 Upper band cutoff for band summation
221 spins : list
222 List of spin indices to include in the calculation
223 """
224 qpd = chi0_body.qpd
226 # Reset PAW correction in case momentum has change
227 pairden_paw_corr = self.gs.pair_density_paw_corrections
228 self.pawcorr = pairden_paw_corr(chi0_body.qpd)
230 self.context.print('Integrating chi0 body.')
232 # symmetries: QSymmetries from gpaw.response.symmetry
233 # generator: KPointDomainGenerator from gpaw.response.kpoints
234 # domain: Domain from from gpaw.response.integrators
235 symmetries, generator, domain, prefactor = self.get_integration_domain(
236 qpd.q_c, spins)
237 integrand = Chi0Integrand(self, qpd=qpd, generator=generator,
238 optical=False, n1=n1, n2=n2, m1=m1, m2=m2)
240 chi0_body.data_WgG[:] /= prefactor
241 if self.hilbert:
242 # Allocate a temporary array for the spectral function
243 out_WgG = chi0_body.zeros()
244 else:
245 # Use the preallocated array for direct updates
246 out_WgG = chi0_body.data_WgG
247 self.integrator.integrate(domain=domain, # Integration domain
248 integrand=integrand,
249 task=self.task,
250 wd=self.wd, # Frequency Descriptor
251 out_wxx=out_WgG) # Output array
253 if self.hilbert:
254 # The integrator only returns the spectral function and a Hilbert
255 # transform is performed to return the real part of the density
256 # response function.
257 with self.context.timer('Hilbert transform'):
258 # Make Hilbert transform
259 ht = HilbertTransform(np.array(self.wd.omega_w), self.eta,
260 timeordered=self.timeordered)
261 ht(out_WgG)
262 # Update the actual chi0 array
263 chi0_body.data_WgG[:] += out_WgG
264 chi0_body.data_WgG[:] *= prefactor
266 tmp_chi0_wGG = chi0_body.copy_array_with_distribution('wGG')
267 with self.context.timer('symmetrize_wGG'):
268 operators = BodySymmetryOperators(symmetries, chi0_body.qpd)
269 operators.symmetrize_wGG(tmp_chi0_wGG)
270 chi0_body.data_WgG[:] = chi0_body.blockdist.distribute_as(
271 tmp_chi0_wGG, chi0_body.nw, 'WgG')
273 def construct_hermitian_task(self):
274 return Hermitian(self.integrator.blockcomm, eshift=self.eshift)
276 def construct_point_hilbert_task(self):
277 return Hilbert(self.integrator.blockcomm, eshift=self.eshift)
279 def construct_tetra_hilbert_task(self):
280 return HilbertTetrahedron(self.integrator.blockcomm)
282 def construct_literal_task(self):
283 return GenericUpdate(
284 self.eta, self.integrator.blockcomm, eshift=self.eshift)
286 def print_info(self, qpd: SingleQPWDescriptor):
288 if gpaw.dry_run:
289 from gpaw.mpi import SerialCommunicator
290 size = gpaw.dry_run
291 comm = SerialCommunicator()
292 comm.size = size
293 else:
294 comm = self.context.comm
296 q_c = qpd.q_c
297 nw = len(self.wd)
298 csize = comm.size
299 knsize = self.integrator.kncomm.size
300 bsize = self.integrator.blockcomm.size
301 chisize = nw * qpd.ngmax**2 * 16. / 1024**2 / bsize
303 isl = ['', f'{ctime()}',
304 'Calculating chi0 body with:',
305 self.get_gs_info_string(tab=' '), '',
306 ' Linear response parametrization:',
307 f' q_c: [{q_c[0]}, {q_c[1]}, {q_c[2]}]',
308 self.get_response_info_string(qpd, tab=' '),
309 f' comm.size: {csize}',
310 f' kncomm.size: {knsize}',
311 f' blockcomm.size: {bsize}']
312 if bsize > nw:
313 isl.append(
314 'WARNING! Your nblocks is larger than number of frequency'
315 ' points. Errors might occur, if your submodule does'
316 ' not know how to handle this.')
317 isl.extend(['',
318 ' Memory estimate of potentially large arrays:',
319 f' chi0_wGG: {chisize} M / cpu',
320 ' Memory usage before allocation: '
321 f'{(maxrss() / 1024**2)} M / cpu'])
322 self.context.print('\n'.join(isl))
325class Chi0OpticalExtensionCalculator(Chi0ComponentPWCalculator):
327 def __init__(self, *args,
328 eshift=None,
329 intraband=True,
330 rate='eta',
331 **kwargs):
332 """Contruct the Chi0OpticalExtensionCalculator.
334 Parameters
335 ----------
336 intraband : bool
337 Flag for including the intraband contribution to the chi0 head.
338 rate : float, str
339 Phenomenological scattering rate to use in optical limit Drude term
340 (in eV). If rate='eta', it uses input artificial broadening eta as
341 rate. Please note that for consistency the rate is implemented as
342 omegap^2 / (omega + 1j * rate)^2, which differs from some
343 literature by a factor of 2.
344 eshift : float or None
345 Energy shift of the conduction bands in Hartree.
346 """
347 # Serial block distribution
349 self.eshift = eshift
351 super().__init__(*args, nblocks=1, **kwargs)
353 if self.gs.metallic:
354 assert self.eshift is None, \
355 'A rigid energy shift cannot be applied to the conduction '\
356 'bands if there is no band gap'
358 # In the optical limit of metals, one must add the Drude dielectric
359 # response from the free-space plasma frequency of the intraband
360 # transitions to the head of chi0. This is handled by a separate
361 # calculator, provided that intraband is set to True.
362 if self.gs.metallic and intraband:
363 from gpaw.response.chi0_drude import Chi0DrudeCalculator
364 if rate == 'eta':
365 rate = self.eta * Ha # external units
366 self.rate = rate
367 self.drude_calc = Chi0DrudeCalculator(
368 self.gs, self.context,
369 qsymmetry=self.qsymmetry,
370 integrationmode=self.integrationmode)
371 else:
372 self.drude_calc = None
373 self.rate = None
375 def calculate(self,
376 qpd: SingleQPWDescriptor | None = None
377 ) -> Chi0OpticalExtensionData:
378 """Calculate the chi0 head and wings."""
379 # Create data object
380 if qpd is None:
381 qpd = self.get_pw_descriptor(q_c=[0., 0., 0.])
383 # wd: FrequencyDescriptor from gpaw.response.frequencies
384 chi0_opt_ext = Chi0OpticalExtensionData(self.wd, qpd)
386 self.print_info(qpd)
388 # Define band transitions
389 n1, n2, m1, m2 = self.gs.get_band_transitions(nbands=self.nbands)
391 # Perform the actual integration
392 self.update_chi0_optical_extension(chi0_opt_ext, n1, n2, m1, m2,
393 spins=range(self.gs.nspins))
395 if self.drude_calc is not None:
396 # Add intraband contribution
397 # drude_calc: Chi0DrudeCalculator from gpaw.response.chi0_drude
398 # chi0_drude: Chi0DrudeData from gpaw.response.chi0_data
399 chi0_drude = self.drude_calc.calculate(self.wd, self.rate)
400 chi0_opt_ext.head_Wvv[:] += chi0_drude.chi_Zvv
402 return chi0_opt_ext
404 def update_chi0_optical_extension(
405 self,
406 chi0_optical_extension: Chi0OpticalExtensionData,
407 n1: int, n2: int,
408 m1: int, m2: int,
409 spins: list | range):
410 """In-place calculation of the chi0 head and wings.
412 Parameters
413 ----------
414 m1 : int
415 Lower band cutoff for band summation
416 m2 : int
417 Upper band cutoff for band summation
418 spins : list
419 List of spin indices to include in the calculation
420 """
421 self.context.print('Integrating chi0 head and wings.')
422 chi0_opt_ext = chi0_optical_extension
423 qpd = chi0_opt_ext.qpd
425 symmetries, generator, domain, prefactor = self.get_integration_domain(
426 qpd.q_c, spins)
427 integrand = Chi0Integrand(self, qpd=qpd, generator=generator,
428 optical=True, n1=n1, n2=n2, m1=m1, m2=m2)
430 # We integrate the head and wings together, using the combined index P
431 # index v = (x, y, z)
432 # index G = (G0, G1, G2, ...)
433 # index P = (x, y, z, G1, G2, ...)
434 WxvP_shape = list(chi0_opt_ext.WxvG_shape)
435 WxvP_shape[-1] += 2
436 tmp_chi0_WxvP = np.zeros(WxvP_shape, complex)
437 self.integrator.integrate(domain=domain, # Integration domain
438 integrand=integrand,
439 task=self.task,
440 wd=self.wd, # Frequency Descriptor
441 out_wxx=tmp_chi0_WxvP) # Output array
442 if self.hilbert:
443 with self.context.timer('Hilbert transform'):
444 ht = HilbertTransform(np.array(self.wd.omega_w), self.eta,
445 timeordered=self.timeordered)
446 ht(tmp_chi0_WxvP)
447 tmp_chi0_WxvP *= prefactor
449 # Fill in wings part of the data, but leave out the head part (G0)
450 chi0_opt_ext.wings_WxvG[..., 1:] += tmp_chi0_WxvP[..., 3:]
451 # Fill in the head
452 chi0_opt_ext.head_Wvv[:] += tmp_chi0_WxvP[:, 0, :3, :3]
453 # Symmetrize
454 operators = WingSymmetryOperators(symmetries, qpd)
455 operators.symmetrize_wxvG(chi0_opt_ext.wings_WxvG)
456 operators.symmetrize_wvv(chi0_opt_ext.head_Wvv)
458 def construct_hermitian_task(self):
459 return HermitianOpticalLimit(eshift=self.eshift)
461 def construct_point_hilbert_task(self):
462 return HilbertOpticalLimit(eshift=self.eshift)
464 def construct_tetra_hilbert_task(self):
465 assert self.eshift is None, \
466 'energy shift is not applied here'
467 return HilbertOpticalLimitTetrahedron()
469 def construct_literal_task(self):
470 return OpticalLimit(eta=self.eta, eshift=self.eshift)
472 def print_info(self, qpd: SingleQPWDescriptor):
473 """Print information about optical extension calculation."""
474 isl = ['',
475 f'{ctime()}',
476 'Calculating chi0 optical extensions with:',
477 self.get_gs_info_string(tab=' '),
478 '',
479 ' Linear response parametrization:',
480 self.get_response_info_string(qpd, tab=' ')]
481 self.context.print('\n'.join(isl))
484def get_frequency_descriptor(
485 frequencies: ArrayLike1D | dict[str, Any] | None = None, *,
486 gs: ResponseGroundStateAdapter | None = None,
487 nbands: int | slice | None = None):
488 """Helper function to generate frequency descriptors.
490 In most cases, the `frequencies` input can be processed directly via
491 wd = FrequencyDescriptor.from_array_or_dict(frequencies),
492 but in cases where `frequencies` does not specify omegamax, it is
493 calculated from the input ground state adapter.
494 """
495 if frequencies is None:
496 frequencies = {'type': 'nonlinear'} # default frequency grid
497 if isinstance(frequencies, dict) and frequencies.get('omegamax') is None:
498 assert gs is not None
499 frequencies['omegamax'] = get_omegamax(gs, nbands)
500 return FrequencyDescriptor.from_array_or_dict(frequencies)
503def get_omegamax(gs: ResponseGroundStateAdapter,
504 nbands: int | slice | None = None):
505 """Get the maxmimum eigenvalue difference including nbands, in eV."""
506 epsmin, epsmax = gs.get_eigenvalue_range(nbands=nbands)
507 return (epsmax - epsmin) * Ha