Coverage for gpaw/core/plane_waves.py: 82%

545 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1from __future__ import annotations 

2 

3from math import pi 

4from typing import TYPE_CHECKING, Literal, Sequence 

5 

6import numpy as np 

7from ase.units import Ha 

8 

9import gpaw.fftw as fftw 

10from gpaw import debug 

11from gpaw.core.arrays import DistributedArrays 

12from gpaw.core.domain import Domain 

13from gpaw.core.matrix import Matrix 

14from gpaw.core.pwacf import PWAtomCenteredFunctions 

15from gpaw.gpu import cupy as cp 

16from gpaw.new.c import pw_norm_kinetic_gpu, pw_norm_gpu 

17from gpaw.mpi import MPIComm, serial_comm 

18from gpaw.new import prod, zips 

19from gpaw.new.c import (add_to_density, add_to_density_gpu, pw_insert, 

20 pw_insert_gpu) 

21from gpaw.pw.descriptor import pad 

22from gpaw.typing import (Array1D, Array2D, Array3D, ArrayLike1D, ArrayLike2D, 

23 Vector) 

24from gpaw.fftw import get_efficient_fft_size 

25from gpaw.utilities import as_real_dtype, as_complex_dtype 

26 

27if TYPE_CHECKING: 

28 from gpaw.core import UGArray, UGDesc 

29 

30 

31class PWDesc(Domain['PWArray']): 

32 itemsize = 16 

33 

34 def __init__(self, 

35 *, 

36 ecut: float | None = None, 

37 gcut: float | None = None, 

38 cell: ArrayLike1D | ArrayLike2D, # bohr 

39 kpt: Vector | None = None, # in units of reciprocal cell 

40 comm: MPIComm = serial_comm, 

41 dtype=None): 

42 """Description of plane-wave basis. 

43 

44 parameters 

45 ---------- 

46 ecut: 

47 Cutoff energy for kinetic energy of plane waves (units: hartree). 

48 cell: 

49 Unit cell given as three floats (orthorhombic grid), six floats 

50 (three lengths and the angles in degrees) or a 3x3 matrix 

51 (units: bohr). 

52 comm: 

53 Communicator for distribution of plane-waves. 

54 kpt: 

55 K-point for Block-boundary conditions specified in units of the 

56 reciprocal cell. 

57 dtype: 

58 Data-type (float or complex). 

59 """ 

60 if ecut is None: 

61 assert gcut is not None 

62 ecut = 0.5 * gcut**2 

63 else: 

64 assert gcut is None 

65 gcut = (2.0 * ecut)**0.5 

66 self.gcut = gcut 

67 self.ecut = ecut 

68 Domain.__init__(self, cell, (True, True, True), kpt, comm, dtype) 

69 

70 G_plus_k_Gv, ekin_G, self.indices_cG = find_reciprocal_vectors( 

71 ecut, self.cell_cv, self.kpt_c, self.dtype) 

72 

73 # Find distribution: 

74 S = comm.size 

75 ng = len(ekin_G) 

76 self.maxmysize = (ng + S - 1) // S 

77 ng1 = comm.rank * self.maxmysize 

78 ng2 = min(ng1 + self.maxmysize, ng) 

79 self.ng1 = ng1 

80 self.ng2 = ng2 

81 

82 # Distribute things: 

83 self.ekin_G = ekin_G[ng1:ng2].copy() 

84 self.ekin_G.flags.writeable = False 

85 # self.myindices_cG = self.indices_cG[:, ng1:ng2] 

86 self.G_plus_k_Gv = G_plus_k_Gv[ng1:ng2].copy() 

87 

88 self.shape = (ng,) 

89 self.myshape = (len(self.ekin_G),) 

90 

91 # Convert from np.float64 to float to avoid fake cupy problem ... 

92 # XXX Fix cpupy!!! 

93 self.dv = float(abs(np.linalg.det(self.cell_cv))) 

94 

95 self._indices_cache: dict[tuple[int, ...], Array1D] = {} 

96 

97 def __repr__(self) -> str: 

98 m = self.myshape[0] 

99 n = self.shape[0] 

100 return super().__repr__().replace( 

101 'Domain(', 

102 f'PWDesc(ecut={self.ecut} <coefs={m}/{n}>, ') 

103 

104 def _short_string(self, global_shape): 

105 return (f'plane wave coefficients: {global_shape[-1]}\n' 

106 f'cutoff: {self.ecut * Ha} eV\n') 

107 

108 def global_shape(self) -> tuple[int, ...]: 

109 """Tuple with one element: number of plane waves.""" 

110 return self.shape 

111 

112 def reciprocal_vectors(self) -> Array2D: 

113 """Returns reciprocal lattice vectors, G + k, in xyz coordinates.""" 

114 return self.G_plus_k_Gv 

115 

116 def kinetic_energies(self) -> Array1D: 

117 """Kinetic energy of plane waves. 

118 

119 ::: 

120 

121 _ _ 2 

122 |G+k| / 2 

123 

124 """ 

125 return self.ekin_G 

