Coverage for gpaw/new/pwfd/wave_functions.py: 86%

273 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from __future__ import annotations 

2 

3from functools import partial 

4from math import pi 

5from typing import Optional, Callable 

6 

7import numpy as np 

8from gpaw.core.arrays import DistributedArrays as XArray 

9from gpaw.core.atom_arrays import AtomArrays, AtomDistribution 

10from gpaw.core.atom_centered_functions import AtomCenteredFunctions 

11from gpaw.core.plane_waves import PWArray 

12from gpaw.core.uniform_grid import UGArray, UGDesc 

13from gpaw.fftw import get_efficient_fft_size 

14from gpaw.gpu import as_np, XP 

15from gpaw.mpi import receive, send 

16from gpaw.new import prod, trace, zips 

17from gpaw.new.potential import Potential 

18from gpaw.new.wave_functions import WaveFunctions 

19from gpaw.setup import Setups 

20from gpaw.typing import Array2D, Array3D, Vector 

21from gpaw.utilities import as_real_dtype 

22 

23 

24class PWFDWaveFunctions(WaveFunctions, XP): 

25 def __init__(self, 

26 psit_nX: XArray, 

27 *, 

28 spin: int, 

29 q: int, 

30 k: int, 

31 setups: Setups, 

32 relpos_ac: Array2D, 

33 atomdist: AtomDistribution, 

34 weight: float = 1.0, 

35 ncomponents: int = 1, 

36 qspiral_v: Vector | None = None): 

37 # assert isinstance(atomdist, AtomDistribution) 

38 self.psit_nX = psit_nX 

39 nbands = psit_nX.dims[0] 

40 super().__init__(setups=setups, 

41 nbands=nbands, 

42 spin=spin, 

43 q=q, 

44 k=k, 

45 kpt_c=psit_nX.desc.kpt_c, 

46 relpos_ac=relpos_ac, 

47 atomdist=atomdist, 

48 weight=weight, 

49 ncomponents=ncomponents, 

50 qspiral_v=qspiral_v, 

51 dtype=psit_nX.desc.dtype, 

52 domain_comm=psit_nX.desc.comm, 

53 band_comm=psit_nX.comm) 

54 self._pt_aiX: Optional[AtomCenteredFunctions] = None 

55 self.orthonormalized = False 

56 self.bytes_per_band = (prod(self.array_shape(global_shape=True)) * 

57 psit_nX.desc.itemsize) 

58 XP.__init__(self, self.psit_nX.xp) 

59 

60 @classmethod 

61 def from_wfs(cls, 

62 wfs: PWFDWaveFunctions, 

63 psit_nX: XArray, 

64 relpos_ac=None, 

65 atomdist=None) -> PWFDWaveFunctions: 

66 return cls( 

67 psit_nX, 

68 spin=wfs.spin, 

69 q=wfs.q, 

70 k=wfs.k, 

71 setups=wfs.setups, 

72 relpos_ac=wfs.relpos_ac if relpos_ac is None else relpos_ac, 

73 atomdist=atomdist or wfs.atomdist, 

74 weight=wfs.weight, 

75 ncomponents=wfs.ncomponents, 

76 qspiral_v=wfs.qspiral_v) 

77 

78 def __del__(self): 

79 # We could be reading from a gpw-file 

80 data = self.psit_nX.data 

81 if hasattr(data, 'fd'): 

82 data.fd.close() 

83 

84 def _short_string(self, global_shape: tuple[int]) -> str: 

85 return self.psit_nX.desc._short_string(global_shape) 

86 

87 def array_shape(self, global_shape=False): 

88 if global_shape: 

89 shape = self.psit_nX.desc.global_shape() 

90 else: 

91 shape = self.psit_nX.desc.myshape 

92 if self.ncomponents == 4: 

93 shape = (2,) + shape 

94 return shape 

95 

96 @property 

97 def pt_aiX(self) -> AtomCenteredFunctions: 

98 """PAW projector functions. 

99 

100 ::: 

101 

102 ~a _ 

103 p (r) 

104 i 

105 """ 

106 if self._pt_aiX is None: 

107 self._pt_aiX = self.psit_nX.desc.atom_centered_functions( 

108 [setup.pt_j for setup in self.setups], 

109 self.relpos_ac, 

110 atomdist=self.atomdist, 

111 qspiral_v=self.qspiral_v, 

112 xp=self.psit_nX.xp) 

