Coverage for gpaw/core/pwacf.py: 94%

375 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1from __future__ import annotations 

2 

3from math import pi 

4from typing import TYPE_CHECKING 

5 

6import numpy as np 

7 

8from gpaw.core.atom_arrays import AtomArraysLayout, AtomDistribution 

9from gpaw.core.atom_centered_functions import AtomCenteredFunctions 

10from gpaw.core.matrix import Matrix 

11from gpaw.core.uniform_grid import UGArray 

12from gpaw.ffbt import rescaled_fourier_bessel_transform 

13from gpaw.gpu import cupy_is_fake, gpu_gemm 

14# from gpaw.lfc import BaseLFC 

15from gpaw.new import prod, trace, tracectx 

16from gpaw.new.c import pwlfc_expand, pwlfc_expand_gpu 

17from gpaw.spherical_harmonics import Y, nablarlYL 

18from gpaw.spline import Spline 

19from gpaw.typing import ArrayLike1D 

20from gpaw.utilities import as_complex_dtype, as_real_dtype 

21from gpaw.utilities.blas import mmm 

22 

23if TYPE_CHECKING: 

24 from gpaw.core.plane_waves import PWDesc, PWArray 

25 

26 

27class PWAtomCenteredFunctions(AtomCenteredFunctions): 

28 def __init__(self, 

29 functions, 

30 relpos, 

31 pw, 

32 atomdist=None, 

33 integrals=None, 

34 xp=None): 

35 AtomCenteredFunctions.__init__(self, functions, relpos, atomdist) 

36 self.pw = pw 

37 self.xp = xp or np 

38 self.integrals = integrals 

39 

40 def new(self, pw, atomdist): 

41 return PWAtomCenteredFunctions( 

42 self.functions, 

43 self.relpos_ac, 

44 pw, 

45 atomdist=atomdist, 

46 xp=self.xp) 

47 

48 def _lazy_init(self): 

49 if self._lfc is not None: 

50 return 

51 

52 self._lfc = PWLFC(self.functions, self.pw, xp=self.xp, 

53 integrals=self.integrals) 

54 if self._atomdist is None: 

55 self._atomdist = AtomDistribution.from_number_of_atoms( 

56 len(self.relpos_ac), self.pw.comm) 

57 

58 self._lfc.set_positions(self.relpos_ac, self._atomdist) 

59 self._layout = AtomArraysLayout([sum(2 * f.l + 1 for f in funcs) 

60 for funcs in self.functions], 

61 self._atomdist, 

62 self.pw.dtype, 

63 xp=self.xp) 

64 

65 def __repr__(self): 

66 s = super().__repr__() 

67 if self.xp is np: 

68 return s 

69 return s[:-1] + ', xp=cp)' 

70 

71 def to_uniform_grid(self, 

72 out: UGArray, 

73 scale: float = 1.0) -> UGArray: 

74 out_G = self.pw.zeros(xp=out.xp) 

75 self.add_to(out_G, scale) 

76 return out_G.ifft(out=out) 

77 

78 def change_cell(self, new_pw): 

79 self.pw = new_pw 

80 self._lfc = None 

81 

82 def multiply(self, 

83 C_nM: Matrix, 

84 out_nG: PWArray) -> None: 

85 """Convert from LCAO expansion to PW expansion.""" 

86 self._lazy_init() 

87 lfc = self._lfc 

88 assert lfc is not None 

89 for G1, G2 in lfc.block(): 

90 f_GI = lfc.expand(G1, G2, cc=False) 

91 a_nG = out_nG.data[:, G1:G2] 

92 if lfc.real: 

93 a_nG = a_nG.view(f_GI.dtype) 

94 if self.xp is np: 

95 mmm(1.0 / self.pw.dv, C_nM.data, 'N', f_GI, 'T', 0.0, a_nG) 

96 else: 

97 gpu_gemm('N', 'T', 

98 C_nM.data, f_GI, a_nG, 1.0 / self.pw.dv, 0.0) 

99 

100 

101class PWLFC: # (BaseLFC) 

