Coverage for gpaw/response/groundstate.py: 94%

253 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1from __future__ import annotations 

2from dataclasses import dataclass 

3from typing import Union 

4from pathlib import Path 

5from functools import cached_property 

6from types import SimpleNamespace 

7from typing import TYPE_CHECKING 

8import numpy as np 

9 

10from ase.units import Ha, Bohr 

11 

12import gpaw.mpi as mpi 

13from gpaw.ibz2bz import IBZ2BZMaps 

14from gpaw.calculator import GPAW as OldGPAW 

15from gpaw.new.ase_interface import ASECalculator as NewGPAW 

16from gpaw.response.paw import LeanPAWDataset 

17 

18if TYPE_CHECKING: 

19 from gpaw.setup import Setups, LeanSetup 

20 

21 

22class PAWDatasetCollection: 

23 def __init__(self, setups: Setups): 

24 by_species = {} 

25 by_atom = [] 

26 id_by_atom = [] 

27 

28 for atom_id, setup in enumerate(setups): 

29 species_id = setups.id_a[atom_id] 

30 if species_id not in by_species: 

31 by_species[species_id] = ResponsePAWDataset(setup) 

32 by_atom.append(by_species[species_id]) 

33 id_by_atom.append(species_id) 

34 

35 self.by_species = by_species 

36 self.by_atom = by_atom 

37 self.id_by_atom = id_by_atom 

38 

39 

40GPAWCalculator = Union[OldGPAW, NewGPAW] 

41GPWFilename = Union[Path, str] 

42ResponseGroundStateAdaptable = Union['ResponseGroundStateAdapter', 

43 GPAWCalculator, 

44 GPWFilename] 

45 

46 

47class ResponseGroundStateAdapter: 

48 def __init__(self, calc: GPAWCalculator): 

49 wfs = calc.wfs # wavefunction object from gpaw.wavefunctions 

50 

51 self.atoms = calc.atoms 

52 self.kd = wfs.kd # KPointDescriptor object from gpaw.kpt_descriptor. 

53 self.world = calc.world # _Communicator object from gpaw.mpi 

54 

55 # GridDescriptor from gpaw.grid_descriptor. 

56 # Describes a grid in real space 

57 self.gd = wfs.gd 

58 

59 # Also a GridDescriptor, with a finer grid... 

60 self.finegd = calc.density.finegd 

61 self.bd = wfs.bd # BandDescriptor from gpaw.band_descriptor 

62 self.nspins = wfs.nspins # number of spins: int 

63 self.dtype = wfs.dtype # data type of wavefunctions, real or complex 

64 

65 self.spos_ac = calc.spos_ac # scaled position vector: np.ndarray 

66 

67 self.kpt_u = wfs.kpt_u # kpoints: list of Kpoint from gpaw.kpoint 

68 self.kpt_qs = wfs.kpt_qs # kpoints: list of Kpoint from gpaw.kpoint 

69 

70 self.fermi_level = wfs.fermi_level # float 

71 self.atoms = calc.atoms # ASE Atoms object 

72 self.pawdatasets = PAWDatasetCollection(calc.setups) 

73 

74 self.pbc = self.atoms.pbc 

75 self.volume = self.gd.volume 

76 

77 self.nvalence = int(round(wfs.nvalence)) 

78 assert self.nvalence == wfs.nvalence 

79 

80 self.nocc1, self.nocc2 = self.count_occupied_bands() 

81 

82 self.ibz2bz = IBZ2BZMaps.from_calculator(calc) 

83 

84 self._wfs = wfs 

85 self._density = calc.density 

86 self._hamiltonian = calc.hamiltonian 

87 self._calc = calc 

88 

89 @staticmethod 

90 def from_input( 

91 gs: ResponseGroundStateAdaptable) -> ResponseGroundStateAdapter: 

92 if isinstance(gs, ResponseGroundStateAdapter): 

93 return gs 

94 elif isinstance(gs, (OldGPAW, NewGPAW)): 

95 return ResponseGroundStateAdapter(calc=gs) 

96 elif isinstance(gs, (Path, str)): # GPWFilename 

97 return ResponseGroundStateAdapter.from_gpw_file(gpw=gs) 

98 raise ValueError('Expected ResponseGroundStateAdaptable, got', gs) 

99 

100 @classmethod 

101 def from_gpw_file(cls, gpw: GPWFilename) -> ResponseGroundStateAdapter: 

102 """Initiate the ground state adapter directly from a .gpw file.""" 

103 from gpaw import GPAW, disable_dry_run 

104 assert Path(gpw).is_file() 

105 with disable_dry_run(): 

106 calc = GPAW(gpw, txt=None, communicator=mpi.serial_comm) 

107 return cls(calc) 

108 

