Coverage for gpaw/nlopt/adapters.py: 84%
64 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1from __future__ import annotations
2from types import SimpleNamespace
3from typing import TYPE_CHECKING
5import numpy as np
7if TYPE_CHECKING:
8 from gpaw.core.atom_arrays import AtomArrays
9 from gpaw.core.plane_waves import PWArray
10 from gpaw.new.ase_interface import ASECalculator
11 from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions
12 from gpaw.typing import ArrayND
15class GSInfo:
16 """
17 This is the base class for the ground state adapters in the non-linear
18 optics module. It is only compatible with GPAW_NEW.
20 The class should never be called directly, but should instead be called
21 through the CollinearGSInfo or NoncollinearGSInfo classes.
23 These subclasses are necessary due to the different ways which the spin
24 index is handled in collinear and noncollinear ground state calculations.
25 """
26 def __init__(self,
27 calc: ASECalculator):
28 assert calc.params.mode.name == 'pw', \
29 'Calculator must be in plane wave mode.'
31 dft = calc.dft
32 self.nabla_aiiv = [setup.nabla_iiv for setup in dft.setups]
34 ibzwfs = self.ibzwfs = dft.ibzwfs
35 if not (ibzwfs.domain_comm.size == 1 and ibzwfs.band_comm.size == 1):
36 raise ValueError('Calculator must be initialised with '
37 'only k-point parallelisation.')
38 if isinstance(ibzwfs.wfs_qs[0][0].psit_nX, SimpleNamespace):
39 raise ValueError('Calculator is missing wfs data. If loading from '
40 'a .gpw file, please recalculate wave functions.')
42 density = dft.density
43 self.collinear = density.collinear
44 self.ndensities = density.ndensities
46 grid = density.nt_sR.desc
47 self.ucvol = np.abs(np.linalg.det(grid.cell))
48 self.bzvol = np.abs(np.linalg.det(2 * np.pi * grid.icell))
50 def get_plane_wave_coefficients(self,
51 wfs: PWFDWaveFunctions,
52 bands: slice,
53 spin: int) -> tuple[ArrayND, ArrayND]:
54 """
55 Returns the plane wave coefficients and reciprocal vectors.
57 Output is an array with shape (band index, reciprocal vector index)
58 """
59 psit_nG = wfs.psit_nX[bands]
60 G_plus_k_Gv = psit_nG.desc.G_plus_k_Gv
61 return G_plus_k_Gv, self._pw_data(psit_nG, spin)
63 def get_wave_function_projections(self,
64 wfs: PWFDWaveFunctions,
65 bands: slice,
66 spin: int):
67 """
68 Returns the projections of the pseudo wfs onto the partial waves.
70 Output is a dictionary with atom index keys and array values with
71 shape (band index, partial wave index)
72 """
73 return self._proj_data(wfs.P_ani, bands, spin)
75 def get_wfs(self,
76 wfs_s: list[PWFDWaveFunctions],
77 spin: int) -> PWFDWaveFunctions:
78 raise NotImplementedError
80 @staticmethod
81 def _pw_data(psit: PWArray,
82 spin: int) -> ArrayND:
83 raise NotImplementedError
85 @staticmethod
86 def _proj_data(P: AtomArrays,
87 bands: slice,
88 spin: int) -> dict[int, ArrayND]:
89 raise NotImplementedError
92class CollinearGSInfo(GSInfo):
93 def __init__(self,
94 calc: ASECalculator):
95 super().__init__(calc)
96 self.ns = self.ndensities
98 def get_wfs(self,
99 wfs_s: list[PWFDWaveFunctions],
100 spin: int) -> PWFDWaveFunctions:
101 return wfs_s[spin]
103 @staticmethod
104 def _pw_data(psit_nG: PWArray,
105 _: int | None = None) -> ArrayND:
106 return psit_nG.data
108 @staticmethod
109 def _proj_data(P_ani: AtomArrays,
110 bands: slice,
111 _: int | None = None) -> dict[int, ArrayND]:
112 return {a: P_ni[bands] for a, P_ni in P_ani.items()}
115class NoncollinearGSInfo(GSInfo):
116 def __init__(self,
117 calc: ASECalculator):
118 super().__init__(calc)
119 self.ns = 2
121 def get_wfs(self,
122 wfs_s: list[PWFDWaveFunctions],
123 _: int | None = None) -> PWFDWaveFunctions:
124 return wfs_s[0]
126 @staticmethod
127 def _pw_data(psit_nsG: PWArray,
128 spin: int) -> ArrayND:
129 return psit_nsG.data[:, spin]
131 @staticmethod
132 def _proj_data(P_ansi: AtomArrays,
133 bands: slice,
134 spin: int) -> dict[int, ArrayND]:
135 return {a: P_nsi[bands, spin] for a, P_nsi in P_ansi.items()}