Coverage for gpaw/new/ibzwfs.py: 84%

352 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1from __future__ import annotations 

2 

3from functools import cached_property 

4from typing import TYPE_CHECKING, Callable, Generator, Generic, TypeVar 

5 

6import numpy as np 

7from ase.io.ulm import Writer 

8from ase.units import Bohr, Ha 

9from gpaw.gpu import as_np, synchronize 

10from gpaw.gpu.mpi import CuPyMPI 

11from gpaw.mpi import MPIComm, serial_comm 

12from gpaw.new import zips 

13from gpaw.new.timer import trace 

14from gpaw.new.brillouin import IBZ 

15from gpaw.new.c import GPU_AWARE_MPI 

16from gpaw.new.potential import Potential 

17from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions 

18from gpaw.new.wave_functions import WaveFunctions 

19from gpaw.typing import Array1D, Array2D, Self 

20from gpaw.utilities import pack_density 

21 

22if TYPE_CHECKING: 

23 from gpaw.new.density import Density 

24 

25WFT = TypeVar('WFT', bound=WaveFunctions) 

26 

27 

28class IBZWaveFunctions(Generic[WFT]): 

29 def __init__(self, 

30 ibz: IBZ, 

31 *, 

32 ncomponents: int, 

33 wfs_qs: list[list[WFT]], 

34 kpt_comm: MPIComm = serial_comm, 

35 kpt_band_comm: MPIComm = serial_comm, 

36 comm: MPIComm = serial_comm): 

37 """Collection of wave function objects for k-points in the IBZ.""" 

38 self.ibz = ibz 

39 self.kpt_comm = kpt_comm 

40 self.kpt_band_comm = kpt_band_comm 

41 self.comm = comm 

42 self.ncomponents = ncomponents 

43 self.collinear = (ncomponents != 4) 

44 self.spin_degeneracy = ncomponents % 2 + 1 

45 self.nspins = ncomponents % 3 

46 

47 self.rank_k = ibz.ranks(kpt_comm) 

48 

49 self.wfs_qs = wfs_qs 

50 

51 self.q_k = {} # IBZ-index to local index 

52 for wfs in self: 

53 self.q_k[wfs.k] = wfs.q 

54 

55 self.band_comm = wfs.band_comm 

56 self.domain_comm = wfs.domain_comm 

57 self.dtype = wfs.dtype 

58 self.nbands = wfs.nbands 

59 

60 self.fermi_levels: Array1D | None = None # hartree 

61 

62 self.xp = self.wfs_qs[0][0].xp 

63 if self.xp is not np: 

64 if not GPU_AWARE_MPI: 

65 self.kpt_comm = CuPyMPI(self.kpt_comm) # type: ignore 

66 

67 self.move_wave_functions: Callable[..., None] = lambda *args: None 

68 

69 self.read_from_file_init_wfs_dm = False 

70 

71 @classmethod 

72 def create(cls, 

73 *, 

74 ibz: IBZ, 

75 ncomponents: int, 

76 create_wfs_func, 

77 kpt_comm: MPIComm = serial_comm, 

78 kpt_band_comm: MPIComm = serial_comm, 

79 comm: MPIComm = serial_comm, 

80 ) -> Self: 

81 """Collection of wave function objects for k-points in the IBZ.""" 

82 rank_k = ibz.ranks(kpt_comm) 

83 mask_k = (rank_k == kpt_comm.rank) 

84 k_q = np.arange(len(ibz))[mask_k] 

85 

86 nspins = ncomponents % 3 

87 

88 wfs_qs: list[list[WFT]] = [] 

89 for q, k in enumerate(k_q): 

90 wfs_s = [] 

91 for spin in range(nspins): 

92 wfs = create_wfs_func(spin, q, k, 

93 ibz.kpt_kc[k], ibz.weight_k[k]) 

94 wfs_s.append(wfs) 

95 wfs_qs.append(wfs_s) 

96 