126 

127 def empty(self, 

128 dims: int | tuple[int, ...] = (), 

129 comm: MPIComm = serial_comm, 

130 xp=None) -> PWArray: 

131 """Create new PlaneWaveExpanions object. 

132 

133 parameters 

134 ---------- 

135 dims: 

136 Extra dimensions. 

137 comm: 

138 Distribute dimensions along this communicator. 

139 """ 

140 return PWArray(self, dims, comm, xp=xp) 

141 

142 def from_data(self, data): 

143 return PWArray(self, data.shape[:-1], data=data) 

144 

145 def new(self, 

146 *, 

147 ecut: float | None = None, 

148 gcut: float | None = None, 

149 kpt=None, 

150 dtype=None, 

151 comm: MPIComm | Literal['inherit'] | None = 'inherit' 

152 ) -> PWDesc: 

153 """Create new plane-wave expansion description.""" 

154 comm = self.comm if comm == 'inherit' else comm or serial_comm 

155 if ecut is None and gcut is None: 

156 ecut = self.ecut 

157 return PWDesc(gcut=gcut, 

158 ecut=ecut, 

159 cell=self.cell_cv, 

160 kpt=self.kpt_c if kpt is None else kpt, 

161 dtype=dtype or self.dtype, 

162 comm=comm or serial_comm) 

163 

164 def indices(self, shape: tuple[int, ...]) -> Array1D: 

165 """Return indices into FFT-grid.""" 

166 Q_G = self._indices_cache.get(shape) 

167 if Q_G is None: 

168 # We should do this here instead of everywhere calling this: !!!! 

169 # if self.dtype == float: 

170 # shape = (shape[0], shape[1], shape[2] // 2 + 1) 

171 Q_G = np.ravel_multi_index(self.indices_cG, shape, # type: ignore 

172 mode='wrap').astype(np.int32) 

173 if debug: 

174 assert (Q_G[1:] > Q_G[:-1]).all() 

175 self._indices_cache[shape] = Q_G 

176 return Q_G 

177 

178 def minimal_uniform_grid(self, 

179 n: int = 1, 

180 factors: Sequence[int] = (2, 3, 5, 7) 

181 ) -> UGDesc: 

182 from gpaw.core import UGDesc 

183 size_c = np.ptp(self.indices_cG, axis=1) + 1 

184 if np.issubdtype(self.dtype, np.floating): 

185 size_c[2] = size_c[2] * 2 - 1 

186 size_c = (size_c + n - 1) // n * n 

187 if factors: 

188 size_c = np.array([get_efficient_fft_size(N, n, factors) 

189 for N in size_c]) 

190 return UGDesc(size=size_c, 

191 cell=self.cell_cv, 

192 pbc=self.pbc_c, 

193 kpt=self.kpt_c, 

194 dtype=self.dtype, 

195 comm=self.comm) 

196 

197 def cut(self, array_Q: Array3D) -> Array1D: 

198 """Cut out G-vectors with (G+k)^2/2<E_kin.""" 

199 return array_Q.ravel()[self.indices(array_Q.shape)] 

200 

201 def paste(self, coef_G: Array1D, array_Q: Array3D) -> None: 

202 """Paste G-vectors with (G+k)^2/2<E_kin into 3-D FFT grid and 

203 zero-pad.""" 

204 Q_G = self.indices(array_Q.shape) 

205 if debug: 

206 assert (Q_G[1:] > Q_G[:-1]).all() 

207 assert (Q_G >= 0).all() 

208 assert (Q_G < array_Q.size).all() 

209 assert coef_G.shape == Q_G.shape 

210 assert coef_G.flags.c_contiguous 

211 assert Q_G.flags.c_contiguous 

212 assert array_Q.flags.c_contiguous 

213 

214 assert isinstance(coef_G, np.ndarray) 

215 assert isinstance(array_Q, np.ndarray) 

216 pw_insert(coef_G, Q_G, 1.0, array_Q) 

217 

218 def map_indices(self, other: PWDesc) -> tuple[Array1D, list[Array1D]]: 

219 """Map from one (distributed) set of plane waves to smaller global set. 

220 

221 Say we have 9 G-vector on two cores:: 

222 

223 5 3 4 . 3 4 0 . . 

224 2 0 1 -> rank=0: 2 0 1 rank=1: . . . 

225 8 6 7 . . . 3 1 2 

226 

227 and we want a mapping to these 5 G-vectors:: 

228 

229 3 

230 2 0 1 

231 4 

232 

233 On rank=0: the return values are:: 

234 

235 [0, 1, 2, 3], [[0, 1, 2, 3], [4]] 

236 

237 and for rank=1:: 

238 

239 [1], [[0, 1, 2, 3], [4]] 

240 """ 

241 size_c = tuple(np.ptp(self.indices_cG, axis=1) + 1) # type: ignore 

242 Q_G = self.indices(size_c) 

243 G_Q = np.empty(prod(size_c), int) 

244 G_Q[Q_G] = np.arange(len(Q_G)) 

