Coverage for gpaw/response/chi0_base.py: 97%

210 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-12 00:18 +0000

1from __future__ import annotations 

2from abc import ABC, abstractmethod 

3 

4import numpy as np 

5 

6from typing import TYPE_CHECKING 

7 

8from ase.units import Ha 

9from gpaw.bztools import convex_hull_volume 

10from gpaw.response import timer 

11from gpaw.response.pair import KPointPairFactory 

12from gpaw.response.frequencies import NonLinearFrequencyDescriptor 

13from gpaw.response.qpd import SingleQPWDescriptor 

14from gpaw.response.pw_parallelization import block_partition 

15from gpaw.response.integrators import ( 

16 Integrand, PointIntegrator, TetrahedronIntegrator, Domain) 

17from gpaw.response.symmetry import QSymmetryInput, QSymmetryAnalyzer 

18from gpaw.response.kpoints import KPointDomain, KPointDomainGenerator 

19 

20if TYPE_CHECKING: 

21 from gpaw.response.pair import ActualPairDensityCalculator 

22 from gpaw.response.context import ResponseContext 

23 from gpaw.response.groundstate import ResponseGroundStateAdapter 

24 

25 

26class Chi0Integrand(Integrand): 

27 def __init__(self, chi0calc: Chi0ComponentPWCalculator, 

28 optical: bool, 

29 qpd: SingleQPWDescriptor, 

30 generator: KPointDomainGenerator, 

31 n1: int, 

32 n2: int, 

33 m1: int, 

34 m2: int): 

35 """ 

36 n1 : int 

37 Lower occupied band index. 

38 n2 : int 

39 Upper occupied band index. 

40 m1 : int 

41 Lower unoccupied band index. 

42 m2 : int 

43 Upper unoccupied band index. 

44 """ 

45 

46 assert m1 <= m2 

47 assert n1 < n2 <= chi0calc.gs.nocc2 

48 assert n1 <= chi0calc.gs.nocc1 

49 assert chi0calc.gs.nocc1 <= m1 

50 self.m1 = m1 

51 self.m2 = m2 

52 self.n1 = n1 

53 self.n2 = n2 

54 

55 self._chi0calc = chi0calc 

56 

57 self.gs: ResponseGroundStateAdapter = chi0calc.gs 

58 

59 self.context: ResponseContext = chi0calc.context 

60 self.kptpair_factory: KPointPairFactory = chi0calc.kptpair_factory 

61 

62 self.qpd = qpd 

63 self.generator = generator 

64 self.integrationmode = chi0calc.integrationmode 

65 self.optical = optical 

66 self.blockcomm = chi0calc.blockcomm 

67 

68 @timer('Get matrix element') 

69 def matrix_element(self, point): 

70 """Return pair density matrix element for integration. 

71 

72 A pair density is defined as:: 

73 

74 <snk| e^(-i (q + G) r) |s'mk+q>, 

75 

76 where s and s' are spins, n and m are band indices, k is 

77 the kpoint and q is the momentum transfer. For dielectric 

78 response s'=s, for the transverse magnetic response 

79 s' is flipped with respect to s. 

80 

81 Parameters 

82 ---------- 

83 k_v : ndarray 

84 Kpoint coordinate in cartesian coordinates. 

85 s : int 

86 Spin index. 

87 

88 If self.optical, then return optical pair densities, that is, the 

89 head and wings matrix elements indexed by: 

90 # P = (x, y, v, G1, G2, ...). 

91 

92 Return 

93 ------ 

94 n_nmG : ndarray 

95 Pair densities. 

96 """ 

97 

98 if self.optical: 

99 # pair_calc: ActualPairDensityCalculator from gpaw.response.pair 

100 target_method = self._chi0calc.pair_calc.get_optical_pair_density 

101 out_ngmax = self.qpd.ngmax + 2 

102 else: 

103 target_method = self._chi0calc.pair_calc.get_pair_density 

104 out_ngmax = self.qpd.ngmax 

105 

106 return self._get_any_matrix_element( 

107 point, target_method=target_method, 

108 ).reshape(-1, out_ngmax) 

109 

110 def _get_any_matrix_element(self, point, target_method): 

111 qpd = self.qpd 

112 

113 k_v = point.kpt_c # XXX c/v discrepancy 

114 

115 k_c = np.dot(qpd.gd.cell_cv, k_v) / (2 * np.pi) 

116 K = self.gs.kpoints.kptfinder.find(k_c) 

117 # assert point.K == K, (point.K, K) 

118 

119 weight = np.sqrt(self.generator.get_kpoint_weight(k_c) 

120 / self.generator.how_many_symmetries()) 

121 

122 # Here we're again setting pawcorr willy-nilly 

123 if self._chi0calc.pawcorr is None: 

124 pairden_paw_corr = self.gs.pair_density_paw_corrections 