97 return cls(ibz, 

98 ncomponents=ncomponents, 

99 wfs_qs=wfs_qs, 

100 kpt_comm=kpt_comm, 

101 kpt_band_comm=kpt_band_comm, 

102 comm=comm) 

103 

104 @cached_property 

105 def mode(self): 

106 wfs = self.wfs_qs[0][0] 

107 if isinstance(wfs, PWFDWaveFunctions): 

108 if hasattr(wfs.psit_nX.desc, 'ecut'): 

109 return 'pw' 

110 return 'fd' 

111 return 'lcao' 

112 

113 def has_wave_functions(self) -> bool: 

114 raise NotImplementedError 

115 

116 def get_max_shape(self, global_shape: bool = False) -> tuple[int, ...]: 

117 """Find the largest wave function array shape. 

118 

119 For a PW-calculation, this shape could depend on k-point. 

120 """ 

121 if global_shape: 

122 shape = np.array(max(wfs.array_shape(global_shape=True) 

123 for wfs in self)) 

124 self.kpt_comm.max(shape) 

125 return tuple(shape) 

126 return max(wfs.array_shape() for wfs in self) 

127 

128 @property 

129 def fermi_level(self) -> float: 

130 fl = self.fermi_levels 

131 assert fl is not None and len(fl) == 1 

132 return fl[0] 

133 

134 def __str__(self): 

135 shape = self.get_max_shape(global_shape=True) 

136 wfs = self.wfs_qs[0][0] 

137 nbytes = (len(self.ibz) * 

138 self.nbands * 

139 len(self.wfs_qs[0]) * 

140 wfs.bytes_per_band) 

141 ncores = (self.kpt_comm.size * 

142 self.domain_comm.size * 

143 self.band_comm.size) 

144 return (f'{self.ibz.symmetries}\n' 

145 f'{self.ibz}\n' 

146 f'{wfs._short_string(shape)}\n' 

147 f'spin-components: {self.ncomponents}' 

148 ' # (' + 

149 ('' if self.collinear else 'non-') + 'collinear spins)\n' 

150 f'bands: {self.nbands}\n' 

151 f'spin-degeneracy: {self.spin_degeneracy}\n' 

152 f'dtype: {self.dtype}\n\n' 

153 'memory:\n' 

154 f' storage: {"CPU" if self.xp is np else "GPU"}\n' 

155 f' wave functions: {nbytes:_} # bytes ' 

156 f' ({nbytes // ncores:_} per core)\n\n' 

157 'parallelization:\n' 

158 f' kpt: {self.kpt_comm.size}\n' 

159 f' domain: {self.domain_comm.size}\n' 

160 f' band: {self.band_comm.size}\n') 

161 

162 def __iter__(self) -> Generator[WFT, None, None]: 

163 for wfs_s in self.wfs_qs: 

164 yield from wfs_s 

165 

166 def move(self, relpos_ac, atomdist): 

167 self.ibz.symmetries.check_positions(relpos_ac) 

168 self.make_sure_wfs_are_read_from_gpw_file() 

169 for wfs in self: 

170 wfs.move(relpos_ac, atomdist, self.move_wave_functions) 

171 

172 def orthonormalize(self, work_array_nX: np.ndarray = None): 

173 for wfs in self: 

174 wfs.orthonormalize(work_array_nX) 

175 

176 @trace 

177 def calculate_occs(self, 

178 occ_calc, 

179 nelectrons: float, 

180 fix_fermi_level=False) -> tuple[float, float, float]: 

181 degeneracy = self.spin_degeneracy 

182 

183 # u index is q and s combined 

184 occ_un, fermi_levels, e_entropy = occ_calc.calculate( 

185 nelectrons=nelectrons / degeneracy, 

186 eigenvalues=[wfs.eig_n * Ha for wfs in self], 

187 weights=[wfs.weight for wfs in self], 

188 fermi_levels_guess=(None 

189 if self.fermi_levels is None else 

190 self.fermi_levels * Ha), 

191 fix_fermi_level=fix_fermi_level) 