245 G_g = G_Q[other.indices(size_c)] 

246 ng1 = 0 

247 g_r = [] 

248 for rank in range(self.comm.size): 

249 ng2 = min(ng1 + self.maxmysize, self.shape[0]) 

250 myg = (ng1 <= G_g) & (G_g < ng2) 

251 g_r.append(np.nonzero(myg)[0]) 

252 if rank == self.comm.rank: 

253 my_G_g = G_g[myg] - ng1 

254 ng1 = ng2 

255 return my_G_g, g_r 

256 

257 def atom_centered_functions(self, 

258 functions, 

259 positions, 

260 *, 

261 qspiral_v=None, 

262 atomdist=None, 

263 integrals=None, 

264 cut=False, 

265 xp=None): 

266 """Create PlaneWaveAtomCenteredFunctions object.""" 

267 if qspiral_v is None: 

268 return PWAtomCenteredFunctions(functions, positions, self, 

269 atomdist=atomdist, 

270 xp=xp, integrals=integrals) 

271 

272 from gpaw.new.spinspiral import SpiralPWACF 

273 return SpiralPWACF(functions, positions, self, 

274 atomdist=atomdist, 

275 qspiral_v=qspiral_v) 

276 

277 

278class PWArray(DistributedArrays[PWDesc]): 

279 def __init__(self, 

280 pw: PWDesc, 

281 dims: int | tuple[int, ...] = (), 

282 comm: MPIComm = serial_comm, 

283 data: np.ndarray | None = None, 

284 xp=None): 

285 """Object for storing function(s) as a plane-wave expansions. 

286 

287 parameters 

288 ---------- 

289 pw: 

290 Description of plane-waves. 

291 dims: 

292 Extra dimensions. 

293 comm: 

294 Distribute extra dimensions along this communicator. 

295 data: 

296 Data array for storage. 

297 """ 

298 

299 self.real_dtype = as_real_dtype(pw.dtype) 

300 self.complex_dtype = as_complex_dtype(pw.dtype) 

301 

302 DistributedArrays. __init__(self, dims, pw.myshape, 

303 comm, pw.comm, 

304 data, pw.dv, 

305 self.complex_dtype, xp) 

306 self.desc = pw 

307 self._matrix: Matrix | None 

308 

309 def __repr__(self): 

310 txt = f'PWArray(pw={self.desc}, dims={self.dims}' 

311 if self.comm.size > 1: 

312 txt += f', comm={self.comm.rank}/{self.comm.size}' 

313 if self.xp is not np: 

314 txt += ', xp=cp' 

315 return txt + ')' 

316 

317 def __getitem__(self, index: int | slice) -> PWArray: 

318 data = self.data[index] 

319 return PWArray(self.desc, 

320 data.shape[:-1], 

321 data=data) 

322 

323 def __iter__(self): 

324 for data in self.data: 

325 yield PWArray(self.desc, 

326 data.shape[:-1], 

327 data=data) 

328 

329 def new(self, 

330 data=None, 

331 dims=None) -> PWArray: 

332 """Create new PWArray object of same kind. 

333 

334 Parameters 

335 ---------- 

336 data: 

337 Array to use for storage. 

338 dims: 

339 Extra dimensions (bands, spin, etc.), required if 

340 data does not fit the full array. 

341 """ 

342 if data is None: 

343 assert dims is None 

344 data = self.xp.empty_like(self.data) 

345 else: 

346 if dims is None: 

347 # Number of plane-waves depends on the k-point. We therefore 

348 # allow for data to be bigger than needed: 

349 data = data.ravel()[:self.data.size].reshape(self.data.shape) 

350 else: 

351 return PWArray(self.desc, dims, self.comm, data) 

352 return PWArray(self.desc, 

353 self.dims, 

354 self.comm, 

355 data) 

356 

357 def copy(self): 

358 """Create a copy (surprise!).""" 

359 a = self.new() 

360 a.data[:] = self.data 

361 return a 

362 

363 def sanity_check(self) -> None: 

364 """Sanity check for real-valued PW expansions. 

365 

366 Make sure the G=(0,0,0) coefficient doesn't have an imaginary part. 

367 """ 

368 if self.xp.isnan(self.data).any(): 

369 raise ValueError('NaN value') 

370 if self.desc.dtype == self.real_dtype and self.desc.comm.rank == 0: 

371 if (self.data[..., 0].imag != 0.0).any(): 

372 val = self.xp.max(self.xp.abs(self.data[..., 0].imag)) 

373 raise ValueError( 

374 f'Imag value of {val}') 

375 

376 def _arrays(self): 

377 shape = self.data.shape 

378 return self.data.reshape((prod(shape[:-1]), shape[-1])) 

379 

380 @property 

381 def matrix(self) -> Matrix: 

382 """Matrix view of data.""" 

383 if self._matrix is not None: 

384 return self._matrix 

385 

386 shape = (self.dims[0], prod(self.dims[1:]) * self.myshape[0]) 

387 myshape = (self.mydims[0], prod(self.mydims[1:]) * self.myshape[0]) 