113 return self._pt_aiX 

114 

115 @property 

116 def P_ani(self) -> AtomArrays: 

117 """PAW projections. 

118 

119 ::: 

120 

121 ~a ~ 

122 <p | ψ > 

123 i n 

124 """ 

125 if self._P_ani is None: 

126 self._P_ani = self.pt_aiX.empty(self.psit_nX.dims, 

127 self.psit_nX.comm) 

128 if self.psit_nX.data is None: 

129 raise RuntimeError('There are no projections or wavefunctions') 

130 self.pt_aiX.integrate(self.psit_nX, self._P_ani) 

131 return self._P_ani 

132 

133 def move(self, 

134 relpos_ac: Array2D, 

135 atomdist: AtomDistribution, 

136 move_wave_functions: Callable[..., None]) -> None: 

137 if self.psit_nX.data is not None: 

138 move_wave_functions( 

139 self.relpos_ac, 

140 relpos_ac, 

141 self.P_ani, 

142 self.psit_nX, 

143 self.setups) 

144 super().move(relpos_ac, atomdist, move_wave_functions) 

145 self.orthonormalized = False 

146 assert self.pt_aiX is not None 

147 self.pt_aiX.move(relpos_ac, atomdist) 

148 

149 def add_to_density(self, 

150 nt_sR: UGArray, 

151 D_asii: AtomArrays) -> None: 

152 occ_n = self.weight * self.spin_degeneracy * self.myocc_n 

153 

154 self.add_to_atomic_density_matrices(occ_n, D_asii) 

155 

156 if self.ncomponents < 4: 

157 self.psit_nX.abs_square(weights=occ_n, out=nt_sR[self.spin]) 

158 return 

159 

160 psit_nsG = self.psit_nX 

161 assert isinstance(psit_nsG, PWArray) 

162 

163 tmp_sR = nt_sR.desc.new(dtype=complex).empty(2) 

164 p1_R, p2_R = tmp_sR.data 

165 nt_xR = nt_sR.data 

166 

167 for f, psit_sG in zips(occ_n, psit_nsG): 

168 psit_sG.ifft(out=tmp_sR) 

169 p11_R = p1_R.real**2 + p1_R.imag**2 

170 p22_R = p2_R.real**2 + p2_R.imag**2 

171 p12_R = p1_R.conj() * p2_R 

172 nt_xR[0] += f * (p11_R + p22_R) 

173 nt_xR[1] += 2 * f * p12_R.real 

174 nt_xR[2] += 2 * f * p12_R.imag 

175 nt_xR[3] += f * (p11_R - p22_R) 

176 

177 def add_to_ked(self, taut_sR) -> None: 

178 occ_n = self.weight * self.spin_degeneracy * self.myocc_n 

179 self.psit_nX.add_ked(occ_n, taut_sR[self.spin]) 

180 

181 @trace 

182 def orthonormalize(self, psit2_nX): 

183 r"""Orthonormalize wave functions. 

184 

185 Computes the overlap matrix::: 

186 

187 / ~ _ *~ _ _ --- a * a a 

188 S = | ψ(r) ψ(r) dr + > (P ) P ΔS 

189 mn / m n --- im jn ij 

190 aij 

191 

192 With `LSL^\dagger=1`, we update the wave functions and projections 

193 inplace like this::: 

194 

195 -- * 

196 Ψ <- > L Ψ , 

197 m -- mn n 

198 n 

199 

200 and::: 

201 

202 a -- * a 

203 P <- > L P . 

204 mi -- mn ni 

205 n 

206 

207 """ 

208 if self.orthonormalized: 

209 return 

210 psit_nX = self.psit_nX 

211 domain_comm = psit_nX.desc.comm 

212 P_ani = self.P_ani 

213 

214 P2_ani = P_ani.new() 

215 if psit2_nX is None: 

216 psit2_nX = psit_nX.new() 

217 dS_aii = self.setups.get_overlap_corrections( 

218 P_ani.layout.atomdist, 

219 self.xp, 

220 dtype=as_real_dtype(P_ani.data.dtype)) 

221 

222 # We are actually calculating S^*: 

223 S = psit_nX.matrix_elements(psit_nX, domain_sum=False, cc=True) 

224 P_ani.block_diag_multiply(dS_aii, out_ani=P2_ani) 