109 @property 

110 def pd(self): 

111 # This is an attribute error in FD/LCAO mode. 

112 # We need to abstract away "calc" in all places used by response 

113 # code, and that includes places that are also compatible with FD. 

114 return self._wfs.pd 

115 

116 def is_parallelized(self): 

117 """Are we dealing with a parallel calculator?""" 

118 return self.world.size > 1 

119 

120 @cached_property 

121 def global_pd(self): 

122 """Get a PWDescriptor that includes all k-points. 

123 

124 In particular, this is necessary to allow all cores to be able to work 

125 on all k-points in the case where calc is parallelized over k-points, 

126 see gpaw.response.kspair 

127 """ 

128 from gpaw.pw.descriptor import PWDescriptor 

129 

130 assert self.gd.comm.size == 1 

131 kd = self.kd.copy() # global KPointDescriptor without a comm 

132 return PWDescriptor(self.pd.ecut, self.gd, 

133 dtype=self.pd.dtype, 

134 kd=kd, fftwflags=self.pd.fftwflags, 

135 gammacentered=self.pd.gammacentered) 

136 

137 def get_occupations_width(self): 

138 # Ugly hack only used by pair.intraband_pair_density I think. 

139 # Actually: was copy-pasted in chi0 also. 

140 # More duplication can probably be eliminated around those. 

141 

142 # Only works with Fermi-Dirac distribution 

143 occs = self._wfs.occupations 

144 assert occs.name in {'fermi-dirac', 'zero-width'} 

145 

146 # No carriers when T=0 

147 width = getattr(occs, '_width', 0.0) / Ha 

148 return width 

149 

150 @cached_property 

151 def cd(self): 

152 return CellDescriptor(self.gd.cell_cv, self.pbc) 

153 

154 @property 

155 def nt_sR(self): 

156 # Used by localft and fxc_kernels 

157 return self._density.nt_sG 

158 

159 @property 

160 def nt_sr(self): 

161 # Used by localft 

162 if self._density.nt_sg is None: 

163 self._density.interpolate_pseudo_density() 

164 return self._density.nt_sg 

165 

166 @cached_property 

167 def n_sR(self): 

168 return self._density.get_all_electron_density( 

169 atoms=self.atoms, gridrefinement=1)[0] 

170 

171 @cached_property 

172 def n_sr(self): 

173 return self._density.get_all_electron_density( 

174 atoms=self.atoms, gridrefinement=2)[0] 

175 

176 @property 

177 def D_asp(self): 

178 # Used by fxc_kernels 

179 return self._density.D_asp 

180 

181 def get_pseudo_density(self, gridrefinement=2): 

182 # Used by localft 

183 if gridrefinement == 1: 

184 return self.nt_sR, self.gd 

185 elif gridrefinement == 2: 

186 return self.nt_sr, self.finegd 

187 else: 

188 raise ValueError(f'Invalid gridrefinement {gridrefinement}') 

189 

190 def get_all_electron_density(self, gridrefinement=2): 

191 # Used by fxc, fxc_kernels and localft 

192 if gridrefinement == 1: 

193 return self.n_sR, self.gd 

194 elif gridrefinement == 2: 

195 return self.n_sr, self.finegd 

196 else: 

197 raise ValueError(f'Invalid gridrefinement {gridrefinement}') 

198 

199 # Things used by EXX. This is getting pretty involved. 

200 # 

201 # EXX naughtily accesses the density object in order to 

202 # interpolate_pseudo_density() which is in principle mutable. 

203 

204 def hacky_all_electron_density(self, **kwargs): 

205 # fxc likes to get all electron densities. It calls 

206 # calc.get_all_electron_density() and so we wrap that here. 

207 # But it also collects to serial (bad), and it also zeropads 

208 # nonperiodic directions (probably WRONG!). 

209 # 

210 # Also this one returns in user units, whereas the calling 

211 # code actually wants internal units. Very silly then. 

212 # 

213 # ALso, the calling code often wants the gd, which is not 

214 # returned, so it is redundantly reconstructed in multiple 

215 # places by refining the "right" number of times. 

216 n_g = self._calc.get_all_electron_density(**kwargs) 

217 n_g *= Bohr**3 

218 return n_g 

219 

220 # Used by EXX. 

221 @property 

222 def hamiltonian(self): 

223 return self._hamiltonian 

224 

225 # Used by EXX. 

226 @property 

227 def density(self): 

228 return self._density 

229 

230 # Ugh SOC 

231 def soc_eigenstates(self, **kwargs): 

232 from gpaw.spinorbit import soc_eigenstates 

233 return soc_eigenstates(self._calc, **kwargs) 

234 

235 @property 

236 def xcname(self): 

237 return self.hamiltonian.xc.name 

238 