192 

193 if not fix_fermi_level: 

194 self.fermi_levels = np.array(fermi_levels) / Ha 

195 else: 

196 assert self.fermi_levels is not None 

197 

198 for occ_n, wfs in zips(occ_un, self): 

199 wfs._occ_n = occ_n 

200 

201 e_entropy *= degeneracy / Ha 

202 e_band = 0.0 

203 for wfs in self: 

204 e_band += wfs.occ_n @ wfs.eig_n * wfs.weight * degeneracy 

205 e_band = self.kpt_comm.sum_scalar(float(e_band)) # XXX CPU float? 

206 

207 return e_band, e_entropy, e_entropy * occ_calc.extrapolate_factor 

208 

209 def add_to_density(self, nt_sR, D_asii) -> None: 

210 """Compute density and add to ``nt_sR`` and ``D_asii``.""" 

211 for wfs in self: 

212 wfs.add_to_density(nt_sR, D_asii) 

213 

214 if self.xp is not np: 

215 synchronize() 

216 

217 # This should be done in a more efficient way!!! 

218 # Also: where do we want the density? 

219 self.kpt_comm.sum(nt_sR.data) 

220 self.kpt_comm.sum(D_asii.data) 

221 self.band_comm.sum(nt_sR.data) 

222 self.band_comm.sum(D_asii.data) 

223 

224 def normalize_density(self, density: Density) -> None: 

225 pass # overwritten in LCAOIBZWaveFunctions class 

226 

227 def add_to_ked(self, taut_sR) -> None: 

228 for wfs in self: 

229 wfs.add_to_ked(taut_sR) 

230 if self.xp is not np: 

231 synchronize() 

232 self.kpt_comm.sum(taut_sR.data) 

233 self.band_comm.sum(taut_sR.data) 

234 

235 def get_all_electron_wave_function(self, 

236 band, 

237 kpt=0, 

238 spin=0, 

239 grid_spacing=0.05, 

240 skip_paw_correction=False): 

241 wfs = self.get_wfs(kpt=kpt, spin=spin, n1=band, n2=band + 1) 

242 if wfs is None: 

243 return None 

244 assert isinstance(wfs, PWFDWaveFunctions) 

245 psit_X = wfs.psit_nX[0].to_pbc_grid() 

246 grid = psit_X.desc.uniform_grid_with_grid_spacing(grid_spacing) 

247 psi_r = psit_X.interpolate(grid=grid) 

248 

249 if not skip_paw_correction: 

250 dphi_aj = wfs.setups.partial_wave_corrections() 

251 dphi_air = grid.atom_centered_functions(dphi_aj, wfs.relpos_ac) 

252 dphi_air.add_to(psi_r, wfs.P_ani[:, 0]) 

253 

254 return psi_r 

255 

256 def get_wfs(self, 

257 *, 

258 kpt: int = 0, 

259 spin: int = 0, 

260 n1=0, 

261 n2=0): 

262 rank = self.rank_k[kpt] 

263 if rank == self.kpt_comm.rank: 

264 wfs = self.wfs_qs[self.q_k[kpt]][spin] 

265 wfs2 = wfs.collect(n1, n2) 

266 if rank == 0: 

267 return wfs2 

268 if wfs2 is not None: 

269 wfs2.send(0, self.kpt_comm) 

270 return 

271 if self.comm.rank == 0: 

272 return self.wfs_qs[0][0].receive(rank, self.kpt_comm) 

273 return None 

274 

275 def get_eigs_and_occs(self, k=0, s=0): 

276 if self.domain_comm.rank == 0 and self.band_comm.rank == 0: 

277 rank = self.rank_k[k] 

278 if rank == self.kpt_comm.rank: 

279 wfs = self.wfs_qs[self.q_k[k]][s] 

280 if rank == 0: 

281 return wfs._eig_n, wfs._occ_n 