102 def __init__(self, 

103 functions, 

104 pw: PWDesc, 

105 *, 

106 xp, 

107 integrals: ArrayLike1D | float | None = None, 

108 blocksize: int | None = 5000): 

109 """Reciprocal-space plane-wave localized function collection. 

110 

111 spline_aj: list of list of spline objects 

112 Splines. 

113 pd: PWDescriptor 

114 Plane-wave descriptor object. 

115 blocksize: int 

116 Block-size to use when looping over G-vectors. Use None for 

117 doing all G-vectors in one big block. 

118 """ 

119 

120 self.xp = xp 

121 self.pw = pw 

122 self.spline_aj = functions 

123 

124 self.dtype = pw.dtype 

125 self.real = np.issubdtype(pw.dtype, np.floating) 

126 

127 self.initialized = False 

128 

129 # These will be filled in later: 

130 self.Y_GL = np.zeros((0, 0)) 

131 self.f_Gs: np.ndarray = np.zeros((0, 0)) 

132 self.l_s: np.ndarray | None = None 

133 self.a_J: np.ndarray | None = None 

134 self.s_J: np.ndarray | None = None 

135 self.I_J: np.ndarray | None = None 

136 self.lmax = -1 

137 

138 if blocksize is not None: 

139 if pw.maxmysize <= blocksize: 

140 # No need to block G-vectors 

141 blocksize = None 

142 self.blocksize = blocksize 

143 

144 # These are set later in set_potitions(): 

145 self.eikR_a = None 

146 self.my_atom_indices = None 

147 self.my_indices = None 

148 self.pos_av = None 

149 self.nI = None 

150 

151 self.comm = pw.comm 

152 

153 if isinstance(integrals, float): 

154 self.integral_a = np.zeros(len(functions)) + integrals 

155 elif integrals is None: 

156 self.integral_a = np.zeros(len(functions)) 

157 else: 

158 self.integral_a = np.array(integrals) 

159 

160 @trace 

161 def initialize(self) -> None: 

162 """Initialize position-independent stuff.""" 

163 if self.initialized: 

164 return 

165 

166 xp = self.xp 

167 

168 splines: dict[Spline, int] = {} 

169 for spline_j in self.spline_aj: 

170 for spline in spline_j: 

171 if spline not in splines: 

172 splines[spline] = len(splines) 

173 nsplines = len(splines) 

174 

175 nJ = sum(len(spline_j) for spline_j in self.spline_aj) 

176 

177 self.f_Gs = xp.empty(self.pw.myshape + (nsplines,), 

178 dtype=as_real_dtype(self.dtype)) 

179 self.l_s = np.empty(nsplines, np.int32) 

180 self.a_J = np.empty(nJ, np.int32) 

181 self.s_J = np.empty(nJ, np.int32) 

182 self.I_J = np.empty(nJ, np.int32) 

183 # Fourier transform radial functions: 

184 J = 0 

185 done: set[Spline] = set() 

186 I = 0 

187 for a, spline_j in enumerate(self.spline_aj): 

188 for spline in spline_j: 

189 s = splines[spline] # get spline index 

190 if spline not in done: 

191 f = rescaled_fourier_bessel_transform(spline) 

192 G_G = (2 * self.pw.ekin_G)**0.5 

193 self.f_Gs[:, s] = xp.asarray(f.map(G_G)) 

194 l = spline.get_angular_momentum_number() 

195 self.l_s[s] = l 

196 integral = self.integral_a[a] 

197 if l == 0 and integral != 0.0: 

198 x = integral / self.f_Gs[0, s] * (4 * pi)**0.5 

199 self.f_Gs[:, s] *= x 

200 done.add(spline) 

201 self.a_J[J] = a 

202 self.s_J[J] = s 

203 self.I_J[J] = I 

204 I += 2 * spline.get_angular_momentum_number() + 1 

205 J += 1 

206 

207 self.lmax = max(self.l_s, default=-1) 

208 

209 # Spherical harmonics: 

210 G_Gv = self.pw.G_plus_k_Gv 

