Coverage for gpaw/new/wave_functions.py: 85%
108 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1from __future__ import annotations
3from types import ModuleType
5import numpy as np
6from gpaw.core.atom_arrays import AtomArrays, AtomDistribution
7from gpaw.core.uniform_grid import UGArray
8from gpaw.mpi import MPIComm, serial_comm
9from gpaw.new import trace, zips
10from gpaw.new.potential import Potential
11from gpaw.setup import Setups
12from gpaw.typing import Array1D, Array2D, ArrayND
15class WaveFunctions:
16 bytes_per_band: int
17 xp: ModuleType # numpy or cupy
19 def __init__(self,
20 *,
21 setups: Setups,
22 nbands: int,
23 relpos_ac: Array2D,
24 atomdist: AtomDistribution,
25 spin: int = 0,
26 q: int = 0,
27 k: int = 0,
28 kpt_c=(0.0, 0.0, 0.0),
29 weight: float = 1.0,
30 ncomponents: int = 1,
31 dtype=float,
32 qspiral_v=None,
33 domain_comm: MPIComm = serial_comm,
34 band_comm: MPIComm = serial_comm):
35 """"""
36 assert spin < ncomponents
38 self.spin = spin
39 self.q = q
40 self.k = k
41 self.setups = setups
42 self.weight = weight
43 self.ncomponents = ncomponents
44 self.dtype = dtype
45 self.kpt_c = kpt_c
46 self.relpos_ac = relpos_ac
47 self.atomdist = atomdist
48 self.domain_comm = domain_comm
49 self.band_comm = band_comm
50 self.nbands = nbands
51 self.qspiral_v = qspiral_v
53 assert domain_comm.size == atomdist.comm.size
55 self.nspins = ncomponents % 3
56 self.spin_degeneracy = ncomponents % 2 + 1
58 self._P_ani: AtomArrays | None = None
60 self._eig_n: Array1D | None = None
61 self._occ_n: Array1D | None = None
63 mynbands = (nbands + band_comm.size - 1) // band_comm.size
64 self.n1 = min(band_comm.rank * mynbands, nbands)
65 self.n2 = min((band_comm.rank + 1) * mynbands, nbands)
67 def __repr__(self):
68 dc = f'{self.domain_comm.rank}/{self.domain_comm.size}'
69 bc = f'{self.band_comm.rank}/{self.band_comm.size}'
70 return (f'{self.__class__.__name__}(nbands={self.nbands}, '
71 f'spin={self.spin}, q={self.q}, k={self.k}, '
72 f'weight={self.weight}, kpt_c={self.kpt_c}, '
73 f'ncomponents={self.ncomponents}, dtype={self.dtype} '
74 f'domain_comm={dc}, band_comm={bc})')
76 def array_shape(self, global_shape: bool = False) -> tuple[int, ...]:
77 raise NotImplementedError
79 def add_to_density(self,
80 nt_sR: UGArray,
81 D_asii: AtomArrays) -> None:
82 raise NotImplementedError
84 def add_to_ked(self,
85 taut_sR: UGArray) -> None:
86 raise NotImplementedError
88 def orthonormalize(self, work_array_nX: ArrayND = None):
89 raise NotImplementedError
91 def move(self,
92 relpos_ac: Array2D,
93 atomdist: AtomDistribution,
94 move_wave_functions) -> None:
95 self.relpos_ac = relpos_ac
96 self.atomdist = atomdist
97 self._P_ani = None
98 self._eig_n = None
99 # self._occ_n = None
101 def collect(self,
102 n1: int = 0,
103 n2: int = 0) -> WaveFunctions | None:
104 raise NotImplementedError
106 @property
107 def has_eigs(self) -> bool:
108 # Checks if eigenvalues have been calculated,
109 # that is, one scf step has been performed.
110 return self._eig_n is not None
112 @property
113 def has_occs(self) -> bool:
114 # Checks if occupations have been calculated,
115 # that is, one scf step has been performed.
116 # XXX: In theory, this should be the same as has_eigs,
117 # however, there seems to be a discrepancy during
118 # fixed density calculations.
119 return self._occ_n is not None
121 @property
122 def eig_n(self) -> Array1D:
123 if self._eig_n is None:
124 raise ValueError
125 return self._eig_n
127 @property
128 def occ_n(self) -> Array1D:
129 if self._occ_n is None:
130 raise ValueError
131 return self._occ_n
133 @property
134 def myeig_n(self):
135 return self.eig_n[self.n1:self.n2]
137 @property
138 def myocc_n(self):
139 return self.occ_n[self.n1:self.n2]
141 @property
142 def P_ani(self) -> AtomArrays:
143 if self._P_ani is None:
144 raise RuntimeError('Projections P_ani not present')
145 return self._P_ani
147 @trace
148 def add_to_atomic_density_matrices(self,
149 occ_n,
150 D_asii: AtomArrays) -> None:
151 xp = D_asii.layout.xp
152 occ_n = xp.asarray(occ_n)
153 if self.ncomponents < 4:
154 P_ani = self.P_ani
155 for D_sii, P_ni in zips(D_asii.values(), P_ani.values()):
156 D_sii[self.spin] += xp.einsum('ni, n, nj -> ij',
157 P_ni.conj(), occ_n, P_ni).real
158 else:
159 for D_xii, P_nsi in zips(D_asii.values(), self.P_ani.values()):
160 add_to_4component_density_matrix(D_xii, P_nsi, occ_n, xp)
162 def send(self, kpt_comm, rank):
163 raise NotImplementedError
165 def receive(self, kpt_comm, rank):
166 raise NotImplementedError
168 def force_contribution(self, potential: Potential, F_av: Array2D):
169 raise NotImplementedError
171 def gather_wave_function_coefficients(self) -> np.ndarray | None:
172 raise NotImplementedError
175def add_to_4component_density_matrix(D_xii, P_nsi, occ_n, xp):
176 D_ssii = xp.einsum('nsi, n, nzj -> szij', P_nsi.conj(), occ_n, P_nsi)
177 D_xii[0] += D_ssii[0, 0] + D_ssii[1, 1]
178 D_xii[1] += D_ssii[0, 1] + D_ssii[1, 0]
179 D_xii[2] += -1j * (D_ssii[0, 1] - D_ssii[1, 0])
180 D_xii[3] += D_ssii[0, 0] - D_ssii[1, 1]