239 def get_xc_difference(self, xc): 

240 # XXX used by gpaw/xc/tools.py 

241 return self._calc.get_xc_difference(xc) 

242 

243 def get_wave_function_array(self, u, n): 

244 # XXX used by gpaw/xc/tools.py in a hacky way 

245 return self._wfs._get_wave_function_array( 

246 u, n, realspace=True) 

247 

248 def pair_density_paw_corrections(self, qpd): 

249 from gpaw.response.paw import get_pair_density_paw_corrections 

250 return get_pair_density_paw_corrections( 

251 pawdatasets=self.pawdatasets, qpd=qpd, spos_ac=self.spos_ac, 

252 atomrotations=self.atomrotations) 

253 

254 def matrix_element_paw_corrections(self, qpd, rshe_a): 

255 from gpaw.response.paw import get_matrix_element_paw_corrections 

256 return get_matrix_element_paw_corrections( 

257 qpd, self.pawdatasets, rshe_a, self.spos_ac) 

258 

259 def get_pos_av(self): 

260 # gd.cell_cv must always be the same as pd.gd.cell_cv, right?? 

261 return np.dot(self.spos_ac, self.gd.cell_cv) 

262 

263 def count_occupied_bands(self, ftol: float = 1e-6) -> tuple[int, int]: 

264 """Count the number of filled (nocc1) and nonempty bands (nocc2). 

265 

266 ftol : float 

267 Threshold determining whether a band is completely filled 

268 (f > 1 - ftol) or completely empty (f < ftol). 

269 """ 

270 # Count the number of occupied bands for this rank 

271 nocc1, nocc2 = self._count_occupied_bands(ftol=ftol) 

272 # Minimize/maximize over k-points 

273 nocc1 = self.kd.comm.min_scalar(nocc1) # bands filled for all k 

274 nocc2 = self.kd.comm.max_scalar(nocc2) # bands filled for any k 

275 # Sum over band distribution 

276 nocc1 = self.bd.comm.sum_scalar(nocc1) # number of filled bands 

277 nocc2 = self.bd.comm.sum_scalar(nocc2) # number of nonempty bands 

278 return int(nocc1), int(nocc2) 

279 

280 def _count_occupied_bands(self, *, ftol: float) -> tuple[int, int]: 

281 nocc1 = 9999999 # number of completely filled bands 

282 nocc2 = 0 # number of nonempty bands 

283 for kpt in self.kpt_u: 

284 f_n = kpt.f_n / kpt.weight 

285 nocc1 = min((f_n > 1 - ftol).sum(), nocc1) 

286 nocc2 = max((f_n > ftol).sum(), nocc2) 

287 return int(nocc1), int(nocc2) 

288 

289 def get_band_transitions(self, nbands: int | slice | None = None): 

290 """Determine the indices the define the range of occupied bands 

291 n1, n2 and unoccupied bands m1, m2""" 

292 

293 if nbands is None: 

294 n1 = 0 

295 m2 = self.nbands 

296 elif isinstance(nbands, int): 

297 n1 = 0 

298 m2 = nbands 

299 assert 1 <= m2 <= self.nbands 

300 elif isinstance(nbands, slice): 

301 n1 = nbands.start 

302 m2 = nbands.stop 

303 assert n1 >= 0 and m2 >= 0 

304 assert nbands.step in {None, 1} 

305 assert n1 < m2 <= self.nbands 

306 assert n1 <= self.nocc1 

307 else: 

308 raise ValueError( 

309 f"Invalid type for nbands: {type(nbands)}." 

310 "Expected None, int, or slice.") 

311 

312 n2 = self.nocc2 

313 m1 = self.nocc1 

314 

315 assert n1 < n2 

316 

317 return n1, n2, m1, m2 

318 

319 def get_eigenvalue_range(self, nbands: int | slice | None = None): 

320 """Get smallest and largest Kohn-Sham eigenvalues.""" 

321 n1, n2, m1, m2 = self.get_band_transitions(nbands) 

322 epsmin = np.inf 

323 epsmax = -np.inf 

324 for kpt in self.kpt_u: 

325 epsmin = min(epsmin, kpt.eps_n[n1]) # the eigenvalues are ordered 

326 epsmax = max(epsmax, kpt.eps_n[m2 - 1]) 

327 return epsmin, epsmax 

328 

329 @property 

330 def nbands(self): 

331 return self.bd.nbands 

332 

333 @property 

334 def metallic(self): 

335 # Does the number of filled bands equal the number of non-empty bands? 

336 return self.nocc1 != self.nocc2 

337 

338 @cached_property 

339 def ibzq_qc(self): 

340 # For G0W0Kernel 

341 kd = self.kd 

342 bzq_qc = kd.get_bz_q_points(first=True) 

