Coverage for gpaw/new/builder.py: 82%

288 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1from __future__ import annotations 

2from functools import cached_property 

3from types import ModuleType, SimpleNamespace 

4from typing import Any, TYPE_CHECKING 

5import warnings 

6import numpy as np 

7from gpaw import GPAW_USE_GPUS, GPAW_CPUPY 

8from ase import Atoms 

9from ase.calculators.calculator import kpts2sizeandoffsets 

10from ase.geometry.cell import cell_to_cellpar 

11from ase.units import Bohr 

12from gpaw.core import UGDesc 

13from gpaw.core.atom_arrays import (AtomArrays, AtomArraysLayout, 

14 AtomDistribution) 

15from gpaw.core.domain import Domain 

16from gpaw.gpu import cpupy as fake_cupy 

17from gpaw.gpu.mpi import CuPyMPI 

18from gpaw.lfc import BasisFunctions 

19from gpaw.mixer import MixerWrapper, get_mixer_from_keywords 

20from gpaw.mpi import (MPIComm, Parallelization, broadcast, serial_comm, 

21 synchronize_atoms, world) 

22from gpaw.new import prod 

23from gpaw.new.basis import create_basis 

24from gpaw.new.brillouin import BZPoints, MonkhorstPackKPoints 

25from gpaw.new.c import GPU_AWARE_MPI 

26from gpaw.new.density import Density 

27from gpaw.new.ibzwfs import IBZWaveFunctions 

28from gpaw.new.logger import Logger 

29from gpaw.new.potential import Potential 

30from gpaw.new.scf import SCFLoop 

31from gpaw.new.smearing import OccupationNumberCalculator 

32from gpaw.new.xc import create_functional 

33from gpaw.setup import Setups 

34from gpaw.typing import Array2D, ArrayLike1D, ArrayLike2D, DTypeLike 

35from gpaw.utilities.gpts import get_number_of_grid_points 

36if TYPE_CHECKING: 

37 from gpaw.dft import Parameters 

38 

39 

40class DFTComponentsBuilder: 

41 def __init__(self, 

42 atoms: Atoms, 

43 params: Parameters, 

44 *, 

45 log=None, 

46 comm=None): 

47 

48 self.atoms = atoms.copy() 

49 self.mode = params.mode.name 

50 self.params = params 

51 if not isinstance(log, Logger): 

52 log = Logger(log, comm) 

53 self.log = log 

54 comm = log.comm 

55 

56 parallel = params.parallel 

57 

58 synchronize_atoms(atoms, comm) 

59 self.check_cell(atoms.cell) 

60 

61 self.initial_magmom_av, self.ncomponents = normalize_initial_magmoms( 

62 atoms, params.magmoms, params.spinpol or params.hund) 

63 

64 self.soc = params.soc 

65 self.nspins = self.ncomponents % 3 

66 self.spin_degeneracy = self.ncomponents % 2 + 1 

67 

68 xcfunc = params.xc.functional(collinear=(self.ncomponents < 4)) 

69 

70 if self.ncomponents == 4 and xcfunc.type != 'LDA': 

71 raise ValueError('Only LDA supported for ' 

72 'SC Non-collinear calculations') 

73 

74 self._backwards_comatible = params.experimental.get( 

75 'backwards_compatible', True) 

76 

77 self.setups = Setups( 

78 atoms.numbers, 

79 params.setups, 

80 params.basis, 

81 xcfunc, 

82 world=comm, 

83 backwards_compatible=self._backwards_comatible) 

84 if params.hund: 

85 c = params.charge / len(atoms) 

86 for a, setup in enumerate(self.setups): 

87 self.initial_magmom_av[a, 2] = setup.get_hunds_rule_moment(c) 

88 

89 symmetries = params.symmetry.build( 

90 atoms, 

91 setup_ids=self.setups.id_a, 

92 magmoms=self.initial_magmom_av, 

93 _backwards_compatible=self._backwards_comatible) 

94 

95 use_time_reversal = params.symmetry.time_reversal 

96 

97 symmetries._old_symmetry.time_reversal = use_time_reversal # legacy 

98 self.setups.set_symmetry(symmetries._old_symmetry) # legacy 

99 

100 if self.ncomponents == 4: 

101 assert (len(symmetries) == 1 and not use_time_reversal) 

102 