211 self.Y_GL = xp.empty((len(G_Gv), (self.lmax + 1)**2), 

212 dtype=as_real_dtype(self.dtype)) 

213 for L in range((self.lmax + 1)**2): 

214 self.Y_GL[:, L] = xp.asarray(Y(L, *G_Gv.T)) 

215 

216 self.l_s = xp.asarray(self.l_s) 

217 self.a_J = xp.asarray(self.a_J) 

218 self.s_J = xp.asarray(self.s_J) 

219 self.I_J = xp.asarray(self.I_J) 

220 

221 self.initialized = True 

222 

223 def get_function_count(self, a): 

224 return sum(2 * spline.get_angular_momentum_number() + 1 

225 for spline in self.spline_aj[a]) 

226 

227 @trace 

228 def set_positions(self, spos_ac, atomdist): 

229 self.initialize() 

230 

231 xp = self.xp 

232 

233 if self.real: 

234 self.eikR_a = xp.ones(len(spos_ac), 

235 dtype=as_real_dtype(self.dtype)) 

236 else: 

237 self.eikR_a = xp.asarray( 

238 np.exp(2j * pi * (spos_ac @ self.pw.kpt_c)), 

239 dtype=as_complex_dtype(self.dtype)) 

240 self.pos_av = xp.asarray(np.dot(spos_ac, self.pw.cell), 

241 dtype=as_real_dtype(self.dtype)) 

242 

243 self.pos_avT = xp.asarray(self.pos_av.T, 

244 as_real_dtype(self.dtype)) 

245 self.G_plus_k_Gv = self.xp.asarray(self.pw.G_plus_k_Gv, 

246 as_real_dtype(self.dtype)) 

247 

248 rank_a = atomdist.rank_a 

249 

250 self.my_atom_indices = [] 

251 self.my_indices = [] 

252 I1 = 0 

253 for a, rank in enumerate(rank_a): 

254 I2 = I1 + self.get_function_count(a) 

255 if rank == self.comm.rank: 

256 self.my_atom_indices.append(a) 

257 self.my_indices.append((a, I1, I2)) 

258 I1 = I2 

259 self.nI = I1 

260 

261 @trace 

262 def expand(self, G1=0, G2=None, cc=False): 

263 """Expand functions in plane-waves. 

264 

265 q: int 

266 k-point index. 

267 G1: int 

268 Start G-vector index. 

269 G2: int 

270 End G-vector index. 

271 cc: bool 

272 Complex conjugate. 

273 """ 

274 xp = self.xp 

275 

276 if G2 is None: 

277 G2 = self.Y_GL.shape[0] 

278 

279 Gk_Gv = self.G_plus_k_Gv[G1:G2] 

280 pos_av = self.pos_av 

281 eikR_a = xp.asarray(self.eikR_a, 

282 dtype=as_complex_dtype(self.dtype)) 

283 

284 f_Gs = self.f_Gs[G1:G2] 

285 Y_GL = self.Y_GL[G1:G2] 

286 

287 if not self.real: 

288 f_GI = xp.empty((G2 - G1, self.nI), as_complex_dtype(self.dtype)) 

289 else: 

290 # Special layout because BLAS does not have real-complex 

291 # multiplications. f_GI(G,I) layout: 

292 # 

293 # real(G1, 0), real(G1, 1), ... 

294 # imag(G1, 0), imag(G1, 1), ... 

295 # real(G1+1, 0), real(G1+1, 1), ... 

296 # imag(G1+1, 0), imag(G1+1, 1), ... 

297 # ... 

298 

299 f_GI = xp.empty((2 * (G2 - G1), self.nI), 

300 as_real_dtype(self.dtype)) 

301 

302 if xp is np: 

303 # Fast C-code: 

304 pwlfc_expand(f_Gs, Gk_Gv, pos_av, eikR_a, Y_GL, 

305 self.l_s, self.a_J, self.s_J, 

306 cc, f_GI) 

307 elif cupy_is_fake: 

