Coverage for gpaw/new/lcao/wave_functions.py: 81%
119 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1from __future__ import annotations
3import numpy as np
4from gpaw.core.atom_arrays import (AtomArrays, AtomArraysLayout,
5 AtomDistribution)
6from gpaw.core.matrix import Matrix
7from gpaw.mpi import MPIComm, receive, send, serial_comm
8from gpaw.new.potential import Potential
9from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions
10from gpaw.new.wave_functions import WaveFunctions
11from gpaw.setup import Setups
12from gpaw.typing import Array2D
15class LCAOWaveFunctions(WaveFunctions):
16 xp = np
18 def __init__(self,
19 *,
20 setups: Setups,
21 tci_derivatives,
22 basis,
23 C_nM: Matrix,
24 S_MM: Matrix,
25 T_MM: Matrix,
26 P_aMi,
27 relpos_ac: Array2D,
28 atomdist: AtomDistribution,
29 kpt_c=(0.0, 0.0, 0.0),
30 domain_comm: MPIComm = serial_comm,
31 spin: int = 0,
32 q: int = 0,
33 k: int = 0,
34 weight: float = 1.0,
35 ncomponents: int = 1):
36 super().__init__(setups=setups,
37 nbands=C_nM.shape[0],
38 spin=spin,
39 q=q,
40 k=k,
41 kpt_c=kpt_c,
42 weight=weight,
43 relpos_ac=relpos_ac,
44 atomdist=atomdist,
45 ncomponents=ncomponents,
46 dtype=C_nM.dtype,
47 domain_comm=domain_comm,
48 band_comm=C_nM.dist.comm)
49 self.tci_derivatives = tci_derivatives
50 self.basis = basis
51 self.C_nM = C_nM
52 self.T_MM = T_MM
53 self.S_MM = S_MM
54 self.P_aMi = P_aMi
56 self.bytes_per_band = (self.array_shape(global_shape=True)[0] *
57 C_nM.data.itemsize)
59 # This is for TB-mode (and MYPY):
60 self.V_MM: Matrix
62 self._L_MM = None
64 def move(self,
65 relpos_ac: Array2D,
66 atomdist: AtomDistribution,
67 move_wave_functions) -> None:
68 self._update_phases(relpos_ac)
69 super().move(relpos_ac, atomdist, move_wave_functions)
70 self._L_MM = None
72 def _update_phases(self, relpos_ac):
73 """Complex-rotate coefficients compensating discontinuous phase shift.
75 This changes the coefficients to counteract the phase discontinuity
76 of overlaps when atoms move across a cell boundary."""
78 # We don't want to apply any phase shift unless we crossed a cell
79 # boundary. So we round the shift to either 0 or 1.
80 #
81 # Example: spos_ac goes from 0.01 to 0.99 -- this rounds to 1 and
82 # we apply the phase. If someone moves an atom by half a cell
83 # without crossing a boundary, then we are out of luck. But they
84 # should have reinitialized from LCAO anyway.
86 C_nM = self.C_nM.data
87 if C_nM.dtype == float:
88 return
89 diff_ac = (relpos_ac - self.relpos_ac).round()
90 if not diff_ac.any():
91 return
92 phase_a = np.exp(2j * np.pi * diff_ac @ self.kpt_c)
93 M1 = 0
94 for phase, sphere in zip(phase_a, self.basis.sphere_a):
95 M2 = M1 + sphere.Mmax
96 C_nM[:, M1:M2] *= phase
97 M1 = M2
99 @property
100 def L_MM(self):
101 if self._L_MM is None:
102 S_MM = self.S_MM.copy()
103 S_MM.invcholesky()
104 if self.ncomponents < 4:
105 self._L_MM = S_MM
106 else:
107 M, M = S_MM.shape
108 L_sMsM = Matrix(2 * M, 2 * M, dtype=complex)
109 L_sMsM.data[:] = 0.0
110 L_sMsM.data[:M, :M] = S_MM.data
111 L_sMsM.data[M:, M:] = S_MM.data
112 self._L_MM = L_sMsM
113 return self._L_MM
115 def _short_string(self, global_shape):
116 return f'basis functions: {global_shape[0]}'
118 def array_shape(self, global_shape=False):
119 if global_shape:
120 return self.C_nM.shape[1:]
121 1 / 0
123 @property
124 def _layout(self):
125 atomdist = AtomDistribution.from_atom_indices(
126 list(self.P_aMi),
127 self.domain_comm,
128 natoms=len(self.setups))
129 return AtomArraysLayout([setup.ni for setup in self.setups],
130 atomdist=atomdist,
131 dtype=self.dtype)
133 @property
134 def P_ani(self):
135 if self._P_ani is None:
136 self._P_ani = self._layout.empty(self.nbands,
137 comm=self.C_nM.dist.comm)
138 # As a hack, builder.py injects a NaN in the first element of
139 # C_nM.data in order for us to be able to tell that the
140 # data is uninitialized:
141 if not isinstance(self.C_nM, Matrix):
142 raise RuntimeError('There are no projections or wavefunctions')
144 for a, P_Mi in self.P_aMi.items():
145 self._P_ani[a][:] = self.C_nM.data @ P_Mi
147 return self._P_ani
149 def add_to_density(self,
150 nt_sR,
151 D_asii: AtomArrays) -> None:
152 """Add density from wave functions.
154 Adds to ``nt_sR`` and ``D_asii``.
155 """
156 rho_MM = self.calculate_density_matrix()
157 self.basis.construct_density(rho_MM, nt_sR.data[self.spin], q=self.q)
158 f_n = self.weight * self.spin_degeneracy * self.myocc_n
159 self.add_to_atomic_density_matrices(f_n, D_asii)
161 def gather_wave_function_coefficients(self) -> np.ndarray:
162 C_nM = self.C_nM.gather()
163 if C_nM is not None:
164 return C_nM.data
165 return None
167 def calculate_density_matrix(self,
168 *,
169 eigs=False,
170 transposed=False) -> np.ndarray:
171 """Calculate the density matrix.
173 The density matrix is:::
175 -- *
176 ρ = > C C f
177 μν -- nμ nν n
178 n
180 Returns
181 -------
182 The density matrix in the LCAO basis
183 """
184 if self.domain_comm.rank == 0:
185 f_n = self.weight * self.spin_degeneracy * self.myocc_n
186 if eigs:
187 f_n *= self.myeig_n
188 TempC_nM = self.C_nM.copy()
189 TempC_nM.data *= f_n[:, None]
190 rho_MM = TempC_nM.multiply(self.C_nM, opa='C')
191 if transposed:
192 rho_MM.complex_conjugate()
193 rho_MM_data = rho_MM.data
194 else:
195 rho_MM_data = np.empty_like(self.T_MM.data)
196 self.domain_comm.broadcast(rho_MM_data, 0)
198 return rho_MM_data
200 def to_uniform_grid_wave_functions(self,
201 grid,
202 basis):
203 grid = grid.new(kpt=self.kpt_c, dtype=self.dtype)
204 psit_nR = grid.zeros(self.nbands, self.band_comm)
205 basis.lcao_to_grid(self.C_nM.data, psit_nR.data, self.q)
207 wfs = PWFDWaveFunctions.from_wfs(self, psit_nR)
208 if self._eig_n is not None:
209 wfs._eig_n = self._eig_n.copy()
210 return wfs
212 def collect(self,
213 n1: int = 0,
214 n2: int = 0) -> LCAOWaveFunctions | None:
215 # Quick'n'dirty implementation
216 # We should generalize the PW+FD method
217 assert self.band_comm.size == 1
218 n2 = n2 or self.nbands + n2
219 return LCAOWaveFunctions(
220 setups=self.setups,
221 tci_derivatives=self.tci_derivatives,
222 basis=self.basis,
223 C_nM=Matrix(n2 - n1,
224 self.C_nM.shape[1],
225 data=self.C_nM.data[n1:n2].copy()),
226 S_MM=self.S_MM,
227 T_MM=self.T_MM,
228 P_aMi=self.P_aMi,
229 relpos_ac=self.relpos_ac,
230 atomdist=self.atomdist.gather(),
231 kpt_c=self.kpt_c,
232 spin=self.spin,
233 q=self.q,
234 k=self.k,
235 weight=self.weight,
236 ncomponents=self.ncomponents)
238 def force_contribution(self, potential: Potential, F_av: Array2D):
239 from gpaw.new.lcao.forces import add_force_contributions
240 add_force_contributions(self, potential, F_av)
241 return F_av
243 def send(self, rank, comm):
244 stuff = (self.kpt_c,
245 self.C_nM.data,
246 self.spin,
247 self.q,
248 self.k,
249 self.weight,
250 self.ncomponents)
251 send(stuff, rank, comm)
253 def receive(self, rank, comm):
254 kpt_c, data, spin, q, k, weight, ncomponents = receive(rank, comm)
255 return LCAOWaveFunctions(setups=self.setups,
256 tci_derivatives=self.tci_derivatives,
257 basis=self.basis,
258 C_nM=Matrix(*data.shape, data=data),
259 S_MM=None,
260 T_MM=None,
261 P_aMi=None,
262 relpos_ac=self.relpos_ac,
263 atomdist=self.atomdist.gather(),
264 kpt_c=kpt_c,
265 spin=spin,
266 q=q,
267 k=k,
268 weight=weight,
269 ncomponents=ncomponents)