103 bz = params.kpts.build(atoms) 

104 self.ibz = bz.reduce( 

105 symmetries, 

106 strict=False, 

107 comm=comm, 

108 use_time_reversal=use_time_reversal) 

109 

110 d = parallel.get('domain', 1 if xcfunc.type == 'HYB' else None) 

111 k = parallel.get('kpt', None) 

112 b = parallel.get('band', None) 

113 self.communicators = create_communicators(comm, len(self.ibz), 

114 d, k, b, self.xp) 

115 

116 if self.mode == 'fd': 

117 pass # filter = create_fourier_filter(grid) 

118 # setups = setups.filter(filter) 

119 

120 self.nbands = calculate_number_of_bands(params.nbands, 

121 self.setups, 

122 params.charge, 

123 self.initial_magmom_av, 

124 self.mode == 'lcao') 

125 if self.ncomponents == 4: 

126 self.nbands *= 2 

127 

128 self.dtype: DTypeLike 

129 if params.mode.dtype is None: 

130 if self.params.mode.force_complex_dtype: 

131 self.dtype = complex 

132 else: 

133 if self.ibz.bz.gamma_only and self.ncomponents < 4: 

134 self.dtype = float 

135 else: 

136 self.dtype = complex 

137 else: 

138 self.dtype = params.mode.dtype 

139 

140 self.grid, self.fine_grid = self.create_uniform_grids() 

141 

142 self.relpos_ac = self.atoms.get_scaled_positions() 

143 self.relpos_ac %= 1 

144 self.relpos_ac %= 1 # yes, we need to do this twice! 

145 

146 self.xc = create_functional(xcfunc, self.fine_grid, self.xp) 

147 

148 self.interpolation_desc: Domain 

149 self.electrostatic_potential_desc: Domain 

150 

151 def __repr__(self): 

152 return f'{self.__class__.__name__}({self.atoms}, {self.params})' 

153 

154 def get_extensions(self): 

155 return [ext.build(self.atoms, 

156 self.communicators, 

157 self.log) for ext in self.params.extensions] 

158 

159 @cached_property 

160 def charge(self) -> float: 

161 return self.setups.core_charge + self.params.charge 

162 

163 @cached_property 

164 def nelectrons(self) -> float: 

165 return self.setups.nvalence - self.charge 

166 

167 @cached_property 

168 def atomdist(self) -> AtomDistribution: 

169 return AtomDistribution( 

170 self.grid.ranks_from_fractional_positions(self.relpos_ac), 

171 self.grid.comm) 

172 

173 def create_uniform_grids(self): 

174 raise NotImplementedError 

175 

176 def check_cell(self, cell): 

177 number_of_lattice_vectors = cell.rank 

178 if number_of_lattice_vectors < 3: 

179 raise ValueError( 

180 'GPAW requires 3 lattice vectors. ' 

181 f'Your system has {number_of_lattice_vectors}.') 

182 angles = cell_to_cellpar(cell)[3:] 

183 if not all(40.0 < a < 140.0 for a in angles): 

184 a, b, c = angles 

185 warnings.warn( 

186 'The angles between your unit-cell vectors are ' 

187 f'{a:.1}, {b:.1} and {c:.1} degrees. ' 

188 'Results may be wrong! ' 

189 'Please Niggli-reduce your unit-cell so that the angle ' 

190 'are closer to 90 degrees:\n\n' 

191 ' from ase.build import niggli_reduce\n' 

192 ' nigli_reduce(atoms)\n') 

193 

194 @cached_property 

195 def wf_desc(self) -> Domain: 

196 return self.create_wf_description() 

197 

198 @cached_property 

199 def gpu(self) -> bool: 

200 """Are we running on a GPU? 

201 

202 If parallel dict does not specify 'gpu': True or False, 

203 GPAW_USE_GPUS environment variable will be used to 

204 determine whether we use GPUs or not. 

205 """ 

206 if self.params.parallel.get('gpu', GPAW_USE_GPUS): 

207 from gpaw.gpu import cupy_is_fake 

208 if cupy_is_fake and not GPAW_CPUPY: 

209 parallel_source = ('the `parallel` parameter' 

210 if self.params.parallel.get('gpu') else 

211 'the environment variable `GPAW_USE_GPUS`') 