388 dist = (self.comm, -1, 1) 

389 data = self.data.reshape(myshape) 

390 

391 if self.desc.dtype == self.real_dtype: 

392 data = data.view(self.real_dtype) 

393 shape = (shape[0], shape[1] * 2) 

394 

395 self._matrix = Matrix(*shape, data=data, dist=dist) 

396 return self._matrix 

397 

398 def ifft(self, 

399 *, 

400 plan=None, 

401 grid=None, 

402 grid_spacing=None, 

403 out=None, 

404 periodic=False): 

405 """Do inverse FFT(s) to uniform grid(s). 

406 

407 Returns: 

408 UGArray with values 

409 ::: 

410 _ _ 

411 _ -- iG.R 

412 f(r) = > c(G) e 

413 -- 

414 G 

415 

416 Parameters 

417 ---------- 

418 plan: 

419 Plan for inverse FFT. 

420 grid: 

421 Target grid. 

422 out: 

423 Target UGArray object. 

424 """ 

425 comm = self.desc.comm 

426 xp = self.xp 

427 if out is None: 

428 if grid is None: 

429 grid = self.desc.uniform_grid_with_grid_spacing(grid_spacing) 

430 out = grid.empty(self.dims, xp=xp) 

431 assert self.desc.dtype == out.desc.dtype, \ 

432 (self.desc.dtype, out.desc.dtype) 

433 

434 assert not out.desc.zerobc_c.any() 

435 assert comm.size == out.desc.comm.size, (comm, out.desc.comm) 

436 

437 plan = plan or out.desc.fft_plans(xp=xp) 

438 this = self.gather() 

439 if this is not None: 

440 for coef_G, out1 in zips(this._arrays(), out.flat()): 

441 plan.ifft_sphere(coef_G, self.desc, out1) 

442 else: 

443 for out1 in out.flat(): 

444 plan.ifft_sphere(None, self.desc, out1) 

445 

446 if not periodic: 

447 out.multiply_by_eikr() 

448 

449 return out 

450 

451 def interpolate(self, 

452 plan1: fftw.FFTPlans | None = None, 

453 plan2: fftw.FFTPlans | None = None, 

454 grid: UGDesc | None = None, 

455 out: UGArray | None = None) -> UGArray: 

456 assert plan1 is None 

457 return self.ifft(plan=plan2, grid=grid, out=out) 

458 

459 def gather(self, out=None, broadcast=False): 

460 """Gather coefficients on master.""" 

461 comm = self.desc.comm 

462 

463 if comm.size == 1: 

464 if out is None: 

465 return self 

466 out.data[:] = self.data 

467 return out 

468 

469 if out is None: 

470 if comm.rank == 0 or broadcast: 

471 pw = self.desc.new(comm=serial_comm) 

472 out = pw.empty(self.dims, comm=self.comm, xp=self.xp) 

473 else: 

474 out = Empty(self.mydims) 

475 

476 if comm.rank == 0: 

477 data = self.xp.empty(self.desc.maxmysize * comm.size, 

478 self.complex_dtype) 

479 else: 

480 data = None 

481 

482 for input, output in zips(self._arrays(), out._arrays()): 

483 mydata = pad(input, self.desc.maxmysize) 

484 comm.gather(mydata, 0, data) 

485 if comm.rank == 0: 

486 output[:] = data[:len(output)] 

487 

488 if broadcast: 

489 comm.broadcast(out.data, 0) 

490 

491 return out if not isinstance(out, Empty) else None 

492 

493 def gather_all(self, out: PWArray) -> None: 

494 """Gather coefficients from self[r] on rank r. 

495 

496 On rank r, an array of all G-vector coefficients will be returned. 

497 These will be gathered from self[r] on all the cores. 

498 """ 

499 assert len(self.dims) == 1 

500 pw = self.desc 

501 comm = pw.comm 

502 if comm.size == 1: 

503 out.data[:] = self.data[0] 

504 return 

505 

506 N = self.dims[0] 

507 assert N <= comm.size 

508 

509 ng = pw.shape[0] 

510 myng = pw.myshape[0] 

511 maxmyng = pw.maxmysize 

512 

513 ssize_r, soffset_r, rsize_r, roffset_r = a2a_stuff( 

514 comm, N, ng, myng, maxmyng) 

515 

516 comm.alltoallv(self.data, ssize_r, soffset_r, 

517 out.data, rsize_r, roffset_r) 

518 

519 def scatter_from(self, data: Array1D | PWArray | None = None) -> None: 

520 """Scatter data from rank-0 to all ranks.""" 

521 if isinstance(data, PWArray): 

522 data = data.data 

523 comm = self.desc.comm 

524 if comm.size == 1: 

525 assert data is not None 

526 self.data[:] = self.xp.asarray(data) 

527 return 

528 

529 if comm.rank == 0: 

530 assert data is not None 

531 shape = data.shape 

532 for fro, to in zips(data.reshape((prod(shape[:-1]), shape[-1])), 

533 self._arrays()): 