125 self._chi0calc.pawcorr = pairden_paw_corr(qpd) 

126 

127 kptpair = self.kptpair_factory.get_kpoint_pair( 

128 qpd, point.spin, K, self.n1, self.n2, 

129 self.m1, self.m2, blockcomm=self.blockcomm) 

130 

131 m_m = np.arange(self.m1, self.m2) 

132 n_n = np.arange(self.n1, self.n2) 

133 n_nmG = target_method(qpd, kptpair, n_n, m_m, 

134 pawcorr=self._chi0calc.pawcorr, 

135 block=True) 

136 

137 if self.integrationmode == 'point integration': 

138 n_nmG *= weight 

139 

140 df_nm = kptpair.get_occupation_differences() 

141 df_nm[df_nm <= 1e-20] = 0.0 

142 n_nmG *= df_nm[..., np.newaxis]**0.5 

143 

144 return n_nmG 

145 

146 @timer('Get eigenvalues') 

147 def eigenvalues(self, point): 

148 """A function that can return the eigenvalues. 

149 

150 A simple function describing the integrand of 

151 the response function which gives an output that 

152 is compatible with the gpaw k-point integration 

153 routines.""" 

154 

155 qpd = self.qpd 

156 gs = self.gs 

157 kd = gs.kd 

158 

159 k_v = point.kpt_c # XXX c/v discrepancy 

160 

161 k_c = np.dot(qpd.gd.cell_cv, k_v) / (2 * np.pi) 

162 kptfinder = self.gs.kpoints.kptfinder 

163 K1 = kptfinder.find(k_c) 

164 K2 = kptfinder.find(k_c + qpd.q_c) 

165 

166 ik1 = kd.bz2ibz_k[K1] 

167 ik2 = kd.bz2ibz_k[K2] 

168 kpt1 = gs.kpt_qs[ik1][point.spin] 

169 assert kd.comm.size == 1 

170 kpt2 = gs.kpt_qs[ik2][point.spin] 

171 deps_nm = np.subtract(kpt1.eps_n[self.n1:self.n2][:, np.newaxis], 

172 kpt2.eps_n[self.m1:self.m2]) 

173 return deps_nm.reshape(-1) 

174 

175 

176class Chi0ComponentCalculator: 

177 """Base class for the Chi0XXXCalculator suite.""" 

178 

179 def __init__(self, gs, context, *, nblocks, 

180 qsymmetry: QSymmetryInput = True, 

181 integrationmode='point integration'): 

182 """Set up attributes common to all chi0 related calculators. 

183 

184 Parameters 

185 ---------- 

186 nblocks : int 

187 Divide response function memory allocation in nblocks. 

188 qsymmetry: bool, dict, or QSymmetryAnalyzer 

189 QSymmetryAnalyzer, or bool to enable all/no symmetries, 

190 or dict with which to create QSymmetryAnalyzer. 

191 Disabling symmetries may be useful for debugging. 

192 integrationmode : str 

193 Integrator for the k-point integration. 

194 If == 'tetrahedron integration' then the kpoint integral is 

195 performed using the linear tetrahedron method. 

196 If == 'point integration', point integration is used. 

197 """ 

198 self.gs = gs 

199 self.context = context 

200 self.kptpair_factory = KPointPairFactory(gs, context) 

201 

202 self.nblocks = nblocks 

203 self.blockcomm, self.kncomm = block_partition( 

204 self.context.comm, self.nblocks) 

205 

206 self.qsymmetry = QSymmetryAnalyzer.from_input(qsymmetry) 

207 

208 # Set up integrator 

209 self.integrationmode = integrationmode 

210 self.integrator = self.construct_integrator() 

211 

212 @property 

213 def pbc(self): 

214 return self.gs.pbc 

215 

216 def construct_integrator(self): # -> Integrator or child of Integrator 

217 """Construct k-point integrator""" 

218 cls = self.get_integrator_cls() 

219 return cls( 

220 cell_cv=self.gs.gd.cell_cv, 

221 context=self.context, 

222 blockcomm=self.blockcomm, 

223 kncomm=self.kncomm) 

224 

225 def get_integrator_cls(self): # -> Integrator or child of Integrator 

226 """Get the appointed k-point integrator class.""" 

227 if self.integrationmode == 'point integration': 

228 self.context.print('Using integrator: PointIntegrator') 

229 cls = PointIntegrator 

230 elif self.integrationmode == 'tetrahedron integration': 

231 self.context.print('Using integrator: TetrahedronIntegrator') 

232 cls = TetrahedronIntegrator 

233 if not self.qsymmetry.disabled: 

234 self.check_high_symmetry_ibz_kpts() 

235 else: 

236 raise ValueError(f'Integration mode "{self.integrationmode}"' 

237 ' not implemented.') 