212 raise ValueError( 

213 f'GPU calculation is requested via {parallel_source}, ' 

214 'but the requisite CuPy library is not found; ' 

215 'please set GPAW_CPUPY=1 if you really want to do "GPU" ' 

216 'calculations with GPAW\'s fake CuPy library ' 

217 '(gpaw.gpu.cpupy)') 

218 return True 

219 return False 

220 

221 @cached_property 

222 def xp(self) -> ModuleType: 

223 """Array module: Numpy or Cupy.""" 

224 if self.gpu: 

225 from gpaw.gpu import cupy 

226 if cupy is fake_cupy: 

227 self.log(fake_cupy.FAKE_CUPY_WARNING) 

228 return cupy 

229 return np 

230 

231 def create_wf_description(self) -> Domain: 

232 raise NotImplementedError 

233 

234 def get_pseudo_core_densities(self): 

235 raise NotImplementedError 

236 

237 def get_pseudo_core_ked(self): 

238 raise NotImplementedError 

239 

240 def create_basis_set(self): 

241 return create_basis(self.ibz, 

242 self.ncomponents % 3, 

243 self.atoms.pbc, 

244 self.grid, 

245 self.setups, 

246 self.dtype, 

247 self.relpos_ac, 

248 self.communicators['w'], 

249 self.communicators['k'], 

250 self.communicators['b']) 

251 

252 def density_from_superposition(self, basis_set): 

253 return Density.from_superposition( 

254 grid=self.grid, 

255 nct_aX=self.get_pseudo_core_densities(), 

256 tauct_aX=self.get_pseudo_core_ked(), 

257 atomdist=self.atomdist, 

258 setups=self.setups, 

259 basis_set=basis_set, 

260 magmom_av=self.initial_magmom_av, 

261 ncomponents=self.ncomponents, 

262 charge=self.charge, 

263 hund=self.params.hund, 

264 mgga=self.xc.type == 'MGGA') 

265 

266 def create_occupation_number_calculator(self): 

267 return OccupationNumberCalculator( 

268 self.params.occupations.params, 

269 self.atoms.pbc, 

270 self.ibz, 

271 self.nbands, 

272 self.communicators, 

273 self.initial_magmom_av.sum(0), 

274 self.ncomponents, 

275 self.nelectrons, 

276 np.linalg.inv(self.atoms.cell.complete()).T) 

277 

278 def create_ibz_wave_functions(self, 

279 basis: BasisFunctions, 

280 potential: Potential) -> IBZWaveFunctions: 

281 raise NotImplementedError 

282 

283 def create_hamiltonian_operator(self): 

284 raise NotImplementedError 

285 

286 def create_eigensolver(self, hamiltonian): 

287 raise NotImplementedError 

288 

289 def create_scf_loop(self): 

290 hamiltonian = self.create_hamiltonian_operator() 

291 occ_calc = self.create_occupation_number_calculator() 

292 eigensolver = self.create_eigensolver(hamiltonian) 

293 

294 mixer = MixerWrapper( 

295 get_mixer_from_keywords(self.atoms.pbc.any(), 

296 self.ncomponents, 

297 **self.params.mixer.params), 

298 self.ncomponents, 

299 self.grid._gd, 

300 world=self.communicators['w']) 

301 

302 return SCFLoop(hamiltonian, occ_calc, 

303 eigensolver, mixer, self.communicators['w'], 

304 {key: value 

305 for key, value in self.params.convergence.items() 

306 if key != 'bands'}, 

307 self.params.maxiter) 

308 

309 def read_ibz_wave_functions(self, reader): 

310 raise NotImplementedError 

311 

312 def create_potential_calculator(self): 

313 raise NotImplementedError 

314 

315 def read_wavefunction_values(self, 

316 reader, 

317 ibzwfs: IBZWaveFunctions) -> None: 

318 """Read eigenvalues, occuptions and projections and fermi levels. 

319 

320 The values are read using reader and set as the appropriate properties 

321 of (the already instantiated) wavefunctions contained in ibzwfs 

322 """ 

323 ha = reader.ha 

324 

325 domain_comm = self.communicators['d'] 

326 band_comm = self.communicators['b'] 

327 

328 eig_skn = reader.wave_functions.eigenvalues 

329 occ_skn = reader.wave_functions.occupations 

330 

331 for wfs in ibzwfs: 

332 index: tuple[int, ...] 

333 if self.ncomponents < 4: 

