Coverage for gpaw/response/groundstate.py: 94%
253 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1from __future__ import annotations
2from dataclasses import dataclass
3from typing import Union
4from pathlib import Path
5from functools import cached_property
6from types import SimpleNamespace
7from typing import TYPE_CHECKING
8import numpy as np
10from ase.units import Ha, Bohr
12import gpaw.mpi as mpi
13from gpaw.ibz2bz import IBZ2BZMaps
14from gpaw.calculator import GPAW as OldGPAW
15from gpaw.new.ase_interface import ASECalculator as NewGPAW
16from gpaw.response.paw import LeanPAWDataset
18if TYPE_CHECKING:
19 from gpaw.setup import Setups, LeanSetup
22class PAWDatasetCollection:
23 def __init__(self, setups: Setups):
24 by_species = {}
25 by_atom = []
26 id_by_atom = []
28 for atom_id, setup in enumerate(setups):
29 species_id = setups.id_a[atom_id]
30 if species_id not in by_species:
31 by_species[species_id] = ResponsePAWDataset(setup)
32 by_atom.append(by_species[species_id])
33 id_by_atom.append(species_id)
35 self.by_species = by_species
36 self.by_atom = by_atom
37 self.id_by_atom = id_by_atom
40GPAWCalculator = Union[OldGPAW, NewGPAW]
41GPWFilename = Union[Path, str]
42ResponseGroundStateAdaptable = Union['ResponseGroundStateAdapter',
43 GPAWCalculator,
44 GPWFilename]
47class ResponseGroundStateAdapter:
48 def __init__(self, calc: GPAWCalculator):
49 wfs = calc.wfs # wavefunction object from gpaw.wavefunctions
51 self.atoms = calc.atoms
52 self.kd = wfs.kd # KPointDescriptor object from gpaw.kpt_descriptor.
53 self.world = calc.world # _Communicator object from gpaw.mpi
55 # GridDescriptor from gpaw.grid_descriptor.
56 # Describes a grid in real space
57 self.gd = wfs.gd
59 # Also a GridDescriptor, with a finer grid...
60 self.finegd = calc.density.finegd
61 self.bd = wfs.bd # BandDescriptor from gpaw.band_descriptor
62 self.nspins = wfs.nspins # number of spins: int
63 self.dtype = wfs.dtype # data type of wavefunctions, real or complex
65 self.spos_ac = calc.spos_ac # scaled position vector: np.ndarray
67 self.kpt_u = wfs.kpt_u # kpoints: list of Kpoint from gpaw.kpoint
68 self.kpt_qs = wfs.kpt_qs # kpoints: list of Kpoint from gpaw.kpoint
70 self.fermi_level = wfs.fermi_level # float
71 self.atoms = calc.atoms # ASE Atoms object
72 self.pawdatasets = PAWDatasetCollection(calc.setups)
74 self.pbc = self.atoms.pbc
75 self.volume = self.gd.volume
77 self.nvalence = int(round(wfs.nvalence))
78 assert self.nvalence == wfs.nvalence
80 self.nocc1, self.nocc2 = self.count_occupied_bands()
82 self.ibz2bz = IBZ2BZMaps.from_calculator(calc)
84 self._wfs = wfs
85 self._density = calc.density
86 self._hamiltonian = calc.hamiltonian
87 self._calc = calc
89 @staticmethod
90 def from_input(
91 gs: ResponseGroundStateAdaptable) -> ResponseGroundStateAdapter:
92 if isinstance(gs, ResponseGroundStateAdapter):
93 return gs
94 elif isinstance(gs, (OldGPAW, NewGPAW)):
95 return ResponseGroundStateAdapter(calc=gs)
96 elif isinstance(gs, (Path, str)): # GPWFilename
97 return ResponseGroundStateAdapter.from_gpw_file(gpw=gs)
98 raise ValueError('Expected ResponseGroundStateAdaptable, got', gs)
100 @classmethod
101 def from_gpw_file(cls, gpw: GPWFilename) -> ResponseGroundStateAdapter:
102 """Initiate the ground state adapter directly from a .gpw file."""
103 from gpaw import GPAW, disable_dry_run
104 assert Path(gpw).is_file()
105 with disable_dry_run():
106 calc = GPAW(gpw, txt=None, communicator=mpi.serial_comm)
107 return cls(calc)
109 @property
110 def pd(self):
111 # This is an attribute error in FD/LCAO mode.
112 # We need to abstract away "calc" in all places used by response
113 # code, and that includes places that are also compatible with FD.
114 return self._wfs.pd
116 def is_parallelized(self):
117 """Are we dealing with a parallel calculator?"""
118 return self.world.size > 1
120 @cached_property
121 def global_pd(self):
122 """Get a PWDescriptor that includes all k-points.
124 In particular, this is necessary to allow all cores to be able to work
125 on all k-points in the case where calc is parallelized over k-points,
126 see gpaw.response.kspair
127 """
128 from gpaw.pw.descriptor import PWDescriptor
130 assert self.gd.comm.size == 1
131 kd = self.kd.copy() # global KPointDescriptor without a comm
132 return PWDescriptor(self.pd.ecut, self.gd,
133 dtype=self.pd.dtype,
134 kd=kd, fftwflags=self.pd.fftwflags,
135 gammacentered=self.pd.gammacentered)
137 def get_occupations_width(self):
138 # Ugly hack only used by pair.intraband_pair_density I think.
139 # Actually: was copy-pasted in chi0 also.
140 # More duplication can probably be eliminated around those.
142 # Only works with Fermi-Dirac distribution
143 occs = self._wfs.occupations
144 assert occs.name in {'fermi-dirac', 'zero-width'}
146 # No carriers when T=0
147 width = getattr(occs, '_width', 0.0) / Ha
148 return width
150 @cached_property
151 def cd(self):
152 return CellDescriptor(self.gd.cell_cv, self.pbc)
154 @property
155 def nt_sR(self):
156 # Used by localft and fxc_kernels
157 return self._density.nt_sG
159 @property
160 def nt_sr(self):
161 # Used by localft
162 if self._density.nt_sg is None:
163 self._density.interpolate_pseudo_density()
164 return self._density.nt_sg
166 @cached_property
167 def n_sR(self):
168 return self._density.get_all_electron_density(
169 atoms=self.atoms, gridrefinement=1)[0]
171 @cached_property
172 def n_sr(self):
173 return self._density.get_all_electron_density(
174 atoms=self.atoms, gridrefinement=2)[0]
176 @property
177 def D_asp(self):
178 # Used by fxc_kernels
179 return self._density.D_asp
181 def get_pseudo_density(self, gridrefinement=2):
182 # Used by localft
183 if gridrefinement == 1:
184 return self.nt_sR, self.gd
185 elif gridrefinement == 2:
186 return self.nt_sr, self.finegd
187 else:
188 raise ValueError(f'Invalid gridrefinement {gridrefinement}')
190 def get_all_electron_density(self, gridrefinement=2):
191 # Used by fxc, fxc_kernels and localft
192 if gridrefinement == 1:
193 return self.n_sR, self.gd
194 elif gridrefinement == 2:
195 return self.n_sr, self.finegd
196 else:
197 raise ValueError(f'Invalid gridrefinement {gridrefinement}')
199 # Things used by EXX. This is getting pretty involved.
200 #
201 # EXX naughtily accesses the density object in order to
202 # interpolate_pseudo_density() which is in principle mutable.
204 def hacky_all_electron_density(self, **kwargs):
205 # fxc likes to get all electron densities. It calls
206 # calc.get_all_electron_density() and so we wrap that here.
207 # But it also collects to serial (bad), and it also zeropads
208 # nonperiodic directions (probably WRONG!).
209 #
210 # Also this one returns in user units, whereas the calling
211 # code actually wants internal units. Very silly then.
212 #
213 # ALso, the calling code often wants the gd, which is not
214 # returned, so it is redundantly reconstructed in multiple
215 # places by refining the "right" number of times.
216 n_g = self._calc.get_all_electron_density(**kwargs)
217 n_g *= Bohr**3
218 return n_g
220 # Used by EXX.
221 @property
222 def hamiltonian(self):
223 return self._hamiltonian
225 # Used by EXX.
226 @property
227 def density(self):
228 return self._density
230 # Ugh SOC
231 def soc_eigenstates(self, **kwargs):
232 from gpaw.spinorbit import soc_eigenstates
233 return soc_eigenstates(self._calc, **kwargs)
235 @property
236 def xcname(self):
237 return self.hamiltonian.xc.name
239 def get_xc_difference(self, xc):
240 # XXX used by gpaw/xc/tools.py
241 return self._calc.get_xc_difference(xc)
243 def get_wave_function_array(self, u, n):
244 # XXX used by gpaw/xc/tools.py in a hacky way
245 return self._wfs._get_wave_function_array(
246 u, n, realspace=True)
248 def pair_density_paw_corrections(self, qpd):
249 from gpaw.response.paw import get_pair_density_paw_corrections
250 return get_pair_density_paw_corrections(
251 pawdatasets=self.pawdatasets, qpd=qpd, spos_ac=self.spos_ac,
252 atomrotations=self.atomrotations)
254 def matrix_element_paw_corrections(self, qpd, rshe_a):
255 from gpaw.response.paw import get_matrix_element_paw_corrections
256 return get_matrix_element_paw_corrections(
257 qpd, self.pawdatasets, rshe_a, self.spos_ac)
259 def get_pos_av(self):
260 # gd.cell_cv must always be the same as pd.gd.cell_cv, right??
261 return np.dot(self.spos_ac, self.gd.cell_cv)
263 def count_occupied_bands(self, ftol: float = 1e-6) -> tuple[int, int]:
264 """Count the number of filled (nocc1) and nonempty bands (nocc2).
266 ftol : float
267 Threshold determining whether a band is completely filled
268 (f > 1 - ftol) or completely empty (f < ftol).
269 """
270 # Count the number of occupied bands for this rank
271 nocc1, nocc2 = self._count_occupied_bands(ftol=ftol)
272 # Minimize/maximize over k-points
273 nocc1 = self.kd.comm.min_scalar(nocc1) # bands filled for all k
274 nocc2 = self.kd.comm.max_scalar(nocc2) # bands filled for any k
275 # Sum over band distribution
276 nocc1 = self.bd.comm.sum_scalar(nocc1) # number of filled bands
277 nocc2 = self.bd.comm.sum_scalar(nocc2) # number of nonempty bands
278 return int(nocc1), int(nocc2)
280 def _count_occupied_bands(self, *, ftol: float) -> tuple[int, int]:
281 nocc1 = 9999999 # number of completely filled bands
282 nocc2 = 0 # number of nonempty bands
283 for kpt in self.kpt_u:
284 f_n = kpt.f_n / kpt.weight
285 nocc1 = min((f_n > 1 - ftol).sum(), nocc1)
286 nocc2 = max((f_n > ftol).sum(), nocc2)
287 return int(nocc1), int(nocc2)
289 def get_band_transitions(self, nbands: int | slice | None = None):
290 """Determine the indices the define the range of occupied bands
291 n1, n2 and unoccupied bands m1, m2"""
293 if nbands is None:
294 n1 = 0
295 m2 = self.nbands
296 elif isinstance(nbands, int):
297 n1 = 0
298 m2 = nbands
299 assert 1 <= m2 <= self.nbands
300 elif isinstance(nbands, slice):
301 n1 = nbands.start
302 m2 = nbands.stop
303 assert n1 >= 0 and m2 >= 0
304 assert nbands.step in {None, 1}
305 assert n1 < m2 <= self.nbands
306 assert n1 <= self.nocc1
307 else:
308 raise ValueError(
309 f"Invalid type for nbands: {type(nbands)}."
310 "Expected None, int, or slice.")
312 n2 = self.nocc2
313 m1 = self.nocc1
315 assert n1 < n2
317 return n1, n2, m1, m2
319 def get_eigenvalue_range(self, nbands: int | slice | None = None):
320 """Get smallest and largest Kohn-Sham eigenvalues."""
321 n1, n2, m1, m2 = self.get_band_transitions(nbands)
322 epsmin = np.inf
323 epsmax = -np.inf
324 for kpt in self.kpt_u:
325 epsmin = min(epsmin, kpt.eps_n[n1]) # the eigenvalues are ordered
326 epsmax = max(epsmax, kpt.eps_n[m2 - 1])
327 return epsmin, epsmax
329 @property
330 def nbands(self):
331 return self.bd.nbands
333 @property
334 def metallic(self):
335 # Does the number of filled bands equal the number of non-empty bands?
336 return self.nocc1 != self.nocc2
338 @cached_property
339 def ibzq_qc(self):
340 # For G0W0Kernel
341 kd = self.kd
342 bzq_qc = kd.get_bz_q_points(first=True)
343 U_scc = kd.symmetry.op_scc
344 ibzq_qc = kd.get_ibz_q_points(bzq_qc, U_scc)[0]
346 return ibzq_qc
348 def get_ibz_vertices(self):
349 # For the tetrahedron method in Chi0
350 from gpaw.bztools import get_bz
351 # NB: We are ignoring the pbc_c keyword to get_bz() in order to mimic
352 # find_high_symmetry_monkhorst_pack() in gpaw.bztools. XXX
353 _, ibz_vertices_kc, _ = get_bz(self._calc)
354 return ibz_vertices_kc
356 def get_aug_radii(self):
357 return np.array([max(pawdata.rcut_j)
358 for pawdata in self.pawdatasets.by_atom])
360 @cached_property
361 def micro_setups(self):
362 from gpaw.response.localft import extract_micro_setup
363 micro_setups = []
364 for a, pawdata in enumerate(self.pawdatasets.by_atom):
365 micro_setups.append(extract_micro_setup(pawdata, self.D_asp[a]))
366 return micro_setups
368 @property
369 def atomrotations(self):
370 return self._wfs.setups.atomrotations
372 @cached_property
373 def kpoints(self):
374 from gpaw.response.kpoints import ResponseKPointGrid
375 return ResponseKPointGrid(self.kd)
378@dataclass
379class CellDescriptor:
380 cell_cv: np.ndarray
381 pbc_c: np.ndarray
383 @property
384 def nonperiodic_hypervolume(self):
385 """Get the hypervolume of the cell along nonperiodic directions.
387 Returns the hypervolume Λ in units of Å, where
389 Λ = 1 in 3D
390 Λ = L in 2D, where L is the out-of-plane cell vector length
391 Λ = A in 1D, where A is the transverse cell area
392 Λ = V in 0D, where V is the cell volume
393 """
394 cell_cv = self.cell_cv
395 pbc_c = self.pbc_c
396 if sum(pbc_c) > 0:
397 # In 1D and 2D, we assume the cartesian representation of the unit
398 # cell to be block diagonal, separating the periodic and
399 # nonperiodic cell vectors in different blocks.
400 assert np.allclose(cell_cv[~pbc_c][:, pbc_c], 0.) and \
401 np.allclose(cell_cv[pbc_c][:, ~pbc_c], 0.), \
402 "In 1D and 2D, please put the periodic/nonperiodic axis " \
403 "along a cartesian component"
404 L = np.abs(np.linalg.det(cell_cv[~pbc_c][:, ~pbc_c]))
405 return L * Bohr**sum(~pbc_c) # Bohr -> Å
408# Contains all the relevant information
409# from Setups class for response calculators
410class ResponsePAWDataset(LeanPAWDataset):
411 def __init__(self, setup: LeanSetup, **kwargs):
412 super().__init__(
413 rgd=setup.rgd, l_j=setup.l_j, rcut_j=setup.rcut_j,
414 phit_jg=setup.data.phit_jg, phi_jg=setup.data.phi_jg, **kwargs)
415 assert setup.ni == self.ni
417 self.n_j = setup.n_j
418 self.N0_q = setup.N0_q
419 self.nabla_iiv = setup.nabla_iiv
420 self.xc_correction: SimpleNamespace | None
421 if setup.xc_correction is not None:
422 self.xc_correction = SimpleNamespace(
423 rgd=setup.xc_correction.rgd, Y_nL=setup.xc_correction.Y_nL,
424 n_qg=setup.xc_correction.n_qg, nt_qg=setup.xc_correction.nt_qg,
425 nc_g=setup.xc_correction.nc_g, nct_g=setup.xc_correction.nct_g,
426 nc_corehole_g=setup.xc_correction.nc_corehole_g,
427 B_pqL=setup.xc_correction.B_pqL,
428 e_xc0=setup.xc_correction.e_xc0)
429 else:
430 # If there is no `xc_correction` in the setup, we assume to be
431 # using pseudo potentials.
432 self.xc_correction = None
433 # In this case, we set l_j to an empty list in order to bypass the
434 # calculation of PAW corrections to pair densities etc.
435 # This is quite an ugly hack...
436 # If we want to support pseudo potential calculations for real, we
437 # should skip the PAW corrections at the matrix element calculator
438 # level, not by an odd hack.
439 self.l_j = np.array([], dtype=float)
440 self.hubbard_u = setup.hubbard_u