Coverage for gpaw/new/ibzwfs.py: 84%
352 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1from __future__ import annotations
3from functools import cached_property
4from typing import TYPE_CHECKING, Callable, Generator, Generic, TypeVar
6import numpy as np
7from ase.io.ulm import Writer
8from ase.units import Bohr, Ha
9from gpaw.gpu import as_np, synchronize
10from gpaw.gpu.mpi import CuPyMPI
11from gpaw.mpi import MPIComm, serial_comm
12from gpaw.new import zips
13from gpaw.new.timer import trace
14from gpaw.new.brillouin import IBZ
15from gpaw.new.c import GPU_AWARE_MPI
16from gpaw.new.potential import Potential
17from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions
18from gpaw.new.wave_functions import WaveFunctions
19from gpaw.typing import Array1D, Array2D, Self
20from gpaw.utilities import pack_density
22if TYPE_CHECKING:
23 from gpaw.new.density import Density
25WFT = TypeVar('WFT', bound=WaveFunctions)
28class IBZWaveFunctions(Generic[WFT]):
29 def __init__(self,
30 ibz: IBZ,
31 *,
32 ncomponents: int,
33 wfs_qs: list[list[WFT]],
34 kpt_comm: MPIComm = serial_comm,
35 kpt_band_comm: MPIComm = serial_comm,
36 comm: MPIComm = serial_comm):
37 """Collection of wave function objects for k-points in the IBZ."""
38 self.ibz = ibz
39 self.kpt_comm = kpt_comm
40 self.kpt_band_comm = kpt_band_comm
41 self.comm = comm
42 self.ncomponents = ncomponents
43 self.collinear = (ncomponents != 4)
44 self.spin_degeneracy = ncomponents % 2 + 1
45 self.nspins = ncomponents % 3
47 self.rank_k = ibz.ranks(kpt_comm)
49 self.wfs_qs = wfs_qs
51 self.q_k = {} # IBZ-index to local index
52 for wfs in self:
53 self.q_k[wfs.k] = wfs.q
55 self.band_comm = wfs.band_comm
56 self.domain_comm = wfs.domain_comm
57 self.dtype = wfs.dtype
58 self.nbands = wfs.nbands
60 self.fermi_levels: Array1D | None = None # hartree
62 self.xp = self.wfs_qs[0][0].xp
63 if self.xp is not np:
64 if not GPU_AWARE_MPI:
65 self.kpt_comm = CuPyMPI(self.kpt_comm) # type: ignore
67 self.move_wave_functions: Callable[..., None] = lambda *args: None
69 self.read_from_file_init_wfs_dm = False
71 @classmethod
72 def create(cls,
73 *,
74 ibz: IBZ,
75 ncomponents: int,
76 create_wfs_func,
77 kpt_comm: MPIComm = serial_comm,
78 kpt_band_comm: MPIComm = serial_comm,
79 comm: MPIComm = serial_comm,
80 ) -> Self:
81 """Collection of wave function objects for k-points in the IBZ."""
82 rank_k = ibz.ranks(kpt_comm)
83 mask_k = (rank_k == kpt_comm.rank)
84 k_q = np.arange(len(ibz))[mask_k]
86 nspins = ncomponents % 3
88 wfs_qs: list[list[WFT]] = []
89 for q, k in enumerate(k_q):
90 wfs_s = []
91 for spin in range(nspins):
92 wfs = create_wfs_func(spin, q, k,
93 ibz.kpt_kc[k], ibz.weight_k[k])
94 wfs_s.append(wfs)
95 wfs_qs.append(wfs_s)
97 return cls(ibz,
98 ncomponents=ncomponents,
99 wfs_qs=wfs_qs,
100 kpt_comm=kpt_comm,
101 kpt_band_comm=kpt_band_comm,
102 comm=comm)
104 @cached_property
105 def mode(self):
106 wfs = self.wfs_qs[0][0]
107 if isinstance(wfs, PWFDWaveFunctions):
108 if hasattr(wfs.psit_nX.desc, 'ecut'):
109 return 'pw'
110 return 'fd'
111 return 'lcao'
113 def has_wave_functions(self) -> bool:
114 raise NotImplementedError
116 def get_max_shape(self, global_shape: bool = False) -> tuple[int, ...]:
117 """Find the largest wave function array shape.
119 For a PW-calculation, this shape could depend on k-point.
120 """
121 if global_shape:
122 shape = np.array(max(wfs.array_shape(global_shape=True)
123 for wfs in self))
124 self.kpt_comm.max(shape)
125 return tuple(shape)
126 return max(wfs.array_shape() for wfs in self)
128 @property
129 def fermi_level(self) -> float:
130 fl = self.fermi_levels
131 assert fl is not None and len(fl) == 1
132 return fl[0]
134 def __str__(self):
135 shape = self.get_max_shape(global_shape=True)
136 wfs = self.wfs_qs[0][0]
137 nbytes = (len(self.ibz) *
138 self.nbands *
139 len(self.wfs_qs[0]) *
140 wfs.bytes_per_band)
141 ncores = (self.kpt_comm.size *
142 self.domain_comm.size *
143 self.band_comm.size)
144 return (f'{self.ibz.symmetries}\n'
145 f'{self.ibz}\n'
146 f'{wfs._short_string(shape)}\n'
147 f'spin-components: {self.ncomponents}'
148 ' # (' +
149 ('' if self.collinear else 'non-') + 'collinear spins)\n'
150 f'bands: {self.nbands}\n'
151 f'spin-degeneracy: {self.spin_degeneracy}\n'
152 f'dtype: {self.dtype}\n\n'
153 'memory:\n'
154 f' storage: {"CPU" if self.xp is np else "GPU"}\n'
155 f' wave functions: {nbytes:_} # bytes '
156 f' ({nbytes // ncores:_} per core)\n\n'
157 'parallelization:\n'
158 f' kpt: {self.kpt_comm.size}\n'
159 f' domain: {self.domain_comm.size}\n'
160 f' band: {self.band_comm.size}\n')
162 def __iter__(self) -> Generator[WFT, None, None]:
163 for wfs_s in self.wfs_qs:
164 yield from wfs_s
166 def move(self, relpos_ac, atomdist):
167 self.ibz.symmetries.check_positions(relpos_ac)
168 self.make_sure_wfs_are_read_from_gpw_file()
169 for wfs in self:
170 wfs.move(relpos_ac, atomdist, self.move_wave_functions)
172 def orthonormalize(self, work_array_nX: np.ndarray = None):
173 for wfs in self:
174 wfs.orthonormalize(work_array_nX)
176 @trace
177 def calculate_occs(self,
178 occ_calc,
179 nelectrons: float,
180 fix_fermi_level=False) -> tuple[float, float, float]:
181 degeneracy = self.spin_degeneracy
183 # u index is q and s combined
184 occ_un, fermi_levels, e_entropy = occ_calc.calculate(
185 nelectrons=nelectrons / degeneracy,
186 eigenvalues=[wfs.eig_n * Ha for wfs in self],
187 weights=[wfs.weight for wfs in self],
188 fermi_levels_guess=(None
189 if self.fermi_levels is None else
190 self.fermi_levels * Ha),
191 fix_fermi_level=fix_fermi_level)
193 if not fix_fermi_level:
194 self.fermi_levels = np.array(fermi_levels) / Ha
195 else:
196 assert self.fermi_levels is not None
198 for occ_n, wfs in zips(occ_un, self):
199 wfs._occ_n = occ_n
201 e_entropy *= degeneracy / Ha
202 e_band = 0.0
203 for wfs in self:
204 e_band += wfs.occ_n @ wfs.eig_n * wfs.weight * degeneracy
205 e_band = self.kpt_comm.sum_scalar(float(e_band)) # XXX CPU float?
207 return e_band, e_entropy, e_entropy * occ_calc.extrapolate_factor
209 def add_to_density(self, nt_sR, D_asii) -> None:
210 """Compute density and add to ``nt_sR`` and ``D_asii``."""
211 for wfs in self:
212 wfs.add_to_density(nt_sR, D_asii)
214 if self.xp is not np:
215 synchronize()
217 # This should be done in a more efficient way!!!
218 # Also: where do we want the density?
219 self.kpt_comm.sum(nt_sR.data)
220 self.kpt_comm.sum(D_asii.data)
221 self.band_comm.sum(nt_sR.data)
222 self.band_comm.sum(D_asii.data)
224 def normalize_density(self, density: Density) -> None:
225 pass # overwritten in LCAOIBZWaveFunctions class
227 def add_to_ked(self, taut_sR) -> None:
228 for wfs in self:
229 wfs.add_to_ked(taut_sR)
230 if self.xp is not np:
231 synchronize()
232 self.kpt_comm.sum(taut_sR.data)
233 self.band_comm.sum(taut_sR.data)
235 def get_all_electron_wave_function(self,
236 band,
237 kpt=0,
238 spin=0,
239 grid_spacing=0.05,
240 skip_paw_correction=False):
241 wfs = self.get_wfs(kpt=kpt, spin=spin, n1=band, n2=band + 1)
242 if wfs is None:
243 return None
244 assert isinstance(wfs, PWFDWaveFunctions)
245 psit_X = wfs.psit_nX[0].to_pbc_grid()
246 grid = psit_X.desc.uniform_grid_with_grid_spacing(grid_spacing)
247 psi_r = psit_X.interpolate(grid=grid)
249 if not skip_paw_correction:
250 dphi_aj = wfs.setups.partial_wave_corrections()
251 dphi_air = grid.atom_centered_functions(dphi_aj, wfs.relpos_ac)
252 dphi_air.add_to(psi_r, wfs.P_ani[:, 0])
254 return psi_r
256 def get_wfs(self,
257 *,
258 kpt: int = 0,
259 spin: int = 0,
260 n1=0,
261 n2=0):
262 rank = self.rank_k[kpt]
263 if rank == self.kpt_comm.rank:
264 wfs = self.wfs_qs[self.q_k[kpt]][spin]
265 wfs2 = wfs.collect(n1, n2)
266 if rank == 0:
267 return wfs2
268 if wfs2 is not None:
269 wfs2.send(0, self.kpt_comm)
270 return
271 if self.comm.rank == 0:
272 return self.wfs_qs[0][0].receive(rank, self.kpt_comm)
273 return None
275 def get_eigs_and_occs(self, k=0, s=0):
276 if self.domain_comm.rank == 0 and self.band_comm.rank == 0:
277 rank = self.rank_k[k]
278 if rank == self.kpt_comm.rank:
279 wfs = self.wfs_qs[self.q_k[k]][s]
280 if rank == 0:
281 return wfs._eig_n, wfs._occ_n
282 self.kpt_comm.send(wfs._eig_n, 0)
283 self.kpt_comm.send(wfs._occ_n, 0)
284 elif self.kpt_comm.rank == 0:
285 eig_n = np.empty(self.nbands)
286 occ_n = np.empty(self.nbands)
287 self.kpt_comm.receive(eig_n, rank)
288 self.kpt_comm.receive(occ_n, rank)
289 return eig_n, occ_n
290 return np.zeros(0), np.zeros(0)
292 def get_all_eigs_and_occs(self, broadcast=False):
293 nkpts = len(self.ibz)
294 mynbands = self.nbands if self.comm.rank == 0 or broadcast else 0
295 eig_skn = np.empty((self.nspins, nkpts, mynbands))
296 occ_skn = np.empty((self.nspins, nkpts, mynbands))
297 for k in range(nkpts):
298 for s in range(self.nspins):
299 eig_n, occ_n = self.get_eigs_and_occs(k, s)
300 if self.comm.rank == 0:
301 eig_skn[s, k, :] = eig_n
302 occ_skn[s, k, :] = occ_n
303 if broadcast:
304 self.comm.broadcast(eig_skn, 0)
305 self.comm.broadcast(occ_skn, 0)
306 return eig_skn, occ_skn
308 def forces(self, potential: Potential) -> Array2D:
309 self.make_sure_wfs_are_read_from_gpw_file()
310 F_av = self.xp.zeros((len(potential.dH_asii), 3))
311 for wfs in self:
312 wfs.force_contribution(potential, F_av)
313 if self.xp is not np:
314 synchronize()
315 self.kpt_band_comm.sum(F_av)
316 return F_av
318 def write(self, writer: Writer, flags) -> None:
319 """Write fermi-level(s), eigenvalues, occupation numbers, ...
321 ... k-points, symmetry information, projections and possibly
322 also the wave functions.
323 """
324 eig_skn, occ_skn = self.get_all_eigs_and_occs()
325 if not self.collinear:
326 eig_skn = eig_skn[0]
327 occ_skn = occ_skn[0]
328 assert self.fermi_levels is not None
329 writer.write(fermi_levels=self.fermi_levels * Ha,
330 eigenvalues=eig_skn * Ha,
331 occupations=occ_skn)
332 ibz = self.ibz
333 writer.child('kpts').write(
334 atommap=ibz.symmetries.atommap_sa,
335 bz2ibz=ibz.bz2ibz_K,
336 bzkpts=ibz.bz.kpt_Kc,
337 ibzkpts=ibz.kpt_kc,
338 rotations=ibz.symmetries.rotation_scc,
339 translations=ibz.symmetries.translation_sc,
340 weights=ibz.weight_k)
342 nproj = self.wfs_qs[0][0].P_ani.layout.size
344 spin_k_shape: tuple[int, ...]
345 proj_shape: tuple[int, ...]
347 if self.collinear:
348 spin_k_shape = (self.ncomponents, len(ibz))
349 proj_shape = (self.nbands, nproj)
350 else:
351 spin_k_shape = (len(ibz),)
352 proj_shape = (self.nbands, 2, nproj)
354 if flags.include_projections:
355 proj_dtype = flags.storage_dtype(self.dtype)
356 writer.add_array('projections', spin_k_shape + proj_shape,
357 proj_dtype)
358 for spin in range(self.nspins):
359 for k, rank in enumerate(self.rank_k):
360 if rank == self.kpt_comm.rank:
361 wfs = self.wfs_qs[self.q_k[k]][spin]
362 P_ani = wfs.P_ani.to_cpu().gather() # gather atoms
363 if P_ani is not None:
364 P_nI = P_ani.matrix.gather() # gather bands
365 if P_nI.dist.comm.rank == 0:
366 if rank == 0:
367 writer.fill(P_nI.data.reshape(
368 proj_shape).astype(proj_dtype))
369 else:
370 self.kpt_comm.send(P_nI.data, 0)
371 elif self.comm.rank == 0:
372 data = np.empty(proj_shape, self.dtype)
373 self.kpt_comm.receive(data, rank)
374 writer.fill(data.astype(proj_dtype))
376 if flags.include_wfs:
377 self._write_wave_functions(writer, spin_k_shape, flags)
379 def _write_wave_functions(self, writer, spin_k_shape, flags):
380 # We collect all bands to master. This may have to be changed
381 # to only one band at a time XXX
382 xshape = self.get_max_shape(global_shape=True)
383 shape = spin_k_shape + (self.nbands,) + xshape
384 dtype = complex if self.mode == 'pw' else self.dtype
385 dtype_write = flags.storage_dtype(dtype)
386 c = 1.0 if self.mode == 'lcao' else Bohr**-1.5
388 writer.add_array('coefficients', shape, dtype=dtype_write)
389 buf_nX = np.empty((self.nbands,) + xshape, dtype=dtype)
391 for spin in range(self.nspins):
392 for k, rank in enumerate(self.rank_k):
393 if rank == self.kpt_comm.rank:
394 wfs = self.wfs_qs[self.q_k[k]][spin]
395 coef_nX = wfs.gather_wave_function_coefficients()
396 if coef_nX is not None:
397 coef_nX = as_np(coef_nX)
398 if self.mode == 'pw':
399 x = coef_nX.shape[-1]
400 if x < xshape[-1]:
401 # For PW-mode, we may need to zero-pad the
402 # plane-wave coefficient up to the maximum
403 # for all k-points:
404 buf_nX[..., :x] = coef_nX
405 buf_nX[..., x:] = 0.0
406 coef_nX = buf_nX
407 if rank == 0:
408 writer.fill(flags.to_storage_dtype(coef_nX * c))
409 else:
410 self.kpt_comm.send(coef_nX, 0)
411 elif self.comm.rank == 0:
412 self.kpt_comm.receive(buf_nX, rank)
413 writer.fill(flags.to_storage_dtype(buf_nX * c))
415 def write_summary(self, log):
416 fl = self.fermi_levels * Ha
417 if len(fl) == 1:
418 log(f'\nFermi level: {fl[0]:.3f}')
419 else:
420 log(f'\nFermi levels: {fl[0]:.3f}, {fl[1]:.3f}')
422 ibz = self.ibz
424 eig_skn, occ_skn = self.get_all_eigs_and_occs()
426 if self.comm.rank != 0:
427 return
429 eig_skn *= Ha
431 D = self.spin_degeneracy
432 nbands = eig_skn.shape[2]
434 for k, (x, y, z) in enumerate(ibz.kpt_kc):
435 if k == 3:
436 log(f'(only showing first 3 out of {len(ibz)} k-points)')
437 break
439 log(f'\nkpt = [{x:.3f}, {y:.3f}, {z:.3f}], '
440 f'weight = {ibz.weight_k[k]:.3f}:')
442 if self.nspins == 1:
443 skipping = False
444 log(f' Band eig [eV] occ [0-{D}]')
445 eig_n = eig_skn[0, k]
446 n0 = (eig_n < fl[0]).sum() - 0.5
447 for n, (e, f) in enumerate(zips(eig_n, occ_skn[0, k])):
448 # First, last and +-8 bands window around Fermi level:
449 if n == 0 or abs(n - n0) < 8 or n == nbands - 1:
450 log(f' {n:4} {e:13.3f} {D * f:9.3f}')
451 skipping = False
452 else:
453 if not skipping:
454 log(' ...')
455 skipping = True
456 else:
457 log(' Band eig [eV] occ [0-1]'
458 ' eig [eV] occ [0-1]')
459 for n, (e1, f1, e2, f2) in enumerate(zips(eig_skn[0, k],
460 occ_skn[0, k],
461 eig_skn[1, k],
462 occ_skn[1, k])):
463 log(f' {n:4} {e1:13.3f} {f1:9.3f}'
464 f' {e2:10.3f} {f2:9.3f}')
466 try:
467 from ase.dft.bandgap import GapInfo
468 except ImportError:
469 log('No gapinfo -- requires new ASE')
470 return
472 try:
473 log()
474 fermilevel = fl[0]
475 gapinfo = GapInfo(eigenvalues=eig_skn - fermilevel)
476 log(gapinfo.description(ibz_kpoints=ibz.kpt_kc))
477 except ValueError:
478 # Maybe we only have the occupied bands and no empty bands
479 log('Could not find a gap')
481 def make_sure_wfs_are_read_from_gpw_file(self):
482 for wfs in self:
483 psit_nX = getattr(wfs, 'psit_nX', None)
484 if psit_nX is None:
485 return
486 if hasattr(psit_nX.data, 'fd'): # fd=file-descriptor
487 self.read_from_file_init_wfs_dm = True
488 psit_nX.data = np.ascontiguousarray(psit_nX.data[:]) # read
490 def get_homo_lumo(self, spin: int = None) -> Array1D:
491 """Return HOMO and LUMO eigenvalues."""
492 if self.ncomponents == 1:
493 assert spin != 1
494 spin = 0
495 elif self.ncomponents == 2:
496 if spin is None:
497 h0, l0 = self.get_homo_lumo(0)
498 h1, l1 = self.get_homo_lumo(1)
499 return np.array([max(h0, h1), min(l0, l1)])
500 else:
501 assert spin != 1
502 spin = 0
504 nocc = 0.0
505 for wfs_s in self.wfs_qs:
506 wfs = wfs_s[spin]
507 nocc += wfs.occ_n.sum() * wfs.weight
508 nocc = self.kpt_comm.sum_scalar(nocc)
509 n = int(round(nocc))
511 homo = -np.inf
512 if n > 0:
513 for wfs_s in self.wfs_qs:
514 homo = max(homo, wfs_s[spin].eig_n[n - 1])
515 homo = self.kpt_comm.max_scalar(homo)
517 lumo = np.inf
518 if n < self.nbands:
519 for wfs_s in self.wfs_qs:
520 lumo = min(lumo, wfs_s[spin].eig_n[n])
521 lumo = self.kpt_comm.min_scalar(lumo)
523 return np.array([homo, lumo])
525 def calculate_kinetic_energy(self,
526 hamiltonian,
527 density: Density) -> float:
528 e_kin = 0.0
529 for wfs in self:
530 e_kin += hamiltonian.calculate_kinetic_energy(wfs, skip_sum=True)
531 e_kin = self.comm.sum_scalar(e_kin)
533 # PAW corrections:
534 e_kin_paw = 0.0
535 for a, D_sii in density.D_asii.items():
536 setup = wfs.setups[a]
537 D_p = pack_density(D_sii.real[:density.ndensities].sum(0))
538 e_kin_paw += setup.K_p @ D_p + setup.Kc
539 e_kin_paw = density.grid.comm.sum_scalar(e_kin_paw)
541 return e_kin + e_kin_paw