308 pwlfc_expand(f_Gs._data, Gk_Gv._data, pos_av._data, 

309 eikR_a._data, Y_GL._data, 

310 self.l_s._data, self.a_J._data, self.s_J._data, 

311 cc, f_GI._data) 

312 else: 

313 pwlfc_expand_gpu(f_Gs, Gk_Gv, pos_av, eikR_a, Y_GL, 

314 self.l_s, self.a_J, self.s_J, 

315 cc, f_GI, self.I_J) 

316 return f_GI 

317 

318 def block(self, ensure_same_number_of_blocks=False): 

319 nG = self.Y_GL.shape[0] 

320 B = self.blocksize 

321 if B: 

322 G1 = 0 

323 while G1 < nG: 

324 G2 = min(G1 + B, nG) 

325 yield G1, G2 

326 G1 = G2 

327 if ensure_same_number_of_blocks: 

328 # Make sure we yield the same number of times: 

329 nb = (self.pw.maxmysize + B - 1) // B 

330 mynb = (nG + B - 1) // B 

331 if mynb < nb: 

332 yield nG, nG # empty block 

333 else: 

334 yield 0, nG 

335 

336 @trace 

337 def get_emiGR_Ga(self, G1, G2): 

338 Gk_Gv = self.G_plus_k_Gv[G1:G2] 

339 GkR_Ga = Gk_Gv @ self.pos_avT 

340 return self.xp.exp(-1j * GkR_Ga) * self.eikR_a 

341 

342 @trace 

343 def add(self, a_xG, c_axi=1.0, q=None): 

344 if self.nI == 0: 

345 return 

346 c_xI = self.xp.empty(a_xG.shape[:-1] + (self.nI,), self.dtype) 

347 

348 if isinstance(c_axi, float): 

349 assert a_xG.ndim == 1 

350 c_xI[:] = c_axi 

351 else: 

352 if self.comm.size != 1: 

353 c_xI[:] = 0.0 

354 for a, I1, I2 in self.my_indices: 

355 c_xI[..., I1:I2] = c_axi[a] * self.eikR_a[a].conj() 

356 if self.comm.size != 1: 

357 self.comm.sum(c_xI) 

358 

359 nx = prod(c_xI.shape[:-1]) 

360 if nx == 0: 

361 return 

362 c_xI = c_xI.reshape((nx, self.nI)) 

363 a_xG = a_xG.reshape((nx, a_xG.shape[-1])).view(self.dtype) 

364 

365 for G1, G2 in self.block(): 

366 f_GI = self.expand(G1, G2, cc=False) 

367 

368 if self.real: 

369 # f_IG = f_IG.view(float) 

370 G1 *= 2 

371 G2 *= 2 

372 

373 with tracectx('gemm'): 

374 if self.xp is np: 

375 mmm(1.0 / self.pw.dv, c_xI, 'N', f_GI, 'T', 

376 1.0, a_xG[:, G1:G2]) 

377 else: 

378 gpu_gemm('N', 'T', 

379 c_xI, f_GI, a_xG[:, G1:G2], 

380 1.0 / self.pw.dv, 1.0) 

381 

382 @trace 

383 def integrate(self, a_xG, c_axi=None, q=-1, add_to=False): 

384 xp = self.xp 

385 if self.nI == 0: 

386 return c_axi 

387 c_xI = xp.zeros(a_xG.shape[:-1] + (self.nI,), self.dtype) 

388 

389 nx = prod(c_xI.shape[:-1]) 

390 if nx == 0: 

391 return 

392 b_xI = c_xI.reshape((nx, self.nI)) 

393 a_xG = a_xG.reshape((nx, a_xG.shape[-1])) 

394 

395 alpha = 1.0 

396 if self.real: 

397 alpha *= 2 

398 a_xG = a_xG.view(self.dtype) 

399 

400 if c_axi is None: 

401 c_axi = self.dict(a_xG.shape[:-1]) 

402 

403 x = 0.0 

404 for G1, G2 in self.block(): 

405 f_GI = self.expand(G1, G2, cc=not self.real) 

406 if self.real: 

