Coverage for gpaw/response/pair.py: 97%

269 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1import numpy as np 

2 

3from gpaw.response import ResponseContext, ResponseGroundStateAdapter, timer 

4from gpaw.response.pw_parallelization import block_partition 

5from gpaw.utilities.blas import mmm 

6 

7 

8class KPoint: 

9 def __init__(self, s, K, n1, n2, blocksize, na, nb, 

10 ut_nR, eps_n, f_n, P_ani, k_c): 

11 self.s = s # spin index 

12 self.K = K # BZ k-point index 

13 self.n1 = n1 # first band 

14 self.n2 = n2 # first band not included 

15 self.blocksize = blocksize 

16 self.na = na # first band of block 

17 self.nb = nb # first band of block not included 

18 self.ut_nR = ut_nR # periodic part of wave functions in real-space 

19 self.eps_n = eps_n # eigenvalues 

20 self.f_n = f_n # occupation numbers 

21 self.P_ani = P_ani # PAW projections 

22 self.k_c = k_c # k-point coordinates 

23 

24 

25class KPointPair: 

26 """This class defines the kpoint-pair container object. 

27 

28 Used for calculating pair quantities it contains two kpoints, 

29 and an associated set of Fourier components.""" 

30 def __init__(self, kpt1, kpt2, Q_G): 

31 self.kpt1 = kpt1 

32 self.kpt2 = kpt2 

33 self.Q_G = Q_G 

34 

35 def get_transition_energies(self): 

36 """Return the energy difference for specified bands.""" 

37 kpt1 = self.kpt1 

38 kpt2 = self.kpt2 

39 deps_nm = kpt1.eps_n[:, np.newaxis] - kpt2.eps_n 

40 return deps_nm 

41 

42 def get_occupation_differences(self): 

43 """Get difference in occupation factor between specified bands.""" 

44 kpt1 = self.kpt1 

45 kpt2 = self.kpt2 

46 df_nm = kpt1.f_n[:, np.newaxis] - kpt2.f_n 

47 return df_nm 

48 

49 

50class KPointPairFactory: 

51 def __init__(self, gs, context): 

52 self.gs = gs 

53 self.context = context 

54 assert self.gs.kd.symmetry.symmorphic 

55 assert self.gs.world.size == 1 

56 

57 @timer('Get a k-point') 

58 def get_k_point(self, s, K, n1, n2, blockcomm=None): 

59 """Return wave functions for a specific k-point and spin. 

60 

61 s: int 

62 Spin index (0 or 1). 

63 K: int 

64 BZ k-point index. 

65 n1, n2: int 

66 Range of bands to include. 

67 """ 

68 

69 assert n1 <= n2 

70 

71 gs = self.gs 

72 kd = gs.kd 

73 

74 if blockcomm: 

75 nblocks = blockcomm.size 

76 rank = blockcomm.rank 

77 else: 

78 nblocks = 1 

79 rank = 0 

80 

81 blocksize = (n2 - n1 + nblocks - 1) // nblocks 

82 na = min(n1 + rank * blocksize, n2) 

83 nb = min(na + blocksize, n2) 

84 

85 ik = kd.bz2ibz_k[K] 

86 assert kd.comm.size == 1 

87 kpt = gs.kpt_qs[ik][s] 

88 

89 assert n2 <= len(kpt.eps_n), \ 

90 'Increase GS-nbands or decrease chi0-nbands!' 

91 eps_n = kpt.eps_n[n1:n2] 

92 f_n = kpt.f_n[n1:n2] / kpt.weight 

93 

94 k_c = self.gs.ibz2bz[K].map_kpoint() 

95 

96 with self.context.timer('load wfs'): 

97 psit_nG = kpt.psit_nG 

98 ut_nR = gs.gd.empty(nb - na, gs.dtype) 

99 for n in range(na, nb): 

100 ut_nR[n - na] = self.gs.ibz2bz[K].map_pseudo_wave( 

101 gs.pd.ifft(psit_nG[n], ik)) 

102 

103 with self.context.timer('Load projections'): 

104 if nb - na > 0: 

105 proj = kpt.projections.new(nbands=nb - na, bcomm=None) 

106 proj.array[:] = kpt.projections.array[na:nb] 

107 proj = self.gs.ibz2bz[K].map_projections(proj) 

108 P_ani = [P_ni for _, P_ni in proj.items()] 

109 else: 

110 P_ani = [] 

111 