534 fro = pad(fro, comm.size * self.desc.maxmysize) 

535 comm.scatter(fro, to, 0) 

536 else: 

537 buf = self.xp.empty(self.desc.maxmysize, self.complex_dtype) 

538 for to in self._arrays(): 

539 comm.scatter(None, buf, 0) 

540 to[:] = buf[:len(to)] 

541 

542 def scatter_from_all(self, a_G: PWArray) -> None: 

543 """Scatter all coefficients from rank r to self on other cores.""" 

544 assert len(self.dims) == 1 

545 pw = self.desc 

546 comm = pw.comm 

547 if comm.size == 1: 

548 self.data[:] = a_G.data 

549 return 

550 

551 N = self.dims[0] 

552 assert N <= comm.size 

553 

554 ng = pw.shape[0] 

555 myng = pw.myshape[0] 

556 maxmyng = pw.maxmysize 

557 

558 rsize_r, roffset_r, ssize_r, soffset_r = a2a_stuff( 

559 comm, N, ng, myng, maxmyng) 

560 

561 comm.alltoallv(a_G.data, ssize_r, soffset_r, 

562 self.data, rsize_r, roffset_r) 

563 

564 def integrate(self, other: PWArray | None = None) -> np.ndarray: 

565 """Integral of self or self time cc(other).""" 

566 dv = self.dv 

567 if other is not None: 

568 assert self.comm.size == 1 

569 assert self.desc.dtype == other.desc.dtype 

570 a = self._arrays() 

571 b = other._arrays() 

572 if self.desc.dtype == self.real_dtype: 

573 a = a.view(self.real_dtype) 

574 b = b.view(self.real_dtype) 

575 dv *= 2 

576 result = a @ b.T.conj() 

577 if self.desc.dtype == self.real_dtype and self.desc.comm.rank == 0: 

578 result -= 0.5 * a[:, :1] @ b[:, :1].T 

579 self.desc.comm.sum(result) 

580 result = result.reshape(self.dims + other.dims) 

581 else: 

582 if self.desc.comm.rank == 0: 

583 result = self.data[..., 0] 

584 else: 

585 result = self.xp.empty(self.mydims, self.complex_dtype) 

586 self.desc.comm.broadcast(self.xp.ascontiguousarray(result), 0) 

587 if self.desc.dtype == self.real_dtype: 

588 result = result.real 

589 if result.ndim == 0: 

590 result = result.item() # convert to scalar 

591 return result * dv 

592 

593 def _matrix_elements_correction(self, 

594 M1: Matrix, 

595 M2: Matrix, 

596 out: Matrix, 

597 symmetric: bool) -> None: 

598 if self.desc.dtype == self.real_dtype: 

599 if symmetric: 

600 # Upper triangle could contain garbadge that will overflow 

601 # when multiplied by 2 

602 out.data[np.triu_indices(M1.shape[0], 1)] = 42.0 

603 out.data *= 2.0 

604 if self.desc.comm.rank == 0: 

605 correction = M1.data[:, :1] @ M2.data[:, :1].T 

606 if symmetric: 

607 correction *= 0.5 * self.dv 

608 out.data -= correction 

609 out.data -= correction.T 

610 else: 

611 correction *= self.dv 

612 out.data -= correction 

613 

614 def norm2(self, kind: str = 'normal', skip_sum=False) -> np.ndarray: 

615 r"""Calculate integral over cell. 

616 

617 For kind='normal' we calculate::: 

618 

619 / _ 2 _ -- 2 

620 ||a(r)| dr = > |c | V, 

621 / -- G 

622 G 

623 

624 where V is the volume of the unit cell. 

625 

626 And for kind='kinetic'::: 

627 

628 1 -- 2 2 

629 --- > |c | G V, 

630 2 -- G 

631 G 

632 

633 """ 

634 a_xG = self._arrays().view(self.real_dtype) 

635 if kind == 'normal': 

636 if self.xp is not np: 

637 result_x = self.xp.empty((a_xG.shape[0],), 

638 dtype=self.real_dtype) 

639 pw_norm_gpu(result_x, self._arrays()) 

640 else: 

641 result_x = self.xp.einsum('xG, xG -> x', a_xG, a_xG) 

642 elif kind == 'kinetic': 

643 x, G2 = a_xG.shape 

644 if self.xp is not np: 

645 result_x = self.xp.empty((x,), dtype=self.real_dtype) 

646 pw_norm_kinetic_gpu(result_x, self._arrays(), 

647 self.xp.asarray(self.desc.ekin_G, 

648 dtype=self.real_dtype)) 

649 else: 