238 return cls 

239 

240 def check_high_symmetry_ibz_kpts(self): 

241 """Check that the ground state includes all corners of the IBZ.""" 

242 ibz_vertices_kc = self.gs.get_ibz_vertices() 

243 # Here we mimic the k-point grid compatibility check of 

244 # gpaw.bztools.find_high_symmetry_monkhorst_pack() 

245 bzk_kc = self.gs.kd.bzk_kc 

246 for ibz_vertex_c in ibz_vertices_kc: 

247 # Relative coordinate difference to the k-point grid 

248 diff_kc = np.abs(bzk_kc - ibz_vertex_c)[:, self.gs.pbc].round(6) 

249 # The ibz vertex should exits in the BZ grid up to a reciprocal 

250 # lattice vector, meaning that the relative coordinate difference 

251 # is allowed to be an integer. Thus, at least one relative k-point 

252 # difference should vanish, modulo 1 

253 mod_diff_kc = np.mod(diff_kc, 1) 

254 nodiff_k = np.all(mod_diff_kc < 1e-5, axis=1) 

255 if not np.any(nodiff_k): 

256 raise ValueError( 

257 'The ground state k-point grid does not include all ' 

258 'vertices of the IBZ. ' 

259 'Please use find_high_symmetry_monkhorst_pack() from ' 

260 'gpaw.bztools to generate your k-point grid.') 

261 

262 def get_integration_domain(self, q_c, spins): 

263 """Get integrator domain and prefactor for the integral.""" 

264 for spin in spins: 

265 assert spin in range(self.gs.nspins) 

266 # The integration domain is determined by the following function 

267 # that reduces the integration domain to the irreducible zone 

268 # of the little group of q. 

269 symmetries, generator, kpoints = self.get_kpoints( 

270 q_c, integrationmode=self.integrationmode) 

271 

272 domain = Domain(kpoints.k_kv, spins) 

273 

274 if self.integrationmode == 'tetrahedron integration': 

275 # If there are non-periodic directions it is possible that the 

276 # integration domain is not compatible with the symmetry operations 

277 # which essentially means that too large domains will be 

278 # integrated. We normalize by vol(BZ) / vol(domain) to make 

279 # sure that to fix this. 

280 domainvol = convex_hull_volume( 

281 kpoints.k_kv) * generator.how_many_symmetries() 

282 bzvol = (2 * np.pi)**3 / self.gs.volume 

283 factor = bzvol / domainvol 

284 else: 

285 factor = 1 

286 

287 prefactor = (2 * factor * generator.how_many_symmetries() 

288 / (self.gs.nspins * (2 * np.pi)**3)) # Remember prefactor 

289 

290 if self.integrationmode == 'point integration': 

291 nbzkpts = self.gs.kd.nbzkpts 

292 prefactor *= len(kpoints) / nbzkpts 

293 

294 return symmetries, generator, domain, prefactor 

295 

296 @timer('Get kpoints') 

297 def get_kpoints(self, q_c, integrationmode): 

298 """Get the integration domain.""" 

299 symmetries, generator = self.qsymmetry.analyze( 

300 np.asarray(q_c), self.gs.kpoints, self.context) 

301 

302 if integrationmode == 'point integration': 

303 k_kc = generator.get_kpt_domain() 

304 elif integrationmode == 'tetrahedron integration': 

305 k_kc = generator.get_tetrahedron_kpt_domain( 

306 pbc_c=self.pbc, cell_cv=self.gs.gd.cell_cv) 

307 kpoints = KPointDomain(k_kc, self.gs.gd.icell_cv) 

308 

309 # In the future, we probably want to put enough functionality on the 

310 # KPointDomain such that we don't need to also return the 

311 # KPointDomainGenerator XXX 

312 return symmetries, generator, kpoints 

313 

314 def get_gs_info_string(self, tab=''): 

315 gs = self.gs 

316 gd = gs.gd 

317 

318 ns = gs.nspins 

319 nk = gs.kd.nbzkpts 

320 nik = gs.kd.nibzkpts 

321 

322 nocc = self.gs.nocc1 

323 npocc = self.gs.nocc2 

324 ngridpoints = gd.N_c[0] * gd.N_c[1] * gd.N_c[2] 

325 nstat = ns * npocc 

326 occsize = nstat * ngridpoints * 16. / 1024**2 

327 

328 gs_list = [f'{tab}Ground state adapter containing:', 

329 f'Number of spins: {ns}', f'Number of kpoints: {nk}', 

330 f'Number of irreducible kpoints: {nik}', 

331 f'Number of completely occupied states: {nocc}', 

332 f'Number of partially occupied states: {npocc}', 

333 f'Occupied states memory: {occsize} M / cpu'] 

334 

335 return f'\n{tab}'.join(gs_list) 

336 

337 