112 return KPoint(s, K, n1, n2, blocksize, na, nb, 

113 ut_nR, eps_n, f_n, P_ani, k_c) 

114 

115 @timer('Get kpoint pair') 

116 def get_kpoint_pair(self, qpd, s, K, n1, n2, m1, m2, 

117 blockcomm=None, flipspin=False): 

118 assert m1 <= m2 

119 assert n1 <= n2 

120 

121 kptfinder = self.gs.kpoints.kptfinder 

122 

123 k_c = self.gs.kd.bzk_kc[K] 

124 K1 = kptfinder.find(k_c) 

125 K2 = kptfinder.find(k_c + qpd.q_c) 

126 s1 = s 

127 s2 = (s + flipspin) % 2 

128 

129 with self.context.timer('get k-points'): 

130 kpt1 = self.get_k_point(s1, K1, n1, n2) 

131 kpt2 = self.get_k_point(s2, K2, m1, m2, blockcomm=blockcomm) 

132 

133 with self.context.timer('fft indices'): 

134 Q_G = phase_shifted_fft_indices(kpt1.k_c, kpt2.k_c, qpd) 

135 

136 return KPointPair(kpt1, kpt2, Q_G) 

137 

138 def pair_calculator(self, blockcomm=None): 

139 # We have decoupled the actual pair density calculator 

140 # from the kpoint factory, but it's still handy to 

141 # keep this shortcut -- for now. 

142 if blockcomm is None: 

143 blockcomm, _ = block_partition(self.context.comm, nblocks=1) 

144 return ActualPairDensityCalculator(self, blockcomm) 

145 

146 

147class ActualPairDensityCalculator: 

148 def __init__(self, kptpair_factory, blockcomm): 

149 # it seems weird to use kptpair_factory only for this 

150 self.gs = kptpair_factory.gs 

151 self.context = kptpair_factory.context 

152 self.blockcomm = blockcomm 

153 self.ut_sKnvR = None # gradient of wave functions for optical limit 

154 

155 def get_optical_pair_density(self, qpd, kptpair, n_n, m_m, *, 

156 pawcorr, block=False): 

157 """Get the full optical pair density, including the optical limit head 

158 for q=0.""" 

159 tmp_nmG = self.get_pair_density(qpd, kptpair, n_n, m_m, 

160 pawcorr=pawcorr, block=block) 

161 

162 nG = qpd.ngmax 

163 # P = (x, y, z, G1, G2, ...) 

164 n_nmP = np.empty((len(n_n), len(m_m), nG + 2), dtype=tmp_nmG.dtype) 

165 n_nmP[:, :, 3:] = tmp_nmG[:, :, 1:] 

166 n_nmv = self.get_optical_pair_density_head(qpd, kptpair, n_n, m_m, 

167 block=block) 

168 n_nmP[:, :, :3] = n_nmv 

169 

170 return n_nmP 

171 

172 @timer('get_pair_density') 

173 def get_pair_density(self, qpd, kptpair, n_n, m_m, *, 

174 pawcorr, block=False): 

175 """Get pair density for a kpoint pair.""" 

176 cpd = self.calculate_pair_density 

177 

178 kpt1 = kptpair.kpt1 

179 kpt2 = kptpair.kpt2 

180 Q_G = kptpair.Q_G # Fourier components of kpoint pair 

181 nG = len(Q_G) 

182 

183 n_nmG = np.zeros((len(n_n), len(m_m), nG), qpd.dtype) 

184 

185 for j, n in enumerate(n_n): 

186 Q_G = kptpair.Q_G 

187 with self.context.timer('conj'): 

188 ut1cc_R = kpt1.ut_nR[n - kpt1.na].conj() 

189 with self.context.timer('paw'): 

190 C1_aGi = pawcorr.multiply(kpt1.P_ani, band=n - kpt1.na) 

191 n_nmG[j] = cpd(ut1cc_R, C1_aGi, kpt2, qpd, Q_G, block=block) 

192 

193 return n_nmG 

194 

195 @timer('get_optical_pair_density_head') 

196 def get_optical_pair_density_head(self, qpd, kptpair, n_n, m_m, 

197 block=False): 

198 """Get the optical limit of the pair density head (G=0) for a k-pair. 

199 """ 

200 assert np.allclose(qpd.q_c, 0.0), f"{qpd.q_c} is not the optical limit" 

201 

202 kpt1 = kptpair.kpt1 

203 kpt2 = kptpair.kpt2 

204 

205 # v = (x, y, z) 