282 self.kpt_comm.send(wfs._eig_n, 0) 

283 self.kpt_comm.send(wfs._occ_n, 0) 

284 elif self.kpt_comm.rank == 0: 

285 eig_n = np.empty(self.nbands) 

286 occ_n = np.empty(self.nbands) 

287 self.kpt_comm.receive(eig_n, rank) 

288 self.kpt_comm.receive(occ_n, rank) 

289 return eig_n, occ_n 

290 return np.zeros(0), np.zeros(0) 

291 

292 def get_all_eigs_and_occs(self, broadcast=False): 

293 nkpts = len(self.ibz) 

294 mynbands = self.nbands if self.comm.rank == 0 or broadcast else 0 

295 eig_skn = np.empty((self.nspins, nkpts, mynbands)) 

296 occ_skn = np.empty((self.nspins, nkpts, mynbands)) 

297 for k in range(nkpts): 

298 for s in range(self.nspins): 

299 eig_n, occ_n = self.get_eigs_and_occs(k, s) 

300 if self.comm.rank == 0: 

301 eig_skn[s, k, :] = eig_n 

302 occ_skn[s, k, :] = occ_n 

303 if broadcast: 

304 self.comm.broadcast(eig_skn, 0) 

305 self.comm.broadcast(occ_skn, 0) 

306 return eig_skn, occ_skn 

307 

308 def forces(self, potential: Potential) -> Array2D: 

309 self.make_sure_wfs_are_read_from_gpw_file() 

310 F_av = self.xp.zeros((len(potential.dH_asii), 3)) 

311 for wfs in self: 

312 wfs.force_contribution(potential, F_av) 

313 if self.xp is not np: 

314 synchronize() 

315 self.kpt_band_comm.sum(F_av) 

316 return F_av 

317 

318 def write(self, writer: Writer, flags) -> None: 

319 """Write fermi-level(s), eigenvalues, occupation numbers, ... 

320 

321 ... k-points, symmetry information, projections and possibly 

322 also the wave functions. 

323 """ 

324 eig_skn, occ_skn = self.get_all_eigs_and_occs() 

325 if not self.collinear: 

326 eig_skn = eig_skn[0] 

327 occ_skn = occ_skn[0] 

328 assert self.fermi_levels is not None 

329 writer.write(fermi_levels=self.fermi_levels * Ha, 

330 eigenvalues=eig_skn * Ha, 

331 occupations=occ_skn) 

332 ibz = self.ibz 

333 writer.child('kpts').write( 

334 atommap=ibz.symmetries.atommap_sa, 

335 bz2ibz=ibz.bz2ibz_K, 

336 bzkpts=ibz.bz.kpt_Kc, 

337 ibzkpts=ibz.kpt_kc, 

338 rotations=ibz.symmetries.rotation_scc, 

339 translations=ibz.symmetries.translation_sc, 

340 weights=ibz.weight_k) 

341 

342 nproj = self.wfs_qs[0][0].P_ani.layout.size 

343 

344 spin_k_shape: tuple[int, ...] 

345 proj_shape: tuple[int, ...] 

346 

347 if self.collinear: 

348 spin_k_shape = (self.ncomponents, len(ibz)) 

349 proj_shape = (self.nbands, nproj) 

350 else: 

351 spin_k_shape = (len(ibz),) 

352 proj_shape = (self.nbands, 2, nproj) 

353 

354 if flags.include_projections: 

355 proj_dtype = flags.storage_dtype(self.dtype) 

356 writer.add_array('projections', spin_k_shape + proj_shape, 

357 proj_dtype) 

358 for spin in range(self.nspins): 

359 for k, rank in enumerate(self.rank_k): 

360 if rank == self.kpt_comm.rank: 

361 wfs = self.wfs_qs[self.q_k[k]][spin] 

362 P_ani = wfs.P_ani.to_cpu().gather() # gather atoms 

363 if P_ani is not None: 