407 if G1 == 0 and self.comm.rank == 0: 

408 f_GI[0] *= 0.5 

409 G1 *= 2 

410 G2 *= 2 

411 if xp is np: 

412 mmm(alpha, a_xG[:, G1:G2], 'N', f_GI, 'N', x, b_xI) 

413 else: 

414 gpu_gemm('N', 'N', 

415 a_xG[:, G1:G2], f_GI, b_xI, 

416 alpha, x) 

417 x = 1.0 

418 

419 self.comm.sum(b_xI) 

420 with tracectx('Displace integrals', gpu=True): 

421 if add_to: 

422 for a, I1, I2 in self.my_indices: 

423 c_axi[a] += self.eikR_a[a] * c_xI[..., I1:I2] 

424 else: 

425 for a, I1, I2 in self.my_indices: 

426 c_axi[a][:] = self.eikR_a[a] * c_xI[..., I1:I2] 

427 

428 return c_axi 

429 

430 @trace 

431 def derivative(self, a_xG, c_axiv=None, q=-1): 

432 xp = self.xp 

433 c_vxI = xp.zeros((3,) + a_xG.shape[:-1] + (self.nI,), self.dtype) 

434 nx = prod(c_vxI.shape[1:-1]) 

435 if nx == 0: 

436 return 

437 b_vxI = c_vxI.reshape((3, nx, self.nI)) 

438 a_xG = a_xG.reshape((nx, a_xG.shape[-1])).view(self.dtype) 

439 

440 alpha = 1.0 

441 

442 if c_axiv is None: 

443 c_axiv = self.dict(a_xG.shape[:-1], derivative=True) 

444 

445 x = 0.0 

446 for G1, G2 in self.block(): 

447 f_GI = self.expand(G1, G2, cc=True) 

448 G_Gv = xp.asarray(self.pw.G_plus_k_Gv[G1:G2], 

449 dtype=as_real_dtype(self.dtype)) 

450 if self.real: 

451 d_GI = xp.empty_like(f_GI) 

452 for v in range(3): 

453 d_GI[::2] = f_GI[1::2] * G_Gv[:, v, np.newaxis] 

454 d_GI[1::2] = f_GI[::2] * G_Gv[:, v, np.newaxis] 

455 if xp is np: 

456 mmm(2 * alpha, 

457 a_xG[:, 2 * G1:2 * G2], 'N', 

458 d_GI, 'N', 

459 x, b_vxI[v]) 

460 else: 

461 gpu_gemm('N', 'N', 

462 a_xG[:, 2 * G1:2 * G2], 

463 d_GI, 

464 b_vxI[v], 

465 2 * alpha, x) 

466 else: 

467 for v in range(3): 

468 if xp is np: 

469 mmm(-alpha, 

470 a_xG[:, G1:G2], 'N', 

471 f_GI * G_Gv[:, v, np.newaxis], 'N', 

472 x, b_vxI[v]) 

473 else: 

474 gpu_gemm('N', 'N', 

475 a_xG[:, G1:G2], 

476 f_GI * G_Gv[:, v, np.newaxis], 

477 b_vxI[v], 

478 -alpha, x) 

479 x = 1.0 

480 

481 self.comm.sum(c_vxI) 

482 

483 for v in range(3): 

484 if self.real: 

485 for a, I1, I2 in self.my_indices: 

486 c_axiv[a][..., v] = c_vxI[v, ..., I1:I2] 

487 else: 

488 for a, I1, I2 in self.my_indices: 

489 c_axiv[a][..., v] = (1.0j * self.eikR_a[a] * 

490 c_vxI[v, ..., I1:I2]) 

491 

492 return c_axiv 

493 

494 @trace 

495 def stress_tensor_contribution(self, a_xG, c_axi=1.0): 

496 xp = self.xp 

497 cache = {} 

498 things = [] 

499 I1 = 0 

500 lmax = 0 

501 for a, spline_j in enumerate(self.spline_aj): 

502 for spline in spline_j: 

503 if spline not in cache: 

504 s = rescaled_fourier_bessel_transform(spline) 