206 n_nmv = np.zeros((len(n_n), len(m_m), 3), qpd.dtype) 

207 

208 for j, n in enumerate(n_n): 

209 n_nmv[j] = self.calculate_optical_pair_density_head(n, m_m, 

210 kpt1, kpt2, 

211 block=block) 

212 

213 return n_nmv 

214 

215 @timer('Calculate pair-densities') 

216 def calculate_pair_density(self, ut1cc_R, C1_aGi, kpt2, qpd, Q_G, 

217 block=True): 

218 """Calculate FFT of pair-densities and add PAW corrections. 

219 

220 ut1cc_R: 3-d complex ndarray 

221 Complex conjugate of the periodic part of the left hand side 

222 wave function. 

223 C1_aGi: list of ndarrays 

224 PAW corrections for all atoms. 

225 kpt2: KPoint object 

226 Right hand side k-point object. 

227 qpd: SingleQPWDescriptor 

228 Plane-wave descriptor for q=k2-k1. 

229 Q_G: 1-d int ndarray 

230 Mapping from flattened 3-d FFT grid to 0.5(G+q)^2<ecut sphere. 

231 """ 

232 dv = qpd.gd.dv 

233 n_mG = qpd.empty(kpt2.blocksize) 

234 myblocksize = kpt2.nb - kpt2.na 

235 

236 for ut_R, n_G in zip(kpt2.ut_nR, n_mG): 

237 n_R = ut1cc_R * ut_R 

238 with self.context.timer('fft'): 

239 n_G[:] = qpd.fft(n_R, 0, Q_G) * dv 

240 # PAW corrections: 

241 with self.context.timer('gemm'): 

242 for C1_Gi, P2_mi in zip(C1_aGi, kpt2.P_ani): 

243 # gemm(1.0, C1_Gi, P2_mi, 1.0, n_mG[:myblocksize], 't') 

244 mmm(1.0, P2_mi, 'N', C1_Gi, 'T', 1.0, n_mG[:myblocksize]) 

245 

246 if not block or self.blockcomm.size == 1: 

247 return n_mG 

248 else: 

249 n_MG = qpd.empty(kpt2.blocksize * self.blockcomm.size) 

250 with self.context.timer('all_gather'): 

251 self.blockcomm.all_gather(n_mG, n_MG) 

252 return n_MG[:kpt2.n2 - kpt2.n1] 

253 

254 @timer('Optical limit') 

255 def calculate_optical_pair_velocity(self, n, kpt1, kpt2, block=False): 

256 # This has the effect of caching at most one kpoint. 

257 # This caching will be efficient only if we are looping over kpoints 

258 # in a particular way. 

259 # 

260 # It would be better to refactor so this caching is handled explicitly 

261 # by the caller providing the right thing. 

262 # 

263 # See https://gitlab.com/gpaw/gpaw/-/issues/625 

264 if self.ut_sKnvR is None or kpt1.K not in self.ut_sKnvR[kpt1.s]: 

265 self.ut_sKnvR = self.calculate_derivatives(kpt1) 

266 

267 gd = self.gs.gd 

268 k_v = 2 * np.pi * np.dot(kpt1.k_c, np.linalg.inv(gd.cell_cv).T) 

269 

270 ut_vR = self.ut_sKnvR[kpt1.s][kpt1.K][n - kpt1.n1] 

271 atomdata_a = self.gs.pawdatasets.by_atom 

272 C_avi = [np.dot(atomdata.nabla_iiv.T, P_ni[n - kpt1.na]) 

273 for atomdata, P_ni in zip(atomdata_a, kpt1.P_ani)] 

274 

275 blockbands = kpt2.nb - kpt2.na 

276 n0_mv = np.empty((kpt2.blocksize, 3), dtype=complex) 

277 nt_m = np.empty(kpt2.blocksize, dtype=complex) 

278 n0_mv[:blockbands] = -self.gs.gd.integrate(ut_vR, 

279 kpt2.ut_nR).T 

280 nt_m[:blockbands] = self.gs.gd.integrate(kpt1.ut_nR[n - kpt1.na], 

281 kpt2.ut_nR) 

282 

283 n0_mv[:blockbands] += (1j * nt_m[:blockbands, np.newaxis] * 

284 k_v[np.newaxis, :]) 

285 

286 for C_vi, P_mi in zip(C_avi, kpt2.P_ani): 

287 # gemm(1.0, C_vi, P_mi, 1.0, n0_mv[:blockbands], 'c') 