364 P_nI = P_ani.matrix.gather() # gather bands 

365 if P_nI.dist.comm.rank == 0: 

366 if rank == 0: 

367 writer.fill(P_nI.data.reshape( 

368 proj_shape).astype(proj_dtype)) 

369 else: 

370 self.kpt_comm.send(P_nI.data, 0) 

371 elif self.comm.rank == 0: 

372 data = np.empty(proj_shape, self.dtype) 

373 self.kpt_comm.receive(data, rank) 

374 writer.fill(data.astype(proj_dtype)) 

375 

376 if flags.include_wfs: 

377 self._write_wave_functions(writer, spin_k_shape, flags) 

378 

379 def _write_wave_functions(self, writer, spin_k_shape, flags): 

380 # We collect all bands to master. This may have to be changed 

381 # to only one band at a time XXX 

382 xshape = self.get_max_shape(global_shape=True) 

383 shape = spin_k_shape + (self.nbands,) + xshape 

384 dtype = complex if self.mode == 'pw' else self.dtype 

385 dtype_write = flags.storage_dtype(dtype) 

386 c = 1.0 if self.mode == 'lcao' else Bohr**-1.5 

387 

388 writer.add_array('coefficients', shape, dtype=dtype_write) 

389 buf_nX = np.empty((self.nbands,) + xshape, dtype=dtype) 

390 

391 for spin in range(self.nspins): 

392 for k, rank in enumerate(self.rank_k): 

393 if rank == self.kpt_comm.rank: 

394 wfs = self.wfs_qs[self.q_k[k]][spin] 

395 coef_nX = wfs.gather_wave_function_coefficients() 

396 if coef_nX is not None: 

397 coef_nX = as_np(coef_nX) 

398 if self.mode == 'pw': 

399 x = coef_nX.shape[-1] 

400 if x < xshape[-1]: 

401 # For PW-mode, we may need to zero-pad the 

402 # plane-wave coefficient up to the maximum 

403 # for all k-points: 

404 buf_nX[..., :x] = coef_nX 

405 buf_nX[..., x:] = 0.0 

406 coef_nX = buf_nX 

407 if rank == 0: 

408 writer.fill(flags.to_storage_dtype(coef_nX * c)) 

409 else: 

410 self.kpt_comm.send(coef_nX, 0) 

411 elif self.comm.rank == 0: 

412 self.kpt_comm.receive(buf_nX, rank) 

413 writer.fill(flags.to_storage_dtype(buf_nX * c)) 

414 

415 def write_summary(self, log): 

416 fl = self.fermi_levels * Ha 

417 if len(fl) == 1: 

418 log(f'\nFermi level: {fl[0]:.3f}') 

419 else: 

420 log(f'\nFermi levels: {fl[0]:.3f}, {fl[1]:.3f}') 

421 

422 ibz = self.ibz 

423 

424 eig_skn, occ_skn = self.get_all_eigs_and_occs() 

425 

426 if self.comm.rank != 0: 

427 return 

428 

429 eig_skn *= Ha 

430 

431 D = self.spin_degeneracy 

432 nbands = eig_skn.shape[2] 

433 

434 for k, (x, y, z) in enumerate(ibz.kpt_kc): 

435 if k == 3: 

436 log(f'(only showing first 3 out of {len(ibz)} k-points)') 

437 break 

438 

439 log(f'\nkpt = [{x:.3f}, {y:.3f}, {z:.3f}], ' 

440 f'weight = {ibz.weight_k[k]:.3f}:') 

441 

442 if self.nspins == 1: 

443 skipping = False 

444 log(f' Band eig [eV] occ [0-{D}]') 

445 eig_n = eig_skn[0, k] 

446 n0 = (eig_n < fl[0]).sum() - 0.5 

447 for n, (e, f) in enumerate(zips(eig_n, occ_skn[0, k])): 

448 # First, last and +-8 bands window around Fermi level: 

449 if n == 0 or abs(n - n0) < 8 or n == nbands - 1: 