505 G_G = (2 * self.pw.ekin_G)**0.5 

506 f_G = [] 

507 dfdGoG_G = [] 

508 for G in G_G: 

509 f, dfdG = s.get_value_and_derivative(G) 

510 if G < 1e-10: 

511 G = 1.0 

512 f_G.append(f) 

513 dfdGoG_G.append(dfdG / G) 

514 f_G = xp.array(f_G) 

515 dfdGoG_G = xp.array(dfdGoG_G) 

516 cache[spline] = (f_G, dfdGoG_G) 

517 else: 

518 f_G, dfdGoG_G = cache[spline] 

519 l = spline.l 

520 lmax = max(l, lmax) 

521 I2 = I1 + 2 * l + 1 

522 things.append((a, l, I1, I2, f_G, dfdGoG_G)) 

523 I1 = I2 

524 

525 if isinstance(c_axi, float): 

526 c_axi = {a: c_axi for a in range(len(self.pos_av))} 

527 

528 G0_Gv = self.pw.G_plus_k_Gv 

529 

530 stress_vv = xp.zeros((3, 3)) 

531 for G1, G2 in self.block(ensure_same_number_of_blocks=True): 

532 G_Gv = G0_Gv[G1:G2] 

533 Z_LvG = xp.array([nablarlYL(L, G_Gv.T) 

534 for L in range((lmax + 1)**2)]) 

535 G_Gv = xp.asarray(G_Gv) 

536 aa_xG = a_xG[..., G1:G2] 

537 for v1 in range(3): 

538 for v2 in range(3): 

539 stress_vv[v1, v2] += self._stress_tensor_contribution( 

540 v1, v2, things, G1, G2, G_Gv, aa_xG, c_axi, Z_LvG) 

541 

542 return stress_vv 

543 

544 @trace 

545 def _stress_tensor_contribution(self, v1, v2, things, G1, G2, 

546 G_Gv, a_xG, c_axi, Z_LvG): 

547 xp = self.xp 

548 f_IG = xp.empty((self.nI, G2 - G1), as_complex_dtype(self.dtype)) 

549 emiGR_Ga = self.get_emiGR_Ga(G1, G2) 

550 Y_LG = self.Y_GL.T 

551 for a, l, I1, I2, f_G, dfdGoG_G in things: 

552 L1 = l**2 

553 L2 = (l + 1)**2 

554 f_IG[I1:I2] = (emiGR_Ga[:, a] * (-1.0j)**l * 

555 (dfdGoG_G[G1:G2] * G_Gv[:, v1] * G_Gv[:, v2] * 

556 Y_LG[L1:L2, G1:G2] + 

557 f_G[G1:G2] * G_Gv[:, v1] * Z_LvG[L1:L2, v2])) 

558 

559 c_xI = xp.zeros(a_xG.shape[:-1] + (self.nI,), self.pw.dtype) 

560 

561 x = prod(c_xI.shape[:-1]) 

562 if x == 0: 

563 return 0.0 

564 b_xI = c_xI.reshape((x, self.nI)) 

565 a_xG = a_xG.reshape((x, a_xG.shape[-1])) 

566 

567 alpha = 1.0 

568 if self.real: 

569 alpha = 2.0 

570 if G1 == 0 and self.pw.comm.rank == 0: 

571 f_IG[:, 0] *= 0.5 

572 f_IG = f_IG.view(as_real_dtype(f_IG.dtype)) 

573 a_xG = a_xG.copy().view(as_real_dtype(f_IG.dtype)) 

574 

575 if xp is np: 

576 mmm(alpha, a_xG, 'N', f_IG, 'C', 0.0, b_xI) 

577 else: 

578 gpu_gemm('N', 'H', a_xG, f_IG, b_xI, alpha, 0.0) 

579 self.comm.sum(b_xI) 

580 

581 stress = 0.0 

582 for a, I1, I2 in self.my_indices: 

583 stress -= self.eikR_a[a] * (c_axi[a] * c_xI[..., I1:I2]).sum() 

584 return stress.real