288 mmm(1.0, P_mi, 'N', C_vi, 'C', 1.0, n0_mv[:blockbands]) 

289 

290 if block and self.blockcomm.size > 1: 

291 n0_Mv = np.empty((kpt2.blocksize * self.blockcomm.size, 3), 

292 dtype=complex) 

293 with self.context.timer('all_gather optical'): 

294 self.blockcomm.all_gather(n0_mv, n0_Mv) 

295 n0_mv = n0_Mv[:kpt2.n2 - kpt2.n1] 

296 

297 return -1j * n0_mv 

298 

299 def calculate_optical_pair_density_head(self, n, m_m, kpt1, kpt2, 

300 block=False): 

301 # Numerical threshold for the optical limit k dot p perturbation 

302 # theory expansion: 

303 threshold = 1 

304 

305 eps1 = kpt1.eps_n[n - kpt1.n1] 

306 deps_m = (eps1 - kpt2.eps_n)[m_m - kpt2.n1] 

307 n0_mv = self.calculate_optical_pair_velocity(n, kpt1, kpt2, 

308 block=block) 

309 

310 deps_m = deps_m.copy() 

311 deps_m[deps_m == 0.0] = np.inf 

312 

313 smallness_mv = np.abs(-1e-3 * n0_mv / deps_m[:, np.newaxis]) 

314 inds_mv = (np.logical_and(np.inf > smallness_mv, 

315 smallness_mv > threshold)) 

316 n0_mv *= - 1 / deps_m[:, np.newaxis] 

317 n0_mv[inds_mv] = 0 

318 

319 return n0_mv 

320 

321 @timer('Intraband') 

322 def intraband_pair_density(self, kpt, n_n): 

323 """Calculate intraband matrix elements of nabla""" 

324 # Bands and check for block parallelization 

325 na, nb, n1 = kpt.na, kpt.nb, kpt.n1 

326 vel_nv = np.zeros((nb - na, 3), dtype=complex) 

327 assert np.max(n_n) < nb, 'This is too many bands' 

328 assert np.min(n_n) >= na, 'This is too few bands' 

329 

330 # Load kpoints 

331 gd = self.gs.gd 

332 k_v = 2 * np.pi * np.dot(kpt.k_c, np.linalg.inv(gd.cell_cv).T) 

333 atomdata_a = self.gs.pawdatasets.by_atom 

334 

335 # Break bands into degenerate chunks 

336 degchunks_cn = [] # indexing c as chunk number 

337 for n in n_n: 

338 inds_n = np.nonzero(np.abs(kpt.eps_n[n - n1] - 

339 kpt.eps_n) < 1e-5)[0] + n1 

340 

341 # Has this chunk already been computed? 

342 oldchunk = any([n in chunk for chunk in degchunks_cn]) 

343 if not oldchunk: 

344 if not all([ind in n_n for ind in inds_n]): 

345 raise RuntimeError( 

346 'You are cutting over a degenerate band ' 

347 'using block parallelization.') 

348 degchunks_cn.append(inds_n) 

349 

350 # Calculate matrix elements by diagonalizing each block 

351 for ind_n in degchunks_cn: 

352 deg = len(ind_n) 

353 ut_nvR = self.gs.gd.zeros((deg, 3), complex) 

354 vel_nnv = np.zeros((deg, deg, 3), dtype=complex) 

355 # States are included starting from kpt.na 

356 ut_nR = kpt.ut_nR[ind_n - na] 

357 

358 # Get derivatives 

359 for ind, ut_vR in zip(ind_n, ut_nvR): 

360 ut_vR[:] = self.make_derivative(kpt.s, kpt.K, 

361 ind, ind + 1)[0] 

362 

363 # Treat the whole degenerate chunk 

364 for n in range(deg): 

365 ut_vR = ut_nvR[n] 

366 C_avi = [np.dot(atomdata.nabla_iiv.T, P_ni[ind_n[n] - na]) 

367 for atomdata, P_ni in zip(atomdata_a, kpt.P_ani)] 

368 

369 nabla0_nv = -self.gs.gd.integrate(ut_vR, ut_nR).T 

370 nt_n = self.gs.gd.integrate(ut_nR[n], ut_nR) 

371 nabla0_nv += 1j * nt_n[:, np.newaxis] * k_v[np.newaxis, :] 

372 

373 for C_vi, P_ni in zip(C_avi, kpt.P_ani): 