225 P_ani.matrix.multiply(P2_ani, opb='C', symmetric=True, out=S, beta=1.0) 

226 domain_comm.sum(S.data, 0) 

227 

228 if domain_comm.rank == 0: 

229 S.invcholesky() 

230 domain_comm.broadcast(S.data, 0) 

231 # S now contains L^* 

232 

233 S.multiply(psit_nX, out=psit2_nX) 

234 S.multiply(P_ani, out=P2_ani) 

235 psit_nX.data[:] = psit2_nX.data 

236 P_ani.data[:] = P2_ani.data 

237 self.orthonormalized = True 

238 

239 @trace 

240 def subspace_diagonalize(self, 

241 Ht, 

242 dH, 

243 psit2_nX, 

244 data_buffer=None, 

245 scalapack_parameters=(None, 1, 1, None)): 

246 """ 

247 If data_buffer is None, psit2_nX will be used as a buffer 

248 for the wave functions. 

249 

250 Ht(in, out)::: 

251 

252 ~ ^ ~ 

253 H = T + v 

254 

255 dH::: 

256 

257 ~ ~ a ~ ~ 

258 <𝜓 |p> ΔH <p |𝜓> 

259 m i ij j n 

260 """ 

261 self.orthonormalize(psit2_nX) 

262 psit_nX = self.psit_nX 

263 P_ani = self.P_ani 

264 P2_ani = P_ani.new() 

265 domain_comm = psit_nX.desc.comm 

266 

267 Ht = partial(Ht, out=psit2_nX, spin=self.spin) 

268 H = psit_nX.matrix_elements(psit_nX, 

269 function=Ht, 

270 domain_sum=False, 

271 cc=True) 

272 dH(P_ani, out_ani=P2_ani, spin=self.spin) 

273 P_ani.matrix.multiply(P2_ani, opb='C', symmetric=True, 

274 out=H, beta=1.0) 

275 domain_comm.sum(H.data, 0) 

276 

277 if domain_comm.rank == 0: 

278 slcomm, r, c, b = scalapack_parameters 

279 if r == c == 1: 

280 slcomm = None 

281 self._eig_n = as_np(H.eigh(scalapack=(slcomm, r, c, b))) 

282 H.complex_conjugate() 

283 # H.data[n, :] now contains the nth eigenvector and eps_n[n] 

284 # the nth eigenvalue 

285 else: 

286 self._eig_n = np.empty(psit_nX.dims) 

287 

288 domain_comm.broadcast(H.data, 0) 

289 domain_comm.broadcast(self._eig_n, 0) 

290 if data_buffer is None: 

291 H.multiply(psit_nX, out=psit2_nX) 

292 psit_nX.data[:] = psit2_nX.data 

293 H.multiply(P_ani, out=P2_ani) 

294 P_ani.data[:] = P2_ani.data 

295 else: 

296 H.multiply(psit_nX, out=psit_nX, data_buffer=data_buffer) 

297 H.multiply(psit2_nX, out=psit2_nX, data_buffer=data_buffer) 

298 H.multiply(P_ani, out=P2_ani) 

299 P_ani.data[:] = P2_ani.data 

300 

301 def force_contribution(self, 

302 potential: Potential, 

303 F_av: Array2D) -> None: 

304 xp = self.xp 

305 dH_asii = potential.dH_asii 

306 myeig_n = xp.asarray(self.myeig_n) 

307 myocc_n = xp.asarray( 

308 self.weight * self.spin_degeneracy * self.myocc_n) 

309 

310 if self.ncomponents == 4: 

311 self._non_collinear_force_contribution(dH_asii, myocc_n, F_av) 

312 return 

313 

314 F_anvi = self.pt_aiX.derivative(self.psit_nX) 

315 for a, F_nvi in F_anvi.items(): 

316 F_nvi = F_nvi.conj() 

317 F_nvi *= myocc_n[:, np.newaxis, np.newaxis] 

318 dH_ii = dH_asii[a][self.spin] 

319 P_ni = self.P_ani[a] 

320 F_vii = xp.einsum('nvi, nj, jk -> vik', F_nvi, P_ni, dH_ii) 

321 F_nvi *= myeig_n[:, np.newaxis, np.newaxis] 

322 dO_ii = xp.asarray(self.setups[a].dO_ii) 

323 F_vii -= xp.einsum('nvi, nj, jk -> vik', F_nvi, P_ni, dO_ii) 