450 log(f' {n:4} {e:13.3f} {D * f:9.3f}') 

451 skipping = False 

452 else: 

453 if not skipping: 

454 log(' ...') 

455 skipping = True 

456 else: 

457 log(' Band eig [eV] occ [0-1]' 

458 ' eig [eV] occ [0-1]') 

459 for n, (e1, f1, e2, f2) in enumerate(zips(eig_skn[0, k], 

460 occ_skn[0, k], 

461 eig_skn[1, k], 

462 occ_skn[1, k])): 

463 log(f' {n:4} {e1:13.3f} {f1:9.3f}' 

464 f' {e2:10.3f} {f2:9.3f}') 

465 

466 try: 

467 from ase.dft.bandgap import GapInfo 

468 except ImportError: 

469 log('No gapinfo -- requires new ASE') 

470 return 

471 

472 try: 

473 log() 

474 fermilevel = fl[0] 

475 gapinfo = GapInfo(eigenvalues=eig_skn - fermilevel) 

476 log(gapinfo.description(ibz_kpoints=ibz.kpt_kc)) 

477 except ValueError: 

478 # Maybe we only have the occupied bands and no empty bands 

479 log('Could not find a gap') 

480 

481 def make_sure_wfs_are_read_from_gpw_file(self): 

482 for wfs in self: 

483 psit_nX = getattr(wfs, 'psit_nX', None) 

484 if psit_nX is None: 

485 return 

486 if hasattr(psit_nX.data, 'fd'): # fd=file-descriptor 

487 self.read_from_file_init_wfs_dm = True 

488 psit_nX.data = np.ascontiguousarray(psit_nX.data[:]) # read 

489 

490 def get_homo_lumo(self, spin: int = None) -> Array1D: 

491 """Return HOMO and LUMO eigenvalues.""" 

492 if self.ncomponents == 1: 

493 assert spin != 1 

494 spin = 0 

495 elif self.ncomponents == 2: 

496 if spin is None: 

497 h0, l0 = self.get_homo_lumo(0) 

498 h1, l1 = self.get_homo_lumo(1) 

499 return np.array([max(h0, h1), min(l0, l1)]) 

500 else: 

501 assert spin != 1 

502 spin = 0 

503 

504 nocc = 0.0 

505 for wfs_s in self.wfs_qs: 

506 wfs = wfs_s[spin] 

507 nocc += wfs.occ_n.sum() * wfs.weight 

508 nocc = self.kpt_comm.sum_scalar(nocc) 

509 n = int(round(nocc)) 

510 

511 homo = -np.inf 

512 if n > 0: 

513 for wfs_s in self.wfs_qs: 

514 homo = max(homo, wfs_s[spin].eig_n[n - 1]) 

515 homo = self.kpt_comm.max_scalar(homo) 

516 

517 lumo = np.inf 

518 if n < self.nbands: 

519 for wfs_s in self.wfs_qs: 

520 lumo = min(lumo, wfs_s[spin].eig_n[n]) 

521 lumo = self.kpt_comm.min_scalar(lumo) 

522 

523 return np.array([homo, lumo]) 

524 

525 def calculate_kinetic_energy(self, 

526 hamiltonian, 

527 density: Density) -> float: 

528 e_kin = 0.0 

529 for wfs in self: 

530 e_kin += hamiltonian.calculate_kinetic_energy(wfs, skip_sum=True) 

531 e_kin = self.comm.sum_scalar(e_kin) 

532 

533 # PAW corrections: 

534 e_kin_paw = 0.0 

535 for a, D_sii in density.D_asii.items(): 

536 setup = wfs.setups[a] 

537 D_p = pack_density(D_sii.real[:density.ndensities].sum(0)) 

538 e_kin_paw += setup.K_p @ D_p + setup.Kc 

539 e_kin_paw = density.grid.comm.sum_scalar(e_kin_paw) 

540 

541 return e_kin + e_kin_paw