650 a_xGz = a_xG.reshape((x, G2 // 2, 2)) 

651 result_x = self.xp.einsum('xGz, xGz, G -> x', 

652 a_xGz, 

653 a_xGz, 

654 self.xp.asarray(self.desc.ekin_G)) 

655 else: 

656 1 / 0 

657 if self.desc.dtype == self.real_dtype: 

658 result_x *= 2 

659 if self.desc.comm.rank == 0 and kind == 'normal': 

660 result_x -= a_xG[:, 0]**2 

661 if not skip_sum: 

662 self.desc.comm.sum(result_x) 

663 return result_x.reshape(self.mydims) * self.dv 

664 

665 def abs_square(self, 

666 weights: Array1D, 

667 out: UGArray, 

668 _slow: bool = False) -> None: 

669 """Add weighted absolute square of self to output array. 

670 

671 With `a_n(G)` being self and `w_n` the weights::: 

672 

673 _ _ -- -1 _ 2 

674 out(r) <- out(r) + > |FFT [a (G)]| w 

675 -- n n 

676 n 

677 

678 """ 

679 pw = self.desc 

680 domain_comm = pw.comm 

681 xp = self.xp 

682 a_nG = self 

683 

684 if domain_comm.size == 1: 

685 if not _slow and xp is cp and pw.dtype == self.complex_dtype: 

686 return abs_square_gpu(a_nG, weights, out) 

687 

688 a_R = out.desc.new(dtype=pw.dtype).empty(xp=xp) 

689 for weight, a_G in zips(weights, a_nG): 

690 if weight == 0.0: 

691 continue 

692 a_G.ifft(out=a_R) 

693 if xp is np: 

694 add_to_density(weight, a_R.data, out.data) 

695 else: 

696 out.data += float(weight) * xp.abs(a_R.data)**2 

697 return 

698 

699 # Undistributed work arrays: 

700 a1_R = out.desc.new(comm=None, dtype=pw.dtype).empty(xp=xp) 

701 a1_G = pw.new(comm=None).empty(xp=xp) 

702 b1_R = out.desc.new(comm=None).zeros(xp=xp) 

703 

704 (N,) = self.mydims 

705 for n1 in range(0, N, domain_comm.size): 

706 n2 = min(n1 + domain_comm.size, N) 

707 a_nG[n1:n2].gather_all(a1_G) 

708 n = n1 + domain_comm.rank 

709 if n >= N: 

710 continue 

711 weight = weights[n] 

712 if weight == 0.0: 

713 continue 

714 a1_G.ifft(out=a1_R) 

715 if xp is np: 

716 add_to_density(weight, a1_R.data, b1_R.data) 

717 else: 

718 b1_R.data += float(weight) * xp.abs(a1_R.data)**2 

719 

720 domain_comm.sum(b1_R.data) 

721 b_R = out.new() 

722 b_R.scatter_from(b1_R) 

723 out.data += b_R.data 

724 

725 def to_pbc_grid(self): 

726 return self 

727 

728 def randomize(self, seed: int | None = None) -> None: 

729 """Insert random numbers between -0.5 and 0.5 into data.""" 

730 if seed is None: 

731 seed = self.comm.rank + self.desc.comm.rank * self.comm.size 

732 rng = self.xp.random.default_rng(seed) 

733 

734 batches = self.data.size // 5_000_000 + 1 

735 arrays = self.xp.array_split(self.data, batches) 

736 is_real = self.desc.dtype == self.real_dtype 

737 ekin_G = self.xp.asarray(self.desc.ekin_G) 

738 for a in arrays: 

739 # numpy does not require shape, cupy does 

740 # cupy just makes all elements equal to one random number 

741 aview = a.view(dtype=self.real_dtype) 

742 rng.random(aview.shape, out=aview, dtype=self.real_dtype) 

743 

744 # Uniform distribution inside unit circle 

745 a[:] = a.real**0.5 * self.xp.exp(2j * self.xp.pi * a.imag) 

746 

747 # Damp high spatial frequencies 

748 a[..., :] *= 0.5 / (1.00 + ekin_G[..., :]) 

749 

750 # Make sure gamma point G=0 does not have imaginary part 

751 if is_real and self.desc.comm.rank == 0: 

752 a[..., 0].imag = 0.0 

753 

754 def moment(self): 

755 pw = self.desc 

756 # Masks: 

757 m0_G, m1_G, m2_G = (i_G == 0 for i_G in pw.indices_cG) 

758 a_G = self.gather() 

759 if a_G is not None: 

760 b_G = a_G.data.imag 

761 b_cs = [b_G[m1_G & m2_G], 

762 b_G[m0_G & m2_G], 

763 b_G[m0_G & m1_G]] 

764 d_c = [b_s[1:] @ (1.0 / np.arange(1, len(b_s))) 

765 for b_s in b_cs] 

766 m_v = d_c @ pw.cell_cv / pi * pw.dv 

767 else: 

768 m_v = np.empty(3) 

769 pw.comm.broadcast(m_v, 0) 

770 return m_v 

771 

772 def boundary_value(self, axis: int) -> float: 

773 """Calculate average value at boundary of box.""" 

774 assert axis == 2 

775 pw = self.desc 

776 m0_G, m1_G = pw.indices_cG[:2, pw.ng1:pw.ng2] == 0 

777 assert self.desc.dtype == self.real_dtype 

778 value = self.data.real[m0_G & m1_G].sum() * 2 

779 if self.desc.comm.rank == 0: 

780 value -= self.data[0].real 

781 return self.desc.comm.sum_scalar(value) 

782 

783 def morph(self, pw: PWDesc) -> PWArray: 

784 """Transform expansion to new cell.""" 

785 in_xG = self.gather() 

786 if in_xG is not None: 

787 pwin = in_xG.desc 

788 pwout = pw.new(comm=None) 

789 

790 d = {} 

791 for G, i_c in enumerate(pwout.indices_cG.T): 

792 d[tuple(i_c)] = G 

793 G_G0 = [] 

794 G0_G = [] 

795 for G0, i_c in enumerate(pwin.indices_cG.T): 

796 G = d.get(tuple(i_c), -1) 

797 if G != -1: 

798 G_G0.append(G) 

799 G0_G.append(G0) 

800 out0_xG = pwout.zeros(self.dims, 

801 comm=self.comm, 

802 xp=self.xp) 

803 out0_xG.data[..., G_G0] = in_xG.data[..., G0_G] 

804 else: 

805 out0_xG = None 

806 

807 out_xG = pw.zeros(self.dims, 

808 comm=self.comm, 

809 xp=self.xp) 

810 out_xG.scatter_from(out0_xG) 

811 return out_xG 

812 

813 def add_ked(self, 

814 occ_n: Array1D, 

815 taut_R: UGArray) -> None: 

816 psit_nG = self 

817 pw = psit_nG.desc 

818 domain_comm = pw.comm 

819 

820 # Undistributed work arrays: 

821 dpsit1_R = taut_R.desc.new(comm=None, dtype=pw.dtype).empty() 

822 pw1 = pw.new(comm=None) 

823 psit1_G = pw1.empty() 

824 iGpsit1_G = pw1.empty() 

825 taut1_R = taut_R.desc.new(comm=None).zeros() 

826 Gplusk1_Gv = pw1.reciprocal_vectors() 

827 

828 (N,) = psit_nG.mydims 

829 for n1 in range(0, N, domain_comm.size): 

830 n2 = min(n1 + domain_comm.size, N) 

831 psit_nG[n1:n2].gather_all(psit1_G) 

832 n = n1 + domain_comm.rank 

833 if n >= N: 

834 continue 

835 f = occ_n[n] 

836 if f == 0.0: 

837 continue 

838 for v in range(3): 

839 iGpsit1_G.data[:] = psit1_G.data 

840 iGpsit1_G.data *= 1j * Gplusk1_Gv[:, v] 

841 iGpsit1_G.ifft(out=dpsit1_R) 

842 add_to_density(0.5 * f, dpsit1_R.data, taut1_R.data) 

843 domain_comm.sum(taut1_R.data) 

844 tmp_R = taut_R.new() 

845 tmp_R.scatter_from(taut1_R) 

846 taut_R.data += tmp_R.data 

847 

848 def transform(self, 

849 U_cc: np.ndarray, 

850 complex_conjugate: bool = False, 

851 pw: PWDesc | None = None) -> PWArray: 

852 """Symmetry-transform data.""" 

853 pw1 = self.desc 

854 pw2 = pw 

855 if complex_conjugate: 

856 U_cc = -U_cc 

857 kpt2_c = U_cc @ pw1.kpt_c 

858 if pw2 is None: 

859 pw2 = pw1.new(kpt=kpt2_c) 

860 else: 

861 assert np.allclose(pw2.kpt_c, kpt2_c) 

862 

863 size_c = np.ptp(pw1.indices_cG, axis=1) + 1 

864 Q1_G = np.ravel_multi_index(U_cc @ pw1.indices_cG, 

865 size_c, 

866 mode='wrap') 

867 Q2_G = np.ravel_multi_index(pw2.indices_cG, # type: ignore 

868 size_c, 

869 mode='wrap') 

870 G_Q = np.empty(np.prod(size_c), dtype=int) 

871 G_Q[:] = -1 

872 G_Q[Q1_G] = np.arange(len(Q1_G), dtype=int) 

873 G1_G2 = G_Q[Q2_G] 

874 assert -1 not in G1_G2 

875 data = np.ascontiguousarray(self.data[..., G1_G2]) 

876 if complex_conjugate: 

877 np.negative(data.imag, data.imag) 

878 return PWArray(pw2, self.dims, self.comm, data) 

879 

880 

881def a2a_stuff(comm, N, ng, myng, maxmyng): 

882 """Create arrays for MPI alltoallv call.""" 

883 ssize_r = np.zeros(comm.size, int) 

884 ssize_r[:N] = myng 

885 soffset_r = np.arange(comm.size) * myng 

886 soffset_r[N:] = 0 

887 roffset_r = (np.arange(comm.size) * maxmyng).clip(max=ng) 

888 rsize_r = np.zeros(comm.size, int) 

889 if comm.rank < N: 

890 rsize_r[:-1] = roffset_r[1:] - roffset_r[:-1] 

891 rsize_r[-1] = ng - roffset_r[-1] 

892 return ssize_r, soffset_r, rsize_r, roffset_r 

893 

894 

895class Empty: 

896 def __init__(self, dims): 

897 self.dims = dims 

898 

899 def _arrays(self): 

900 for _ in range(prod(self.dims)): 

901 yield 

902 

903 

904def find_reciprocal_vectors(ecut: float, 

905 cell: Array2D, 

906 kpt=np.zeros(3), 

907 dtype=complex) -> tuple[Array2D, 

908 Array1D, 

909 Array2D]: 

910 """Find reciprocal lattice vectors inside sphere. 

911 

912 >>> cell = np.eye(3) 

913 >>> ecut = 0.5 * (2 * pi)**2 

914 >>> G, e, i = find_reciprocal_vectors(ecut, cell) 

915 >>> G 

916 array([[ 0. , 0. , 0. ], 

917 [ 0. , 0. , 6.28318531], 

918 [ 0. , 0. , -6.28318531], 

919 [ 0. , 6.28318531, 0. ], 

920 [ 0. , -6.28318531, 0. ], 

921 [ 6.28318531, 0. , 0. ], 

922 [-6.28318531, 0. , 0. ]]) 

923 >>> e 

924 array([ 0. , 19.7392088, 19.7392088, 19.7392088, 19.7392088, 

925 19.7392088, 19.7392088]) 

926 >>> i 

927 array([[ 0, 0, 0, 0, 0, 1, -1], 

928 [ 0, 0, 0, 1, -1, 0, 0], 

929 [ 0, 1, -1, 0, 0, 0, 0]]) 

930 """ 

931 Gcut = (2 * ecut)**0.5 

932 n = Gcut * (cell**2).sum(axis=1)**0.5 / (2 * pi) + abs(kpt) 

933 size = 2 * n.astype(int) + 4 

934 

935 real = np.issubdtype(dtype, np.floating) 

936 if real: 

937 size[2] = size[2] // 2 + 1 

938 i_Qc = np.indices(size).transpose((1, 2, 3, 0)) 

939 i_Qc[..., :2] += size[:2] // 2 

940 i_Qc[..., :2] %= size[:2] 

941 i_Qc[..., :2] -= size[:2] // 2 

942 else: 

943 i_Qc = np.indices(size).transpose((1, 2, 3, 0)) # type: ignore 

944 half = [s // 2 for s in size] 

945 i_Qc += half 

946 i_Qc %= size 

947 i_Qc -= half 

948 

949 # Calculate reciprocal lattice vectors: 

950 B_cv = 2.0 * pi * np.linalg.inv(cell).T 

951 # i_Qc.shape = (-1, 3) 

952 G_plus_k_Qv = (i_Qc + kpt) @ B_cv 

953 

954 ekin = 0.5 * (G_plus_k_Qv**2).sum(axis=3) 

955 mask = ekin <= ecut 

956 

957 assert not mask[size[0] // 2].any() 

958 assert not mask[:, size[1] // 2].any() 

959 if not real: 

960 assert not mask[:, :, size[2] // 2].any() 

961 else: 

962 assert not mask[:, :, -1].any() 

963 

964 if real: 

965 mask &= ((i_Qc[..., 2] > 0) | 

966 (i_Qc[..., 1] > 0) | 

967 ((i_Qc[..., 0] >= 0) & (i_Qc[..., 1] == 0))) 

968 

969 indices = i_Qc[mask] 

970 ekin = ekin[mask] 

971 G_plus_k = G_plus_k_Qv[mask] 

972 

973 return G_plus_k, ekin, indices.T 

974 

975 

976def abs_square_gpu(psit_nG, weight_n, nt_R): 

977 from gpaw.gpu import cupyx 

978 pw = psit_nG.desc 

979 plan = nt_R.desc.fft_plans(xp=cp, dtype=complex) 

980 Q_G = cp.asarray(plan.indices(pw)) 

981 weight_n = cp.asarray(weight_n) 

982 N = len(weight_n) 

983 shape = tuple(nt_R.desc.size_c) 

984 B = 32 

985 psit_bR = None 

986 for b1 in range(0, N, B): 

987 b2 = min(b1 + B, N) 

988 nb = b2 - b1 

989 if psit_bR is None: 

990 psit_bR = cp.empty((nb,) + shape, psit_nG.data.dtype) 

991 elif nb < B: 

992 psit_bR = psit_bR[:nb] 

993 psit_bR[:] = 0.0 

994 # TODO: Remember to give real space size instead of 

995 # reciprocal space size when doing real wave functions 

996 # (now psit_bR is shared between real and reciprocal space) 

997 pw_insert_gpu(psit_nG.data[b1:b2], 

998 Q_G, 

999 1.0, 

1000 psit_bR.reshape((nb, -1)), *psit_bR.shape[1:]) 

1001 psit_bR[:] = cupyx.scipy.fft.ifftn( 

1002 psit_bR, 

1003 shape, 

1004 norm='forward', 

1005 overwrite_x=True) 

1006 add_to_density_gpu(weight_n[b1:b2], 

1007 psit_bR, 

1008 nt_R.data)