324 F_av[a] += 2 * F_vii.real.trace(0, 1, 2) 

325 

326 def _non_collinear_force_contribution(self, 

327 dH_asii, 

328 myocc_n, 

329 F_av): 

330 F_ansvi = self.pt_aiX.derivative(self.psit_nX) 

331 for a, F_nsvi in F_ansvi.items(): 

332 F_nsvi = F_nsvi.conj() 

333 F_nsvi *= myocc_n[:, np.newaxis, np.newaxis, np.newaxis] 

334 dH_sii = dH_asii[a] 

335 dH_ii = dH_sii[0] 

336 dH_vii = dH_sii[1:] 

337 dH_ssii = np.array( 

338 [[dH_ii + dH_vii[2], dH_vii[0] - 1j * dH_vii[1]], 

339 [dH_vii[0] + 1j * dH_vii[1], dH_ii - dH_vii[2]]]) 

340 P_nsi = self.P_ani[a] 

341 F_v = np.einsum('nsvi, stij, ntj -> v', F_nsvi, dH_ssii, P_nsi) 

342 F_nsvi *= self.myeig_n[:, np.newaxis, np.newaxis, np.newaxis] 

343 dO_ii = self.setups[a].dO_ii 

344 F_v -= np.einsum('nsvi, ij, nsj -> v', F_nsvi, dO_ii, P_nsi) 

345 F_av[a] += 2 * F_v.real 

346 

347 def collect(self, 

348 n1: int = 0, 

349 n2: int = 0) -> PWFDWaveFunctions | None: 

350 """Collect range of bands to master of band and domain comms.""" 

351 # Also collect projections instead of recomputing XXX 

352 n2 = n2 if n2 > 0 else self.nbands + n2 

353 spinors = (2,) if self.ncomponents == 4 else () 

354 band_comm = self.psit_nX.comm 

355 domain_comm = self.psit_nX.desc.comm 

356 nbands = self.nbands 

357 mynbands = (nbands + band_comm.size - 1) // band_comm.size 

358 rank1, b1 = divmod(n1, mynbands) 

359 rank2, b2 = divmod(n2, mynbands) 

360 if band_comm.rank == 0: 

361 if domain_comm.rank == 0: 

362 psit_nX = self.psit_nX.desc.new(comm=None).empty( 

363 (n2 - n1, *spinors)) 

364 rank = rank1 

365 ba = b1 

366 na = n1 

367 while (rank, ba) < (rank2, b2): 

368 bb = min((rank + 1) * mynbands, nbands) - rank * mynbands 

369 if rank == rank2 and bb > b2: 

370 bb = b2 

371 nb = na + bb - ba 

372 if bb > ba: 

373 if rank == 0: 

374 psit_bX = self.psit_nX[ba:bb].gather() 

375 if domain_comm.rank == 0: 

376 psit_nX.data[:bb - ba] = psit_bX.data 

377 else: 

378 if domain_comm.rank == 0: 

379 band_comm.receive(psit_nX.data[na - n1:nb - n1], 

380 rank) 

381 rank += 1 

382 ba = 0 

383 na = nb 

384 if domain_comm.rank == 0: 

385 wfs = PWFDWaveFunctions.from_wfs( 

386 self, 

387 psit_nX, 

388 atomdist=self.atomdist.gather()) 

389 wfs._eig_n = self.eig_n[n1:n2] 

390 return wfs 

391 else: 

392 rank = band_comm.rank 

393 ranka, ba = max((rank1, b1), (rank, 0)) 

394 rankb, bb = min((rank2, b2), (rank, self.psit_nX.mydims[0])) 

395 if (ranka, ba) < (rankb, bb): 

396 assert ranka == rankb == rank 

397 band_comm.send(self.psit_nX.data[ba:bb], dest=0) 

398 

399 return None 

400 

401 def send(self, rank, comm): 

402 stuff = (self.kpt_c, 

403 self.psit_nX.data, 

404 self.spin, 

405 self.q, 

406 self.k, 

407 self.weight) 

408 send(stuff, rank, comm) 

409 

410 def receive(self, rank, comm): 

411 kpt_c, data, spin, q, k, weight = receive(rank, comm) 

412 psit_nX = self.psit_nX.desc.new(kpt=kpt_c, comm=None).from_data(data) 