334 dims = [self.nbands] 

335 index = (wfs.spin, wfs.k) 

336 else: 

337 dims = [self.nbands, 2] 

338 index = (wfs.k,) 

339 

340 wfs._eig_n = eig_skn[index] / ha 

341 wfs._occ_n = occ_skn[index] 

342 layout = AtomArraysLayout([(setup.ni,) for setup in self.setups], 

343 atomdist=self.atomdist, 

344 dtype=self.dtype) 

345 P_ani = AtomArrays(layout, dims=dims, comm=band_comm) 

346 

347 if domain_comm.rank == 0: 

348 try: 

349 P_nI = reader.wave_functions.proxy('projections', *index) 

350 except KeyError: 

351 data = None 

352 else: 

353 b1, b2 = P_ani.my_slice() # my bands 

354 data = P_nI[b1:b2].astype(ibzwfs.dtype) # read from file 

355 else: 

356 data = None 

357 

358 have_projections = broadcast( 

359 data is not None if domain_comm.rank == 0 else None, 

360 comm=domain_comm) 

361 

362 if have_projections: 

363 P_ani.scatter_from(data) # distribute over atoms 

364 wfs._P_ani = P_ani 

365 else: 

366 wfs._P_ani = None 

367 

368 try: 

369 ibzwfs.fermi_levels = reader.wave_functions.fermi_levels / ha 

370 except AttributeError: 

371 # old gpw-file 

372 ibzwfs.fermi_levels = np.array( 

373 [reader.occupations.fermilevel / ha]) 

374 

375 def create_environment(self, grid): 

376 return self.params.environment.build( 

377 setups=self.setups, 

378 grid=grid, relpos_ac=self.relpos_ac, log=self.log, 

379 comm=self.communicators['w']) 

380 

381 

382def create_communicators(comm: MPIComm = None, 

383 nibzkpts: int = 1, 

384 domain: int | tuple[int, int, int] | None = None, 

385 kpt: int = None, 

386 band: int = None, 

387 xp: ModuleType = np) -> dict[str, MPIComm]: 

388 parallelization = Parallelization(comm or world, nibzkpts) 

389 if domain is not None and not isinstance(domain, int): 

390 domain = prod(domain) 

391 parallelization.set(kpt=kpt, 

392 domain=domain, 

393 band=band) 

394 comms = parallelization.build_communicators() 

395 comms['w'] = comm 

396 

397 # We replace size=1 MPI communications with serial_comm so that 

398 # serial_comm.sum(<cupy-array>) works: XXX 

399 comms = {key: comm if comm.size > 1 else serial_comm 

400 for key, comm in comms.items()} 

401 

402 if xp is not np and not GPU_AWARE_MPI: 

403 comms = {key: CuPyMPI(comm) for key, comm in comms.items()} 

404 

405 return comms 

406 

407 

408def create_fourier_filter(grid): 

409 gamma = 1.6 

410 

411 h = ((grid.icell**2).sum(1)**-0.5 / grid.size).max() 

412 

413 def filter(rgd, rcut, f_r, l=0): 

414 gcut = np.pi / h - 2 / rcut / gamma 

415 ftmp = rgd.filter(f_r, rcut * gamma, gcut, l) 

416 f_r[:] = ftmp[:len(f_r)] 

417 

418 return filter 

419 

420 

421def normalize_initial_magmoms( 

422 atoms: Atoms, 

423 magmoms: ArrayLike2D | ArrayLike1D | float | None = None, 

424 force_spinpol_calculation: bool = False) -> tuple[Array2D, int]: 

425 """Convert magnetic moments to (natoms, 3)-shaped array. 

426 

427 Also return number of wave function components (1, 2 or 4). 

428 

429 >>> h = Atoms('H', magmoms=[1]) 

430 >>> normalize_initial_magmoms(h) 

431 (array([[0., 0., 1.]]), 2) 

432 >>> normalize_initial_magmoms(h, [[1, 0, 0]]) 

433 (array([[1., 0., 0.]]), 4) 

434 """ 

435 magmom_av = np.zeros((len(atoms), 3)) 

436 ncomponents = 2 

437 

438 if magmoms is None: 

439 magmom_av[:, 2] = atoms.get_initial_magnetic_moments() 

440 elif isinstance(magmoms, float): 

441 magmom_av[:, 2] = magmoms 