374 # gemm(1.0, C_vi, P_ni[ind_n - na], 1.0, nabla0_nv, 'c') 

375 mmm(1.0, P_ni[ind_n - na], 'N', C_vi, 'C', 1.0, nabla0_nv) 

376 

377 vel_nnv[n] = -1j * nabla0_nv 

378 

379 for iv in range(3): 

380 vel, _ = np.linalg.eig(vel_nnv[..., iv]) 

381 vel_nv[ind_n - na, iv] = vel # Use eigenvalues 

382 

383 return vel_nv[n_n - na] 

384 

385 def calculate_derivatives(self, kpt): 

386 ut_sKnvR = [{}, {}] 

387 ut_nvR = self.make_derivative(kpt.s, kpt.K, kpt.n1, kpt.n2) 

388 ut_sKnvR[kpt.s][kpt.K] = ut_nvR 

389 

390 return ut_sKnvR 

391 

392 @timer('Derivatives') 

393 def make_derivative(self, s, K, n1, n2): 

394 gs = self.gs 

395 U_cc = gs.ibz2bz[K].U_cc 

396 A_cv = gs.gd.cell_cv 

397 M_vv = np.dot(np.dot(A_cv.T, U_cc.T), np.linalg.inv(A_cv).T) 

398 ik = gs.kd.bz2ibz_k[K] 

399 assert gs.kd.comm.size == 1 

400 kpt = gs.kpt_qs[ik][s] 

401 psit_nG = kpt.psit_nG 

402 iG_Gv = 1j * gs.pd.get_reciprocal_vectors(q=ik, add_q=False) 

403 ut_nvR = gs.gd.zeros((n2 - n1, 3), complex) 

404 for n in range(n1, n2): 

405 for v in range(3): 

406 ut_R = gs.ibz2bz[K].map_pseudo_wave( 

407 gs.pd.ifft(iG_Gv[:, v] * psit_nG[n], ik)) 

408 for v2 in range(3): 

409 ut_nvR[n - n1, v2] += ut_R * M_vv[v, v2] 

410 

411 return ut_nvR 

412 

413 

414def phase_shifted_fft_indices(k1_c, k2_c, qpd, coordinate_transformation=None): 

415 """Get phase shifted FFT indices for G-vectors inside the cutoff sphere. 

416 

417 The output 1D FFT indices Q_G can be used to extract the plane-wave 

418 components G of the phase shifted Fourier transform 

419 

420 n_kk'(G+q) = FFT_G[e^(-i[k+q-k']r) n_kk'(r)] 

421 

422 where n_kk'(r) is some lattice periodic function and the wave vector 

423 difference k + q - k' is commensurate with the reciprocal lattice. 

424 """ 

425 N_c = qpd.gd.N_c 

426 Q_G = qpd.Q_qG[0] 

427 q_c = qpd.q_c 

428 if coordinate_transformation: 

429 q_c = coordinate_transformation(q_c) 

430 

431 shift_c = k1_c + q_c - k2_c 

432 assert np.allclose(shift_c.round(), shift_c) 

433 shift_c = shift_c.round().astype(int) 

434 

435 if shift_c.any() or coordinate_transformation: 

436 # Get the 3D FFT grid indices (relative reciprocal space coordinates) 

437 # of the G-vectors inside the cutoff sphere 

438 i_cG = np.unravel_index(Q_G, N_c) 

439 if coordinate_transformation: 

440 i_cG = coordinate_transformation(i_cG) 

441 # Shift the 3D FFT grid indices to account for the Bloch-phase shift 

442 # e^(-i[k+q-k']r) 

443 i_cG += shift_c[:, np.newaxis] 

444 # Transform back the FFT grid to 1D FFT indices 

445 Q_G = np.ravel_multi_index(i_cG, N_c, 'wrap') 

446 

447 return Q_G 

448 

449 

450def get_gs_and_context(calc, txt, world, timer): 

451 """Interface to initialize gs and context from old input arguments. 

452 Should be phased out in the future!""" 

453 from gpaw.calculator import GPAW as OldGPAW 

454 from gpaw.new.ase_interface import ASECalculator as NewGPAW 

455 

456 context = ResponseContext(txt=txt, timer=timer, comm=world) 

457 

458 if isinstance(calc, (OldGPAW, NewGPAW)): 

459 assert calc.wfs.world.size == 1 

460 gs = calc.gs_adapter() 

461 else: 

462 gs = ResponseGroundStateAdapter.from_gpw_file(gpw=calc) 

463 

464 return gs, context