Coverage for gpaw/new/pwfd/wave_functions.py: 86%
273 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1from __future__ import annotations
3from functools import partial
4from math import pi
5from typing import Optional, Callable
7import numpy as np
8from gpaw.core.arrays import DistributedArrays as XArray
9from gpaw.core.atom_arrays import AtomArrays, AtomDistribution
10from gpaw.core.atom_centered_functions import AtomCenteredFunctions
11from gpaw.core.plane_waves import PWArray
12from gpaw.core.uniform_grid import UGArray, UGDesc
13from gpaw.fftw import get_efficient_fft_size
14from gpaw.gpu import as_np, XP
15from gpaw.mpi import receive, send
16from gpaw.new import prod, trace, zips
17from gpaw.new.potential import Potential
18from gpaw.new.wave_functions import WaveFunctions
19from gpaw.setup import Setups
20from gpaw.typing import Array2D, Array3D, Vector
21from gpaw.utilities import as_real_dtype
24class PWFDWaveFunctions(WaveFunctions, XP):
25 def __init__(self,
26 psit_nX: XArray,
27 *,
28 spin: int,
29 q: int,
30 k: int,
31 setups: Setups,
32 relpos_ac: Array2D,
33 atomdist: AtomDistribution,
34 weight: float = 1.0,
35 ncomponents: int = 1,
36 qspiral_v: Vector | None = None):
37 # assert isinstance(atomdist, AtomDistribution)
38 self.psit_nX = psit_nX
39 nbands = psit_nX.dims[0]
40 super().__init__(setups=setups,
41 nbands=nbands,
42 spin=spin,
43 q=q,
44 k=k,
45 kpt_c=psit_nX.desc.kpt_c,
46 relpos_ac=relpos_ac,
47 atomdist=atomdist,
48 weight=weight,
49 ncomponents=ncomponents,
50 qspiral_v=qspiral_v,
51 dtype=psit_nX.desc.dtype,
52 domain_comm=psit_nX.desc.comm,
53 band_comm=psit_nX.comm)
54 self._pt_aiX: Optional[AtomCenteredFunctions] = None
55 self.orthonormalized = False
56 self.bytes_per_band = (prod(self.array_shape(global_shape=True)) *
57 psit_nX.desc.itemsize)
58 XP.__init__(self, self.psit_nX.xp)
60 @classmethod
61 def from_wfs(cls,
62 wfs: PWFDWaveFunctions,
63 psit_nX: XArray,
64 relpos_ac=None,
65 atomdist=None) -> PWFDWaveFunctions:
66 return cls(
67 psit_nX,
68 spin=wfs.spin,
69 q=wfs.q,
70 k=wfs.k,
71 setups=wfs.setups,
72 relpos_ac=wfs.relpos_ac if relpos_ac is None else relpos_ac,
73 atomdist=atomdist or wfs.atomdist,
74 weight=wfs.weight,
75 ncomponents=wfs.ncomponents,
76 qspiral_v=wfs.qspiral_v)
78 def __del__(self):
79 # We could be reading from a gpw-file
80 data = self.psit_nX.data
81 if hasattr(data, 'fd'):
82 data.fd.close()
84 def _short_string(self, global_shape: tuple[int]) -> str:
85 return self.psit_nX.desc._short_string(global_shape)
87 def array_shape(self, global_shape=False):
88 if global_shape:
89 shape = self.psit_nX.desc.global_shape()
90 else:
91 shape = self.psit_nX.desc.myshape
92 if self.ncomponents == 4:
93 shape = (2,) + shape
94 return shape
96 @property
97 def pt_aiX(self) -> AtomCenteredFunctions:
98 """PAW projector functions.
100 :::
102 ~a _
103 p (r)
104 i
105 """
106 if self._pt_aiX is None:
107 self._pt_aiX = self.psit_nX.desc.atom_centered_functions(
108 [setup.pt_j for setup in self.setups],
109 self.relpos_ac,
110 atomdist=self.atomdist,
111 qspiral_v=self.qspiral_v,
112 xp=self.psit_nX.xp)
113 return self._pt_aiX
115 @property
116 def P_ani(self) -> AtomArrays:
117 """PAW projections.
119 :::
121 ~a ~
122 <p | ψ >
123 i n
124 """
125 if self._P_ani is None:
126 self._P_ani = self.pt_aiX.empty(self.psit_nX.dims,
127 self.psit_nX.comm)
128 if self.psit_nX.data is None:
129 raise RuntimeError('There are no projections or wavefunctions')
130 self.pt_aiX.integrate(self.psit_nX, self._P_ani)
131 return self._P_ani
133 def move(self,
134 relpos_ac: Array2D,
135 atomdist: AtomDistribution,
136 move_wave_functions: Callable[..., None]) -> None:
137 if self.psit_nX.data is not None:
138 move_wave_functions(
139 self.relpos_ac,
140 relpos_ac,
141 self.P_ani,
142 self.psit_nX,
143 self.setups)
144 super().move(relpos_ac, atomdist, move_wave_functions)
145 self.orthonormalized = False
146 assert self.pt_aiX is not None
147 self.pt_aiX.move(relpos_ac, atomdist)
149 def add_to_density(self,
150 nt_sR: UGArray,
151 D_asii: AtomArrays) -> None:
152 occ_n = self.weight * self.spin_degeneracy * self.myocc_n
154 self.add_to_atomic_density_matrices(occ_n, D_asii)
156 if self.ncomponents < 4:
157 self.psit_nX.abs_square(weights=occ_n, out=nt_sR[self.spin])
158 return
160 psit_nsG = self.psit_nX
161 assert isinstance(psit_nsG, PWArray)
163 tmp_sR = nt_sR.desc.new(dtype=complex).empty(2)
164 p1_R, p2_R = tmp_sR.data
165 nt_xR = nt_sR.data
167 for f, psit_sG in zips(occ_n, psit_nsG):
168 psit_sG.ifft(out=tmp_sR)
169 p11_R = p1_R.real**2 + p1_R.imag**2
170 p22_R = p2_R.real**2 + p2_R.imag**2
171 p12_R = p1_R.conj() * p2_R
172 nt_xR[0] += f * (p11_R + p22_R)
173 nt_xR[1] += 2 * f * p12_R.real
174 nt_xR[2] += 2 * f * p12_R.imag
175 nt_xR[3] += f * (p11_R - p22_R)
177 def add_to_ked(self, taut_sR) -> None:
178 occ_n = self.weight * self.spin_degeneracy * self.myocc_n
179 self.psit_nX.add_ked(occ_n, taut_sR[self.spin])
181 @trace
182 def orthonormalize(self, psit2_nX):
183 r"""Orthonormalize wave functions.
185 Computes the overlap matrix:::
187 / ~ _ *~ _ _ --- a * a a
188 S = | ψ(r) ψ(r) dr + > (P ) P ΔS
189 mn / m n --- im jn ij
190 aij
192 With `LSL^\dagger=1`, we update the wave functions and projections
193 inplace like this:::
195 -- *
196 Ψ <- > L Ψ ,
197 m -- mn n
198 n
200 and:::
202 a -- * a
203 P <- > L P .
204 mi -- mn ni
205 n
207 """
208 if self.orthonormalized:
209 return
210 psit_nX = self.psit_nX
211 domain_comm = psit_nX.desc.comm
212 P_ani = self.P_ani
214 P2_ani = P_ani.new()
215 if psit2_nX is None:
216 psit2_nX = psit_nX.new()
217 dS_aii = self.setups.get_overlap_corrections(
218 P_ani.layout.atomdist,
219 self.xp,
220 dtype=as_real_dtype(P_ani.data.dtype))
222 # We are actually calculating S^*:
223 S = psit_nX.matrix_elements(psit_nX, domain_sum=False, cc=True)
224 P_ani.block_diag_multiply(dS_aii, out_ani=P2_ani)
225 P_ani.matrix.multiply(P2_ani, opb='C', symmetric=True, out=S, beta=1.0)
226 domain_comm.sum(S.data, 0)
228 if domain_comm.rank == 0:
229 S.invcholesky()
230 domain_comm.broadcast(S.data, 0)
231 # S now contains L^*
233 S.multiply(psit_nX, out=psit2_nX)
234 S.multiply(P_ani, out=P2_ani)
235 psit_nX.data[:] = psit2_nX.data
236 P_ani.data[:] = P2_ani.data
237 self.orthonormalized = True
239 @trace
240 def subspace_diagonalize(self,
241 Ht,
242 dH,
243 psit2_nX,
244 data_buffer=None,
245 scalapack_parameters=(None, 1, 1, None)):
246 """
247 If data_buffer is None, psit2_nX will be used as a buffer
248 for the wave functions.
250 Ht(in, out):::
252 ~ ^ ~
253 H = T + v
255 dH:::
257 ~ ~ a ~ ~
258 <𝜓 |p> ΔH <p |𝜓>
259 m i ij j n
260 """
261 self.orthonormalize(psit2_nX)
262 psit_nX = self.psit_nX
263 P_ani = self.P_ani
264 P2_ani = P_ani.new()
265 domain_comm = psit_nX.desc.comm
267 Ht = partial(Ht, out=psit2_nX, spin=self.spin)
268 H = psit_nX.matrix_elements(psit_nX,
269 function=Ht,
270 domain_sum=False,
271 cc=True)
272 dH(P_ani, out_ani=P2_ani, spin=self.spin)
273 P_ani.matrix.multiply(P2_ani, opb='C', symmetric=True,
274 out=H, beta=1.0)
275 domain_comm.sum(H.data, 0)
277 if domain_comm.rank == 0:
278 slcomm, r, c, b = scalapack_parameters
279 if r == c == 1:
280 slcomm = None
281 self._eig_n = as_np(H.eigh(scalapack=(slcomm, r, c, b)))
282 H.complex_conjugate()
283 # H.data[n, :] now contains the nth eigenvector and eps_n[n]
284 # the nth eigenvalue
285 else:
286 self._eig_n = np.empty(psit_nX.dims)
288 domain_comm.broadcast(H.data, 0)
289 domain_comm.broadcast(self._eig_n, 0)
290 if data_buffer is None:
291 H.multiply(psit_nX, out=psit2_nX)
292 psit_nX.data[:] = psit2_nX.data
293 H.multiply(P_ani, out=P2_ani)
294 P_ani.data[:] = P2_ani.data
295 else:
296 H.multiply(psit_nX, out=psit_nX, data_buffer=data_buffer)
297 H.multiply(psit2_nX, out=psit2_nX, data_buffer=data_buffer)
298 H.multiply(P_ani, out=P2_ani)
299 P_ani.data[:] = P2_ani.data
301 def force_contribution(self,
302 potential: Potential,
303 F_av: Array2D) -> None:
304 xp = self.xp
305 dH_asii = potential.dH_asii
306 myeig_n = xp.asarray(self.myeig_n)
307 myocc_n = xp.asarray(
308 self.weight * self.spin_degeneracy * self.myocc_n)
310 if self.ncomponents == 4:
311 self._non_collinear_force_contribution(dH_asii, myocc_n, F_av)
312 return
314 F_anvi = self.pt_aiX.derivative(self.psit_nX)
315 for a, F_nvi in F_anvi.items():
316 F_nvi = F_nvi.conj()
317 F_nvi *= myocc_n[:, np.newaxis, np.newaxis]
318 dH_ii = dH_asii[a][self.spin]
319 P_ni = self.P_ani[a]
320 F_vii = xp.einsum('nvi, nj, jk -> vik', F_nvi, P_ni, dH_ii)
321 F_nvi *= myeig_n[:, np.newaxis, np.newaxis]
322 dO_ii = xp.asarray(self.setups[a].dO_ii)
323 F_vii -= xp.einsum('nvi, nj, jk -> vik', F_nvi, P_ni, dO_ii)
324 F_av[a] += 2 * F_vii.real.trace(0, 1, 2)
326 def _non_collinear_force_contribution(self,
327 dH_asii,
328 myocc_n,
329 F_av):
330 F_ansvi = self.pt_aiX.derivative(self.psit_nX)
331 for a, F_nsvi in F_ansvi.items():
332 F_nsvi = F_nsvi.conj()
333 F_nsvi *= myocc_n[:, np.newaxis, np.newaxis, np.newaxis]
334 dH_sii = dH_asii[a]
335 dH_ii = dH_sii[0]
336 dH_vii = dH_sii[1:]
337 dH_ssii = np.array(
338 [[dH_ii + dH_vii[2], dH_vii[0] - 1j * dH_vii[1]],
339 [dH_vii[0] + 1j * dH_vii[1], dH_ii - dH_vii[2]]])
340 P_nsi = self.P_ani[a]
341 F_v = np.einsum('nsvi, stij, ntj -> v', F_nsvi, dH_ssii, P_nsi)
342 F_nsvi *= self.myeig_n[:, np.newaxis, np.newaxis, np.newaxis]
343 dO_ii = self.setups[a].dO_ii
344 F_v -= np.einsum('nsvi, ij, nsj -> v', F_nsvi, dO_ii, P_nsi)
345 F_av[a] += 2 * F_v.real
347 def collect(self,
348 n1: int = 0,
349 n2: int = 0) -> PWFDWaveFunctions | None:
350 """Collect range of bands to master of band and domain comms."""
351 # Also collect projections instead of recomputing XXX
352 n2 = n2 if n2 > 0 else self.nbands + n2
353 spinors = (2,) if self.ncomponents == 4 else ()
354 band_comm = self.psit_nX.comm
355 domain_comm = self.psit_nX.desc.comm
356 nbands = self.nbands
357 mynbands = (nbands + band_comm.size - 1) // band_comm.size
358 rank1, b1 = divmod(n1, mynbands)
359 rank2, b2 = divmod(n2, mynbands)
360 if band_comm.rank == 0:
361 if domain_comm.rank == 0:
362 psit_nX = self.psit_nX.desc.new(comm=None).empty(
363 (n2 - n1, *spinors))
364 rank = rank1
365 ba = b1
366 na = n1
367 while (rank, ba) < (rank2, b2):
368 bb = min((rank + 1) * mynbands, nbands) - rank * mynbands
369 if rank == rank2 and bb > b2:
370 bb = b2
371 nb = na + bb - ba
372 if bb > ba:
373 if rank == 0:
374 psit_bX = self.psit_nX[ba:bb].gather()
375 if domain_comm.rank == 0:
376 psit_nX.data[:bb - ba] = psit_bX.data
377 else:
378 if domain_comm.rank == 0:
379 band_comm.receive(psit_nX.data[na - n1:nb - n1],
380 rank)
381 rank += 1
382 ba = 0
383 na = nb
384 if domain_comm.rank == 0:
385 wfs = PWFDWaveFunctions.from_wfs(
386 self,
387 psit_nX,
388 atomdist=self.atomdist.gather())
389 wfs._eig_n = self.eig_n[n1:n2]
390 return wfs
391 else:
392 rank = band_comm.rank
393 ranka, ba = max((rank1, b1), (rank, 0))
394 rankb, bb = min((rank2, b2), (rank, self.psit_nX.mydims[0]))
395 if (ranka, ba) < (rankb, bb):
396 assert ranka == rankb == rank
397 band_comm.send(self.psit_nX.data[ba:bb], dest=0)
399 return None
401 def send(self, rank, comm):
402 stuff = (self.kpt_c,
403 self.psit_nX.data,
404 self.spin,
405 self.q,
406 self.k,
407 self.weight)
408 send(stuff, rank, comm)
410 def receive(self, rank, comm):
411 kpt_c, data, spin, q, k, weight = receive(rank, comm)
412 psit_nX = self.psit_nX.desc.new(kpt=kpt_c, comm=None).from_data(data)
413 return PWFDWaveFunctions(psit_nX,
414 spin=spin,
415 q=q,
416 k=k,
417 setups=self.setups,
418 relpos_ac=self.relpos_ac,
419 atomdist=self.atomdist.gather(),
420 weight=weight,
421 ncomponents=self.ncomponents,
422 qspiral_v=self.qspiral_v)
424 def dipole_matrix_elements(self) -> Array3D:
425 """Calculate dipole matrix-elements.
427 :::
429 _ / ~ ~ _ _ --- a a _a
430 μ = | 𝜓 𝜓 rdr + > P P Δμ
431 mn / m n --- im jn ij
432 aij
434 Returns
435 -------
436 Array3D:
437 matrix elements in atomic units.
438 """
439 cell_cv = self.psit_nX.desc.cell_cv
441 dipole_nnv = np.zeros((self.nbands, self.nbands, 3))
443 position_av = self.relpos_ac @ cell_cv
445 R_aiiv = []
446 for setup, position_v in zips(self.setups, position_av):
447 Delta_iiL = setup.Delta_iiL
448 R_iiv = Delta_iiL[:, :, [3, 1, 2]] * (4 * pi / 3)**0.5
449 R_iiv += position_v * setup.Delta_iiL[:, :, :1] * (4 * pi)**0.5
450 R_aiiv.append(R_iiv)
452 for a, P_ni in self.P_ani.items():
453 dipole_nnv += np.einsum('mi, ijv, nj -> mnv',
454 P_ni, R_aiiv[a], P_ni)
456 self.psit_nX.desc.comm.sum(dipole_nnv)
458 if isinstance(self.psit_nX, UGArray):
459 psit_nR = self.psit_nX
460 else:
461 assert isinstance(self.psit_nX, PWArray)
462 # Find size of fft grid large enough to store square of wfs.
463 pw = self.psit_nX.desc
464 s1, s2, s3 = np.ptp(pw.indices_cG, axis=1) # type: ignore
465 assert pw.dtype == float
466 # Last dimension is special because dtype=float:
467 size_c = [2 * s1 + 2,
468 2 * s2 + 2,
469 4 * s3 + 2]
470 size_c = [get_efficient_fft_size(N, 2) for N in size_c]
471 grid = UGDesc(cell=pw.cell_cv, size=size_c)
472 psit_nR = self.psit_nX.ifft(grid=grid)
474 for na, psita_R in enumerate(psit_nR):
475 for nb, psitb_R in enumerate(psit_nR[:na + 1]):
476 d_v = (psita_R * psitb_R).moment()
477 dipole_nnv[na, nb] += d_v
478 if na != nb:
479 dipole_nnv[nb, na] += d_v
481 return dipole_nnv
483 def gather_wave_function_coefficients(self) -> np.ndarray | None:
484 psit_nX = self.psit_nX.gather() # gather X
485 if psit_nX is not None:
486 data_nX = psit_nX.matrix.gather() # gather n
487 if data_nX.dist.comm.rank == 0:
488 # XXX PW-gamma-point mode: float or complex matrix.dtype?
489 return data_nX.data.view(
490 psit_nX.data.dtype).reshape((-1,) + psit_nX.data.shape[1:])
491 return None
493 def to_uniform_grid_wave_functions(self,
494 grid,
495 basis):
496 if isinstance(self.psit_nX, UGArray):
497 return self
499 grid = grid.new(kpt=self.kpt_c, dtype=self.dtype)
500 psit_nR = grid.zeros(self.nbands, self.band_comm)
501 self.psit_nX.ifft(out=psit_nR)
502 return PWFDWaveFunctions.from_wfs(self, psit_nR)
504 def morph(self, desc, relpos_ac, atomdist):
505 desc = desc.new(kpt=self.psit_nX.desc.kpt_c)
506 psit_nX = self.psit_nX.morph(desc)
508 # Save memory:
509 self.psit_nX.data = None
510 self._P_ani = None
511 self._pt_aiX = None
513 return PWFDWaveFunctions.from_wfs(self, psit_nX,
514 relpos_ac=relpos_ac)