338class Chi0ComponentPWCalculator(Chi0ComponentCalculator, ABC): 

339 """Base class for Chi0XXXCalculators, which utilize a plane-wave basis.""" 

340 

341 def __init__(self, gs, context, 

342 *, 

343 wd, 

344 hilbert=True, 

345 nbands=None, 

346 timeordered=False, 

347 ecut=50.0, 

348 eta=0.2, 

349 **kwargs): 

350 """Set up attributes to calculate the chi0 body and optical extensions. 

351 

352 Parameters 

353 ---------- 

354 wd : FrequencyDescriptor 

355 Frequencies for which the chi0 component is evaluated. 

356 hilbert : bool 

357 Hilbert transform flag. If True, the dissipative part of the chi0 

358 component is evaluated, and the reactive part is calculated via a 

359 hilbert transform. Only works for frequencies on the real axis and 

360 requires a nonlinear frequency grid. 

361 nbands : int or slice 

362 Number of bands to include. 

363 timeordered : bool 

364 Flag for calculating the time ordered chi0 component. Used for 

365 G0W0, which performs its own hilbert transform. 

366 ecut : float | dict 

367 Plane-wave energy cutoff in eV or dictionary for the plane-wave 

368 descriptor type. See response/qpd.py for details. 

369 eta : float 

370 Artificial broadening of the chi0 component in eV. 

371 """ 

372 super().__init__(gs, context, **kwargs) 

373 

374 if not isinstance(ecut, dict): 

375 ecut /= Ha 

376 self.ecut = ecut 

377 self.nbands = nbands or self.gs.nbands 

378 

379 self.wd = wd 

380 self.context.print(self.wd, flush=False) 

381 self.eta = eta / Ha 

382 self.hilbert = hilbert 

383 self.task = self.construct_integral_task() 

384 

385 self.timeordered = bool(timeordered) 

386 if self.timeordered: 

387 assert self.hilbert # Timeordered is only needed for G0W0 

388 

389 self.pawcorr = None 

390 

391 self.context.print('Nonperiodic BCs: ', (~self.pbc)) 

392 if sum(self.pbc) == 1: 

393 raise ValueError('1-D not supported atm.') 

394 

395 @property 

396 def pair_calc(self) -> ActualPairDensityCalculator: 

397 return self.kptpair_factory.pair_calculator(self.blockcomm) 

398 

399 def construct_integral_task(self): 

400 if self.eta == 0: 

401 assert not self.hilbert 

402 # eta == 0 is used as a synonym for calculating the hermitian part 

403 # of the response function at a range of imaginary frequencies 

404 assert not self.wd.omega_w.real.any() 

405 return self.construct_hermitian_task() 

406 

407 if self.hilbert: 

408 # The hilbert flag is used to calculate the reponse function via a 

409 # hilbert transform of its dissipative (spectral) part. 

410 assert isinstance(self.wd, NonLinearFrequencyDescriptor) 

411 return self.construct_hilbert_task() 

412 

413 # Otherwise, we perform a literal evaluation of the response function 

414 # at the given frequencies with broadening eta 

415 return self.construct_literal_task() 

416 

417 @abstractmethod 

418 def construct_hermitian_task(self): 

419 """Integral task for the hermitian part of chi0.""" 

420 

421 def construct_hilbert_task(self): 

422 if isinstance(self.integrator, PointIntegrator): 

423 return self.construct_point_hilbert_task() 

424 else: 

425 assert isinstance(self.integrator, TetrahedronIntegrator) 

426 return self.construct_tetra_hilbert_task() 

427 

428 @abstractmethod 

429 def construct_point_hilbert_task(self): 

430 """Integral task for point integrating the spectral part of chi0.""" 

431 

432 @abstractmethod 

433 def construct_tetra_hilbert_task(self): 

434 """Integral task for tetrahedron integration of the spectral part.""" 

435 

436 @abstractmethod 

437 def construct_literal_task(self): 

438 """Integral task for a literal evaluation of chi0.""" 

439 

440 def get_pw_descriptor(self, q_c): 

441 return SingleQPWDescriptor.from_q(q_c, self.ecut, self.gs.gd) 

442 

443 def get_response_info_string(self, qpd, tab=''): 

444 nw = len(self.wd) 

445 if not isinstance(self.ecut, dict): 

446 ecut = self.ecut * Ha 

447 else: 

448 ecut = self.ecut 

449 nbands = self.nbands 

450 ngmax = qpd.ngmax 

451 eta = self.eta * Ha 

452 

453 res_list = [f'{tab}Number of frequency points: {nw}', 

454 f'Planewave cutoff: {ecut}', 

455 f'Number of bands: {nbands}', 

456 f'Number of planewaves: {ngmax}', 

457 f'Broadening (eta): {eta}'] 

458 

459 return f'\n{tab}'.join(res_list)