413 return PWFDWaveFunctions(psit_nX, 

414 spin=spin, 

415 q=q, 

416 k=k, 

417 setups=self.setups, 

418 relpos_ac=self.relpos_ac, 

419 atomdist=self.atomdist.gather(), 

420 weight=weight, 

421 ncomponents=self.ncomponents, 

422 qspiral_v=self.qspiral_v) 

423 

424 def dipole_matrix_elements(self) -> Array3D: 

425 """Calculate dipole matrix-elements. 

426 

427 ::: 

428 

429 _ / ~ ~ _ _ --- a a _a 

430 μ = | 𝜓 𝜓 rdr + > P P Δμ 

431 mn / m n --- im jn ij 

432 aij 

433 

434 Returns 

435 ------- 

436 Array3D: 

437 matrix elements in atomic units. 

438 """ 

439 cell_cv = self.psit_nX.desc.cell_cv 

440 

441 dipole_nnv = np.zeros((self.nbands, self.nbands, 3)) 

442 

443 position_av = self.relpos_ac @ cell_cv 

444 

445 R_aiiv = [] 

446 for setup, position_v in zips(self.setups, position_av): 

447 Delta_iiL = setup.Delta_iiL 

448 R_iiv = Delta_iiL[:, :, [3, 1, 2]] * (4 * pi / 3)**0.5 

449 R_iiv += position_v * setup.Delta_iiL[:, :, :1] * (4 * pi)**0.5 

450 R_aiiv.append(R_iiv) 

451 

452 for a, P_ni in self.P_ani.items(): 

453 dipole_nnv += np.einsum('mi, ijv, nj -> mnv', 

454 P_ni, R_aiiv[a], P_ni) 

455 

456 self.psit_nX.desc.comm.sum(dipole_nnv) 

457 

458 if isinstance(self.psit_nX, UGArray): 

459 psit_nR = self.psit_nX 

460 else: 

461 assert isinstance(self.psit_nX, PWArray) 

462 # Find size of fft grid large enough to store square of wfs. 

463 pw = self.psit_nX.desc 

464 s1, s2, s3 = np.ptp(pw.indices_cG, axis=1) # type: ignore 

465 assert pw.dtype == float 

466 # Last dimension is special because dtype=float: 

467 size_c = [2 * s1 + 2, 

468 2 * s2 + 2, 

469 4 * s3 + 2] 

470 size_c = [get_efficient_fft_size(N, 2) for N in size_c] 

471 grid = UGDesc(cell=pw.cell_cv, size=size_c) 

472 psit_nR = self.psit_nX.ifft(grid=grid) 

473 

474 for na, psita_R in enumerate(psit_nR): 

475 for nb, psitb_R in enumerate(psit_nR[:na + 1]): 

476 d_v = (psita_R * psitb_R).moment() 

477 dipole_nnv[na, nb] += d_v 

478 if na != nb: 

479 dipole_nnv[nb, na] += d_v 

480 

481 return dipole_nnv 

482 

483 def gather_wave_function_coefficients(self) -> np.ndarray | None: 

484 psit_nX = self.psit_nX.gather() # gather X 

485 if psit_nX is not None: 

486 data_nX = psit_nX.matrix.gather() # gather n 

487 if data_nX.dist.comm.rank == 0: 

488 # XXX PW-gamma-point mode: float or complex matrix.dtype? 

489 return data_nX.data.view( 

490 psit_nX.data.dtype).reshape((-1,) + psit_nX.data.shape[1:]) 

491 return None 

492 

493 def to_uniform_grid_wave_functions(self, 

494 grid, 

495 basis): 

496 if isinstance(self.psit_nX, UGArray): 

497 return self 

498 

499 grid = grid.new(kpt=self.kpt_c, dtype=self.dtype) 

500 psit_nR = grid.zeros(self.nbands, self.band_comm) 

501 self.psit_nX.ifft(out=psit_nR) 

502 return PWFDWaveFunctions.from_wfs(self, psit_nR) 

503 

504 def morph(self, desc, relpos_ac, atomdist): 

505 desc = desc.new(kpt=self.psit_nX.desc.kpt_c) 

506 psit_nX = self.psit_nX.morph(desc) 

507 

508 # Save memory: 

509 self.psit_nX.data = None 

510 self._P_ani = None 

511 self._pt_aiX = None 

512 

513 return PWFDWaveFunctions.from_wfs(self, psit_nX, 

514 relpos_ac=relpos_ac)