442 else: 

443 magmoms = np.asarray(magmoms) 

444 if magmoms.ndim == 1: 

445 magmom_av[:, 2] = magmoms 

446 else: 

447 magmom_av[:] = magmoms 

448 ncomponents = 4 

449 

450 if (ncomponents == 2 and 

451 not force_spinpol_calculation and 

452 not magmom_av[:, 2].any()): 

453 ncomponents = 1 

454 

455 return magmom_av, ncomponents 

456 

457 

458def ____create_kpts(kpts: dict[str, Any], atoms: Atoms) -> BZPoints: 

459 if 'kpts' in kpts: 

460 bz = BZPoints(kpts['kpts']) 

461 elif 'path' in kpts: 

462 path = atoms.cell.bandpath(pbc=atoms.pbc, **kpts) 

463 bz = BZPoints(path.kpts) 

464 else: 

465 size, offset = kpts2sizeandoffsets(**kpts, atoms=atoms) 

466 bz = MonkhorstPackKPoints(size, offset) 

467 for c, periodic in enumerate(atoms.pbc): 

468 if not periodic and not np.allclose(bz.kpt_Kc[:, c], 0.0): 

469 raise ValueError('K-points can only be used with PBCs!') 

470 return bz 

471 

472 

473def calculate_number_of_bands(nbands: int | str | None, 

474 setups: Setups, 

475 charge: float, 

476 initial_magmom_av: Array2D, 

477 is_lcao: bool) -> int: 

478 nao = setups.nao 

479 nvalence = setups.nvalence - charge 

480 M = np.linalg.norm(initial_magmom_av.sum(0)) 

481 

482 orbital_free = any(setup.orbital_free for setup in setups) 

483 if orbital_free: 

484 return 1 

485 

486 if nbands is None: 

487 # Number of bound partial waves: 

488 nbandsmax = sum(setup.get_default_nbands() 

489 for setup in setups) 

490 N = int(np.ceil(1.2 * (nvalence + M) / 2)) + 4 

491 N = min(N, nbandsmax) 

492 if is_lcao and N > nao: 

493 N = nao 

494 elif isinstance(nbands, str): 

495 if nbands == 'nao': 

496 N = nao 

497 elif nbands[-1] == '%': 

498 cfgbands = (nvalence + M) / 2 

499 N = int(np.ceil(float(nbands[:-1]) / 100 * cfgbands)) 

500 else: 

501 url = 'https://gpaw.readthedocs.io/documentation/basic.html' 

502 raise ValueError( 

503 f'Bad value for nbands: {nbands!r}. ' 

504 f'See {url}#manual-nbands for help') 

505 elif nbands <= 0: 

506 N = max(1, int(nvalence + M + 0.5) // 2 + (-nbands)) 

507 else: 

508 N = nbands 

509 

510 if N > nao and is_lcao: 

511 raise ValueError('Too many bands for LCAO calculation: ' 

512 f'{nbands}%d bands and only {nao} atomic orbitals!') 

513 

514 if nvalence < 0: 

515 raise ValueError( 

516 f'Charge {charge} is not possible - not enough valence electrons') 

517 

518 if nvalence > 2 * N: 

519 raise ValueError( 

520 f'Too few bands! Electrons: {nvalence}, bands: {nbands}') 

521 

522 return N 

523 

524 

525def create_uniform_grid(mode: str, 

526 gpts, 

527 cell, 

528 pbc, 

529 symmetries, 

530 h: float | None = None, 

531 interpolation: int | str | None = None, 

532 ecut: float = None, 

533 comm: MPIComm = serial_comm) -> UGDesc: 

534 """Create grid in a backwards compatible way.""" 

535 cell = cell / Bohr 

536 if h is not None: 

537 h /= Bohr 

538 

539 realspace = (mode != 'pw' and interpolation != 'fft') 

540 if realspace: 

541 zerobc = [not periodic for periodic in pbc] 

542 else: 

543 zerobc = [False] * 3 

544 

545 if gpts is not None: 

546 size = gpts 

547 else: 

548 modeobj = SimpleNamespace(name=mode, ecut=ecut) 

549 size = get_number_of_grid_points(cell, h, modeobj, realspace, 

550 symmetries) 

551 return UGDesc(cell=cell, pbc=pbc, zerobc=zerobc, size=size, comm=comm)