343 U_scc = kd.symmetry.op_scc 

344 ibzq_qc = kd.get_ibz_q_points(bzq_qc, U_scc)[0] 

345 

346 return ibzq_qc 

347 

348 def get_ibz_vertices(self): 

349 # For the tetrahedron method in Chi0 

350 from gpaw.bztools import get_bz 

351 # NB: We are ignoring the pbc_c keyword to get_bz() in order to mimic 

352 # find_high_symmetry_monkhorst_pack() in gpaw.bztools. XXX 

353 _, ibz_vertices_kc, _ = get_bz(self._calc) 

354 return ibz_vertices_kc 

355 

356 def get_aug_radii(self): 

357 return np.array([max(pawdata.rcut_j) 

358 for pawdata in self.pawdatasets.by_atom]) 

359 

360 @cached_property 

361 def micro_setups(self): 

362 from gpaw.response.localft import extract_micro_setup 

363 micro_setups = [] 

364 for a, pawdata in enumerate(self.pawdatasets.by_atom): 

365 micro_setups.append(extract_micro_setup(pawdata, self.D_asp[a])) 

366 return micro_setups 

367 

368 @property 

369 def atomrotations(self): 

370 return self._wfs.setups.atomrotations 

371 

372 @cached_property 

373 def kpoints(self): 

374 from gpaw.response.kpoints import ResponseKPointGrid 

375 return ResponseKPointGrid(self.kd) 

376 

377 

378@dataclass 

379class CellDescriptor: 

380 cell_cv: np.ndarray 

381 pbc_c: np.ndarray 

382 

383 @property 

384 def nonperiodic_hypervolume(self): 

385 """Get the hypervolume of the cell along nonperiodic directions. 

386 

387 Returns the hypervolume Λ in units of Å, where 

388 

389 Λ = 1 in 3D 

390 Λ = L in 2D, where L is the out-of-plane cell vector length 

391 Λ = A in 1D, where A is the transverse cell area 

392 Λ = V in 0D, where V is the cell volume 

393 """ 

394 cell_cv = self.cell_cv 

395 pbc_c = self.pbc_c 

396 if sum(pbc_c) > 0: 

397 # In 1D and 2D, we assume the cartesian representation of the unit 

398 # cell to be block diagonal, separating the periodic and 

399 # nonperiodic cell vectors in different blocks. 

400 assert np.allclose(cell_cv[~pbc_c][:, pbc_c], 0.) and \ 

401 np.allclose(cell_cv[pbc_c][:, ~pbc_c], 0.), \ 

402 "In 1D and 2D, please put the periodic/nonperiodic axis " \ 

403 "along a cartesian component" 

404 L = np.abs(np.linalg.det(cell_cv[~pbc_c][:, ~pbc_c])) 

405 return L * Bohr**sum(~pbc_c) # Bohr -> Å 

406 

407 

408# Contains all the relevant information 

409# from Setups class for response calculators 

410class ResponsePAWDataset(LeanPAWDataset): 

411 def __init__(self, setup: LeanSetup, **kwargs): 

412 super().__init__( 

413 rgd=setup.rgd, l_j=setup.l_j, rcut_j=setup.rcut_j, 

414 phit_jg=setup.data.phit_jg, phi_jg=setup.data.phi_jg, **kwargs) 

415 assert setup.ni == self.ni 

416 

417 self.n_j = setup.n_j 

418 self.N0_q = setup.N0_q 

419 self.nabla_iiv = setup.nabla_iiv 

420 self.xc_correction: SimpleNamespace | None 

421 if setup.xc_correction is not None: 

422 self.xc_correction = SimpleNamespace( 

423 rgd=setup.xc_correction.rgd, Y_nL=setup.xc_correction.Y_nL, 

424 n_qg=setup.xc_correction.n_qg, nt_qg=setup.xc_correction.nt_qg, 

425 nc_g=setup.xc_correction.nc_g, nct_g=setup.xc_correction.nct_g, 

426 nc_corehole_g=setup.xc_correction.nc_corehole_g, 

427 B_pqL=setup.xc_correction.B_pqL, 

428 e_xc0=setup.xc_correction.e_xc0) 

429 else: 

430 # If there is no `xc_correction` in the setup, we assume to be 

431 # using pseudo potentials. 

432 self.xc_correction = None 

433 # In this case, we set l_j to an empty list in order to bypass the 

434 # calculation of PAW corrections to pair densities etc. 

435 # This is quite an ugly hack... 

436 # If we want to support pseudo potential calculations for real, we 

437 # should skip the PAW corrections at the matrix element calculator 

438 # level, not by an odd hack. 

439 self.l_j = np.array([], dtype=float) 

440 self.hubbard_u = setup.hubbard_u