Coverage for gpaw/new/builder.py: 82%
288 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1from __future__ import annotations
2from functools import cached_property
3from types import ModuleType, SimpleNamespace
4from typing import Any, TYPE_CHECKING
5import warnings
6import numpy as np
7from gpaw import GPAW_USE_GPUS, GPAW_CPUPY
8from ase import Atoms
9from ase.calculators.calculator import kpts2sizeandoffsets
10from ase.geometry.cell import cell_to_cellpar
11from ase.units import Bohr
12from gpaw.core import UGDesc
13from gpaw.core.atom_arrays import (AtomArrays, AtomArraysLayout,
14 AtomDistribution)
15from gpaw.core.domain import Domain
16from gpaw.gpu import cpupy as fake_cupy
17from gpaw.gpu.mpi import CuPyMPI
18from gpaw.lfc import BasisFunctions
19from gpaw.mixer import MixerWrapper, get_mixer_from_keywords
20from gpaw.mpi import (MPIComm, Parallelization, broadcast, serial_comm,
21 synchronize_atoms, world)
22from gpaw.new import prod
23from gpaw.new.basis import create_basis
24from gpaw.new.brillouin import BZPoints, MonkhorstPackKPoints
25from gpaw.new.c import GPU_AWARE_MPI
26from gpaw.new.density import Density
27from gpaw.new.ibzwfs import IBZWaveFunctions
28from gpaw.new.logger import Logger
29from gpaw.new.potential import Potential
30from gpaw.new.scf import SCFLoop
31from gpaw.new.smearing import OccupationNumberCalculator
32from gpaw.new.xc import create_functional
33from gpaw.setup import Setups
34from gpaw.typing import Array2D, ArrayLike1D, ArrayLike2D, DTypeLike
35from gpaw.utilities.gpts import get_number_of_grid_points
36if TYPE_CHECKING:
37 from gpaw.dft import Parameters
40class DFTComponentsBuilder:
41 def __init__(self,
42 atoms: Atoms,
43 params: Parameters,
44 *,
45 log=None,
46 comm=None):
48 self.atoms = atoms.copy()
49 self.mode = params.mode.name
50 self.params = params
51 if not isinstance(log, Logger):
52 log = Logger(log, comm)
53 self.log = log
54 comm = log.comm
56 parallel = params.parallel
58 synchronize_atoms(atoms, comm)
59 self.check_cell(atoms.cell)
61 self.initial_magmom_av, self.ncomponents = normalize_initial_magmoms(
62 atoms, params.magmoms, params.spinpol or params.hund)
64 self.soc = params.soc
65 self.nspins = self.ncomponents % 3
66 self.spin_degeneracy = self.ncomponents % 2 + 1
68 xcfunc = params.xc.functional(collinear=(self.ncomponents < 4))
70 if self.ncomponents == 4 and xcfunc.type != 'LDA':
71 raise ValueError('Only LDA supported for '
72 'SC Non-collinear calculations')
74 self._backwards_comatible = params.experimental.get(
75 'backwards_compatible', True)
77 self.setups = Setups(
78 atoms.numbers,
79 params.setups,
80 params.basis,
81 xcfunc,
82 world=comm,
83 backwards_compatible=self._backwards_comatible)
84 if params.hund:
85 c = params.charge / len(atoms)
86 for a, setup in enumerate(self.setups):
87 self.initial_magmom_av[a, 2] = setup.get_hunds_rule_moment(c)
89 symmetries = params.symmetry.build(
90 atoms,
91 setup_ids=self.setups.id_a,
92 magmoms=self.initial_magmom_av,
93 _backwards_compatible=self._backwards_comatible)
95 use_time_reversal = params.symmetry.time_reversal
97 symmetries._old_symmetry.time_reversal = use_time_reversal # legacy
98 self.setups.set_symmetry(symmetries._old_symmetry) # legacy
100 if self.ncomponents == 4:
101 assert (len(symmetries) == 1 and not use_time_reversal)
103 bz = params.kpts.build(atoms)
104 self.ibz = bz.reduce(
105 symmetries,
106 strict=False,
107 comm=comm,
108 use_time_reversal=use_time_reversal)
110 d = parallel.get('domain', 1 if xcfunc.type == 'HYB' else None)
111 k = parallel.get('kpt', None)
112 b = parallel.get('band', None)
113 self.communicators = create_communicators(comm, len(self.ibz),
114 d, k, b, self.xp)
116 if self.mode == 'fd':
117 pass # filter = create_fourier_filter(grid)
118 # setups = setups.filter(filter)
120 self.nbands = calculate_number_of_bands(params.nbands,
121 self.setups,
122 params.charge,
123 self.initial_magmom_av,
124 self.mode == 'lcao')
125 if self.ncomponents == 4:
126 self.nbands *= 2
128 self.dtype: DTypeLike
129 if params.mode.dtype is None:
130 if self.params.mode.force_complex_dtype:
131 self.dtype = complex
132 else:
133 if self.ibz.bz.gamma_only and self.ncomponents < 4:
134 self.dtype = float
135 else:
136 self.dtype = complex
137 else:
138 self.dtype = params.mode.dtype
140 self.grid, self.fine_grid = self.create_uniform_grids()
142 self.relpos_ac = self.atoms.get_scaled_positions()
143 self.relpos_ac %= 1
144 self.relpos_ac %= 1 # yes, we need to do this twice!
146 self.xc = create_functional(xcfunc, self.fine_grid, self.xp)
148 self.interpolation_desc: Domain
149 self.electrostatic_potential_desc: Domain
151 def __repr__(self):
152 return f'{self.__class__.__name__}({self.atoms}, {self.params})'
154 def get_extensions(self):
155 return [ext.build(self.atoms,
156 self.communicators,
157 self.log) for ext in self.params.extensions]
159 @cached_property
160 def charge(self) -> float:
161 return self.setups.core_charge + self.params.charge
163 @cached_property
164 def nelectrons(self) -> float:
165 return self.setups.nvalence - self.charge
167 @cached_property
168 def atomdist(self) -> AtomDistribution:
169 return AtomDistribution(
170 self.grid.ranks_from_fractional_positions(self.relpos_ac),
171 self.grid.comm)
173 def create_uniform_grids(self):
174 raise NotImplementedError
176 def check_cell(self, cell):
177 number_of_lattice_vectors = cell.rank
178 if number_of_lattice_vectors < 3:
179 raise ValueError(
180 'GPAW requires 3 lattice vectors. '
181 f'Your system has {number_of_lattice_vectors}.')
182 angles = cell_to_cellpar(cell)[3:]
183 if not all(40.0 < a < 140.0 for a in angles):
184 a, b, c = angles
185 warnings.warn(
186 'The angles between your unit-cell vectors are '
187 f'{a:.1}, {b:.1} and {c:.1} degrees. '
188 'Results may be wrong! '
189 'Please Niggli-reduce your unit-cell so that the angle '
190 'are closer to 90 degrees:\n\n'
191 ' from ase.build import niggli_reduce\n'
192 ' nigli_reduce(atoms)\n')
194 @cached_property
195 def wf_desc(self) -> Domain:
196 return self.create_wf_description()
198 @cached_property
199 def gpu(self) -> bool:
200 """Are we running on a GPU?
202 If parallel dict does not specify 'gpu': True or False,
203 GPAW_USE_GPUS environment variable will be used to
204 determine whether we use GPUs or not.
205 """
206 if self.params.parallel.get('gpu', GPAW_USE_GPUS):
207 from gpaw.gpu import cupy_is_fake
208 if cupy_is_fake and not GPAW_CPUPY:
209 parallel_source = ('the `parallel` parameter'
210 if self.params.parallel.get('gpu') else
211 'the environment variable `GPAW_USE_GPUS`')
212 raise ValueError(
213 f'GPU calculation is requested via {parallel_source}, '
214 'but the requisite CuPy library is not found; '
215 'please set GPAW_CPUPY=1 if you really want to do "GPU" '
216 'calculations with GPAW\'s fake CuPy library '
217 '(gpaw.gpu.cpupy)')
218 return True
219 return False
221 @cached_property
222 def xp(self) -> ModuleType:
223 """Array module: Numpy or Cupy."""
224 if self.gpu:
225 from gpaw.gpu import cupy
226 if cupy is fake_cupy:
227 self.log(fake_cupy.FAKE_CUPY_WARNING)
228 return cupy
229 return np
231 def create_wf_description(self) -> Domain:
232 raise NotImplementedError
234 def get_pseudo_core_densities(self):
235 raise NotImplementedError
237 def get_pseudo_core_ked(self):
238 raise NotImplementedError
240 def create_basis_set(self):
241 return create_basis(self.ibz,
242 self.ncomponents % 3,
243 self.atoms.pbc,
244 self.grid,
245 self.setups,
246 self.dtype,
247 self.relpos_ac,
248 self.communicators['w'],
249 self.communicators['k'],
250 self.communicators['b'])
252 def density_from_superposition(self, basis_set):
253 return Density.from_superposition(
254 grid=self.grid,
255 nct_aX=self.get_pseudo_core_densities(),
256 tauct_aX=self.get_pseudo_core_ked(),
257 atomdist=self.atomdist,
258 setups=self.setups,
259 basis_set=basis_set,
260 magmom_av=self.initial_magmom_av,
261 ncomponents=self.ncomponents,
262 charge=self.charge,
263 hund=self.params.hund,
264 mgga=self.xc.type == 'MGGA')
266 def create_occupation_number_calculator(self):
267 return OccupationNumberCalculator(
268 self.params.occupations.params,
269 self.atoms.pbc,
270 self.ibz,
271 self.nbands,
272 self.communicators,
273 self.initial_magmom_av.sum(0),
274 self.ncomponents,
275 self.nelectrons,
276 np.linalg.inv(self.atoms.cell.complete()).T)
278 def create_ibz_wave_functions(self,
279 basis: BasisFunctions,
280 potential: Potential) -> IBZWaveFunctions:
281 raise NotImplementedError
283 def create_hamiltonian_operator(self):
284 raise NotImplementedError
286 def create_eigensolver(self, hamiltonian):
287 raise NotImplementedError
289 def create_scf_loop(self):
290 hamiltonian = self.create_hamiltonian_operator()
291 occ_calc = self.create_occupation_number_calculator()
292 eigensolver = self.create_eigensolver(hamiltonian)
294 mixer = MixerWrapper(
295 get_mixer_from_keywords(self.atoms.pbc.any(),
296 self.ncomponents,
297 **self.params.mixer.params),
298 self.ncomponents,
299 self.grid._gd,
300 world=self.communicators['w'])
302 return SCFLoop(hamiltonian, occ_calc,
303 eigensolver, mixer, self.communicators['w'],
304 {key: value
305 for key, value in self.params.convergence.items()
306 if key != 'bands'},
307 self.params.maxiter)
309 def read_ibz_wave_functions(self, reader):
310 raise NotImplementedError
312 def create_potential_calculator(self):
313 raise NotImplementedError
315 def read_wavefunction_values(self,
316 reader,
317 ibzwfs: IBZWaveFunctions) -> None:
318 """Read eigenvalues, occuptions and projections and fermi levels.
320 The values are read using reader and set as the appropriate properties
321 of (the already instantiated) wavefunctions contained in ibzwfs
322 """
323 ha = reader.ha
325 domain_comm = self.communicators['d']
326 band_comm = self.communicators['b']
328 eig_skn = reader.wave_functions.eigenvalues
329 occ_skn = reader.wave_functions.occupations
331 for wfs in ibzwfs:
332 index: tuple[int, ...]
333 if self.ncomponents < 4:
334 dims = [self.nbands]
335 index = (wfs.spin, wfs.k)
336 else:
337 dims = [self.nbands, 2]
338 index = (wfs.k,)
340 wfs._eig_n = eig_skn[index] / ha
341 wfs._occ_n = occ_skn[index]
342 layout = AtomArraysLayout([(setup.ni,) for setup in self.setups],
343 atomdist=self.atomdist,
344 dtype=self.dtype)
345 P_ani = AtomArrays(layout, dims=dims, comm=band_comm)
347 if domain_comm.rank == 0:
348 try:
349 P_nI = reader.wave_functions.proxy('projections', *index)
350 except KeyError:
351 data = None
352 else:
353 b1, b2 = P_ani.my_slice() # my bands
354 data = P_nI[b1:b2].astype(ibzwfs.dtype) # read from file
355 else:
356 data = None
358 have_projections = broadcast(
359 data is not None if domain_comm.rank == 0 else None,
360 comm=domain_comm)
362 if have_projections:
363 P_ani.scatter_from(data) # distribute over atoms
364 wfs._P_ani = P_ani
365 else:
366 wfs._P_ani = None
368 try:
369 ibzwfs.fermi_levels = reader.wave_functions.fermi_levels / ha
370 except AttributeError:
371 # old gpw-file
372 ibzwfs.fermi_levels = np.array(
373 [reader.occupations.fermilevel / ha])
375 def create_environment(self, grid):
376 return self.params.environment.build(
377 setups=self.setups,
378 grid=grid, relpos_ac=self.relpos_ac, log=self.log,
379 comm=self.communicators['w'])
382def create_communicators(comm: MPIComm = None,
383 nibzkpts: int = 1,
384 domain: int | tuple[int, int, int] | None = None,
385 kpt: int = None,
386 band: int = None,
387 xp: ModuleType = np) -> dict[str, MPIComm]:
388 parallelization = Parallelization(comm or world, nibzkpts)
389 if domain is not None and not isinstance(domain, int):
390 domain = prod(domain)
391 parallelization.set(kpt=kpt,
392 domain=domain,
393 band=band)
394 comms = parallelization.build_communicators()
395 comms['w'] = comm
397 # We replace size=1 MPI communications with serial_comm so that
398 # serial_comm.sum(<cupy-array>) works: XXX
399 comms = {key: comm if comm.size > 1 else serial_comm
400 for key, comm in comms.items()}
402 if xp is not np and not GPU_AWARE_MPI:
403 comms = {key: CuPyMPI(comm) for key, comm in comms.items()}
405 return comms
408def create_fourier_filter(grid):
409 gamma = 1.6
411 h = ((grid.icell**2).sum(1)**-0.5 / grid.size).max()
413 def filter(rgd, rcut, f_r, l=0):
414 gcut = np.pi / h - 2 / rcut / gamma
415 ftmp = rgd.filter(f_r, rcut * gamma, gcut, l)
416 f_r[:] = ftmp[:len(f_r)]
418 return filter
421def normalize_initial_magmoms(
422 atoms: Atoms,
423 magmoms: ArrayLike2D | ArrayLike1D | float | None = None,
424 force_spinpol_calculation: bool = False) -> tuple[Array2D, int]:
425 """Convert magnetic moments to (natoms, 3)-shaped array.
427 Also return number of wave function components (1, 2 or 4).
429 >>> h = Atoms('H', magmoms=[1])
430 >>> normalize_initial_magmoms(h)
431 (array([[0., 0., 1.]]), 2)
432 >>> normalize_initial_magmoms(h, [[1, 0, 0]])
433 (array([[1., 0., 0.]]), 4)
434 """
435 magmom_av = np.zeros((len(atoms), 3))
436 ncomponents = 2
438 if magmoms is None:
439 magmom_av[:, 2] = atoms.get_initial_magnetic_moments()
440 elif isinstance(magmoms, float):
441 magmom_av[:, 2] = magmoms
442 else:
443 magmoms = np.asarray(magmoms)
444 if magmoms.ndim == 1:
445 magmom_av[:, 2] = magmoms
446 else:
447 magmom_av[:] = magmoms
448 ncomponents = 4
450 if (ncomponents == 2 and
451 not force_spinpol_calculation and
452 not magmom_av[:, 2].any()):
453 ncomponents = 1
455 return magmom_av, ncomponents
458def ____create_kpts(kpts: dict[str, Any], atoms: Atoms) -> BZPoints:
459 if 'kpts' in kpts:
460 bz = BZPoints(kpts['kpts'])
461 elif 'path' in kpts:
462 path = atoms.cell.bandpath(pbc=atoms.pbc, **kpts)
463 bz = BZPoints(path.kpts)
464 else:
465 size, offset = kpts2sizeandoffsets(**kpts, atoms=atoms)
466 bz = MonkhorstPackKPoints(size, offset)
467 for c, periodic in enumerate(atoms.pbc):
468 if not periodic and not np.allclose(bz.kpt_Kc[:, c], 0.0):
469 raise ValueError('K-points can only be used with PBCs!')
470 return bz
473def calculate_number_of_bands(nbands: int | str | None,
474 setups: Setups,
475 charge: float,
476 initial_magmom_av: Array2D,
477 is_lcao: bool) -> int:
478 nao = setups.nao
479 nvalence = setups.nvalence - charge
480 M = np.linalg.norm(initial_magmom_av.sum(0))
482 orbital_free = any(setup.orbital_free for setup in setups)
483 if orbital_free:
484 return 1
486 if nbands is None:
487 # Number of bound partial waves:
488 nbandsmax = sum(setup.get_default_nbands()
489 for setup in setups)
490 N = int(np.ceil(1.2 * (nvalence + M) / 2)) + 4
491 N = min(N, nbandsmax)
492 if is_lcao and N > nao:
493 N = nao
494 elif isinstance(nbands, str):
495 if nbands == 'nao':
496 N = nao
497 elif nbands[-1] == '%':
498 cfgbands = (nvalence + M) / 2
499 N = int(np.ceil(float(nbands[:-1]) / 100 * cfgbands))
500 else:
501 url = 'https://gpaw.readthedocs.io/documentation/basic.html'
502 raise ValueError(
503 f'Bad value for nbands: {nbands!r}. '
504 f'See {url}#manual-nbands for help')
505 elif nbands <= 0:
506 N = max(1, int(nvalence + M + 0.5) // 2 + (-nbands))
507 else:
508 N = nbands
510 if N > nao and is_lcao:
511 raise ValueError('Too many bands for LCAO calculation: '
512 f'{nbands}%d bands and only {nao} atomic orbitals!')
514 if nvalence < 0:
515 raise ValueError(
516 f'Charge {charge} is not possible - not enough valence electrons')
518 if nvalence > 2 * N:
519 raise ValueError(
520 f'Too few bands! Electrons: {nvalence}, bands: {nbands}')
522 return N
525def create_uniform_grid(mode: str,
526 gpts,
527 cell,
528 pbc,
529 symmetries,
530 h: float | None = None,
531 interpolation: int | str | None = None,
532 ecut: float = None,
533 comm: MPIComm = serial_comm) -> UGDesc:
534 """Create grid in a backwards compatible way."""
535 cell = cell / Bohr
536 if h is not None:
537 h /= Bohr
539 realspace = (mode != 'pw' and interpolation != 'fft')
540 if realspace:
541 zerobc = [not periodic for periodic in pbc]
542 else:
543 zerobc = [False] * 3
545 if gpts is not None:
546 size = gpts
547 else:
548 modeobj = SimpleNamespace(name=mode, ecut=ecut)
549 size = get_number_of_grid_points(cell, h, modeobj, realspace,
550 symmetries)
551 return UGDesc(cell=cell, pbc=pbc, zerobc=zerobc, size=size, comm=comm)