Coverage for gpaw/core/uniform_grid.py: 81%

464 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1from __future__ import annotations 

2 

3from functools import cached_property 

4from math import pi 

5from typing import Sequence, Literal, TYPE_CHECKING 

6import numpy as np 

7 

8import gpaw.fftw as fftw 

9from gpaw.core.arrays import DistributedArrays 

10from gpaw.core.atom_centered_functions import UGAtomCenteredFunctions 

11from gpaw.core.domain import Domain 

12from gpaw.gpu import as_np, cupy_is_fake 

13from gpaw.grid_descriptor import GridDescriptor 

14from gpaw.mpi import MPIComm, serial_comm 

15from gpaw.new import zips 

16from gpaw.typing import (Array1D, Array2D, Array3D, Array4D, ArrayLike1D, 

17 ArrayLike2D, Vector) 

18from gpaw.new.c import add_to_density, add_to_density_gpu, symmetrize_ft 

19from gpaw.fd_operators import Gradient 

20 

21if TYPE_CHECKING: 

22 import plotly.graph_objects as go 

23 

24 

25class UGDesc(Domain['UGArray']): 

26 def __init__(self, 

27 *, 

28 cell: ArrayLike1D | ArrayLike2D, # bohr 

29 size: ArrayLike1D, 

30 pbc=(True, True, True), 

31 zerobc=(False, False, False), 

32 kpt: Vector | None = None, # in units of reciprocal cell 

33 comm: MPIComm = serial_comm, 

34 decomp: Sequence[Sequence[int]] | None = None, 

35 dtype=None): 

36 """Description of 3D uniform grid. 

37 

38 parameters 

39 ---------- 

40 cell: 

41 Unit cell given as three floats (orthorhombic grid), six floats 

42 (three lengths and the angles in degrees) or a 3x3 matrix 

43 (units: bohr). 

44 size: 

45 Number of grid points along axes. 

46 pbc: 

47 Periodic boundary conditions flag(s). 

48 zerobc: 

49 Zero-boundary conditions flag(s). Skip first grid-point 

50 (assumed to be zero). 

51 comm: 

52 Communicator for domain-decomposition. 

53 kpt: 

54 K-point for Block-boundary conditions specified in units of the 

55 reciprocal cell. 

56 decomp: 

57 Decomposition of the domain. 

58 dtype: 

59 Data-type (float or complex). 

60 """ 

61 self.size_c = np.array(size, int) 

62 if isinstance(zerobc, int): 

63 zerobc = (zerobc,) * 3 

64 self.zerobc_c = np.array(zerobc, bool) 

65 

66 if decomp is None: 

67 gd = GridDescriptor(size, pbc_c=~self.zerobc_c, comm=comm) 

68 decomp = gd.n_cp 

69 self.decomp_cp = [np.asarray(d) for d in decomp] 

70 

71 self.parsize_c = np.array([len(d_p) - 1 for d_p in self.decomp_cp]) 

72 self.mypos_c = np.unravel_index(comm.rank, self.parsize_c) 

73 

74 self.start_c = np.array([d_p[p] 

75 for d_p, p 

76 in zips(self.decomp_cp, self.mypos_c)]) 

77 self.end_c = np.array([d_p[p + 1] 

78 for d_p, p 

79 in zips(self.decomp_cp, self.mypos_c)]) 

80 self.mysize_c = self.end_c - self.start_c 

81 

82 Domain.__init__(self, cell, pbc, kpt, comm, dtype) 

83 self.myshape = tuple(self.mysize_c) 

84 

85 self.dv = self.volume / self.size_c.prod() 

86 

87 self.itemsize = 8 if self.dtype == float else 16 

88 

89 if (self.zerobc_c & self.pbc_c).any(): 

90 raise ValueError('Bad boundary conditions') 

91 

92 @property 

93 def size(self): 

94 """Size of uniform grid.""" 

95 return self.size_c.copy() 

96 

97 def global_shape(self) -> tuple[int, ...]: 

98 """Actual size of uniform grid.""" 

99 return tuple(self.size_c - self.zerobc_c) 

100 

101 def __repr__(self): 

102 return Domain.__repr__(self).replace( 

103 'Domain(', 

104 f'UGDesc(size={self.size_c.tolist()}, ') 

105 

106 def _short_string(self, global_shape): 

107 return f'uniform wave function grid shape: {global_shape}' 

108 

109 @cached_property 

110 def phase_factor_cd(self): 

111 """Phase factor for block-boundary conditions.""" 

112 delta_d = np.array([-1, 1]) 

113 disp_cd = np.empty((3, 2)) 

114 for pos, pbc, size, disp_d in zips(self.mypos_c, self.pbc_c, 

115 self.parsize_c, disp_cd): 

116 disp_d[:] = -((pos + delta_d) // size) 

117 return np.exp(2j * np.pi * 

118 disp_cd * 

119 self.kpt_c[:, np.newaxis]) 

120 

121 def new(self, 

122 *, 

123 kpt=None, 

124 dtype=None, 

125 comm: MPIComm | Literal['inherit'] | None = 'inherit', 

126 size=None, 

127 pbc=None, 

128 zerobc=None, 

129 decomp=None) -> UGDesc: 

130 """Create new uniform grid description.""" 

131 reuse_decomp = (decomp is None and comm == 'inherit' and 

132 size is None and pbc is None and zerobc is None) 

133 if reuse_decomp: 

134 decomp = self.decomp_cp 

135 comm = self.comm if comm == 'inherit' else comm 

136 return UGDesc(cell=self.cell_cv, 

137 size=self.size_c if size is None else size, 

138 pbc=self.pbc_c if pbc is None else pbc, 

139 zerobc=self.zerobc_c if zerobc is None else zerobc, 

140 kpt=(self.kpt_c if self.kpt_c.any() else None) 

141 if kpt is None else kpt, 

142 comm=comm or serial_comm, 

143 decomp=decomp, 

144 dtype=self.dtype if dtype is None else dtype) 

145 

146 def empty(self, 

147 dims: int | tuple[int, ...] = (), 

148 comm: MPIComm = serial_comm, 

149 xp=np) -> UGArray: 

150 """Create new UGArray object. 

151 

152 parameters 

153 ---------- 

154 dims: 

155 Extra dimensions. 

156 comm: 

157 Distribute dimensions along this communicator. 

158 """ 

159 return UGArray(self, dims, comm, xp=xp) 

160 

161 def from_data(self, data: np.ndarray) -> UGArray: 

162 return UGArray(self, data.shape[:-3], data=data) 

163 

164 def blocks(self, data: np.ndarray): 

165 """Yield views of blocks of data.""" 

166 s0, s1, s2 = self.parsize_c 

167 d0_p, d1_p, d2_p = (d_p - d_p[0] for d_p in self.decomp_cp) 

168 for p0 in range(s0): 

169 b0, e0 = d0_p[p0:p0 + 2] 

170 for p1 in range(s1): 

171 b1, e1 = d1_p[p1:p1 + 2] 

172 for p2 in range(s2): 

173 b2, e2 = d2_p[p2:p2 + 2] 

174 yield data[..., b0:e0, b1:e1, b2:e2] 

175 

176 def xyz(self) -> Array4D: 

177 """Create array of (x, y, z) coordinates.""" 

178 indices_Rc = np.indices(self.mysize_c).transpose((1, 2, 3, 0)) 

179 indices_Rc += self.start_c 

180 return indices_Rc @ (self.cell_cv.T / self.size_c).T 

181 

182 def atom_centered_functions(self, 

183 functions, 

184 positions, 

185 *, 

186 qspiral_v=None, 

187 atomdist=None, 

188 integrals=None, 

189 cut=False, 

190 xp=None): 

191 """Create UGAtomCenteredFunctions object.""" 

192 assert qspiral_v is None 

193 return UGAtomCenteredFunctions(functions, 

194 positions, 

195 self, 

196 atomdist=atomdist, 

197 integrals=integrals, 

198 cut=cut, 

199 xp=xp) 

200 

201 def transformer(self, other: UGDesc, stencil_range=3, xp=np): 

202 """Create transformer from one grid to another. 

203 

204 (for interpolation and restriction). 

205 """ 

206 from gpaw.transformers import Transformer 

207 

208 apply = Transformer(self._gd, other._gd, nn=stencil_range, xp=xp).apply 

209 

210 def transform(functions, out=None): 

211 if out is None: 

212 out = other.empty(functions.dims, functions.comm, xp=xp) 

213 for input, output in zips(functions._arrays(), out._arrays()): 

214 apply(input, output) 

215 return out 

216 

217 return transform 

218 

219 def eikr(self, kpt_c: Vector | None = None) -> Array3D: 

220 """Plane wave. 

221 

222 ::: 

223 _ _ 

224 ik.r 

225 e 

226 

227 Parameters 

228 ---------- 

229 kpt_c: 

230 k-point in units of the reciprocal cell. Defaults to the 

231 UGDesc objects own k-point. 

232 """ 

233 if kpt_c is None: 

234 kpt_c = self.kpt_c 

235 index_Rc = np.indices(self.mysize_c).T + self.start_c 

236 return np.exp(2j * pi * (index_Rc @ (kpt_c / self.size_c))).T 

237 

238 @property 

239 def _gd(self): 

240 # Make sure gd can be pickled (in serial): 

241 comm = self.comm if self.comm.size > 1 else serial_comm 

242 

243 return GridDescriptor(self.size_c, 

244 cell_cv=self.cell_cv, 

245 pbc_c=~self.zerobc_c, 

246 comm=comm, 

247 parsize_c=[len(d_p) - 1 

248 for d_p in self.decomp_cp]) 

249 

250 @classmethod 

251 def from_cell_and_grid_spacing(cls, 

252 cell: ArrayLike1D | ArrayLike2D, 

253 grid_spacing: float, 

254 pbc=(True, True, True), 

255 kpt: Vector | None = None, 

256 comm: MPIComm = serial_comm, 

257 dtype=None) -> UGDesc: 

258 """Create UGDesc from grid-spacing.""" 

259 domain: Domain = Domain(cell, pbc, kpt, comm, dtype) 

260 return domain.uniform_grid_with_grid_spacing(grid_spacing) 

261 

262 def fft_plans(self, 

263 flags: int = fftw.MEASURE, 

264 xp=np, 

265 dtype=None) -> fftw.FFTPlans: 

266 """Create FFTW-plans.""" 

267 if dtype is None: 

268 dtype = self.dtype 

269 if self.comm.rank == 0: 

270 return fftw.create_plans(self.size_c, dtype, flags, xp) 

271 else: 

272 return fftw.create_plans([0, 0, 0], dtype) 

273 

274 def ranks_from_fractional_positions(self, 

275 relpos_ac: Array2D) -> Array1D: 

276 rank_ac = np.floor(relpos_ac * self.parsize_c).astype(int) 

277 if (rank_ac < 0).any() or (rank_ac >= self.parsize_c).any(): 

278 raise ValueError('Positions outside cell!') 

279 return np.ravel_multi_index(rank_ac.T, self.parsize_c) # type: ignore 

280 

281 def ekin_max(self) -> float: 

282 """Maximum value of ekin so that all 0.5 * G^2 < ekin. 

283 

284 In 1D, this will be 0.5*(pi/h)^2 where h is the grid-spacing. 

285 """ 

286 # Height of reciprocal cell (squared): 

287 b2_c = np.pi**2 / (self.cell_cv**2).sum(1) 

288 return 0.5 * (self.size_c**2 * b2_c).min() 

289 

290 

291class UGArray(DistributedArrays[UGDesc]): 

292 def __init__(self, 

293 grid: UGDesc, 

294 dims: int | tuple[int, ...] = (), 

295 comm: MPIComm = serial_comm, 

296 data: np.ndarray | None = None, 

297 xp=None): 

298 """Object for storing function(s) on a uniform grid. 

299 

300 parameters 

301 ---------- 

302 grid: 

303 Description of uniform grid. 

304 dims: 

305 Extra dimensions. 

306 comm: 

307 Distribute dimensions along this communicator. 

308 data: 

309 Data array for storage. 

310 """ 

311 DistributedArrays. __init__(self, dims, grid.myshape, 

312 comm, grid.comm, data, grid.dv, 

313 grid.dtype, xp) 

314 self.desc = grid 

315 

316 def __repr__(self): 

317 txt = f'UGArray(grid={self.desc}, dims={self.dims}' 

318 if self.comm.size > 1: 

319 txt += f', comm={self.comm.rank}/{self.comm.size}' 

320 if self.xp is not np: 

321 txt += ', xp=cp' 

322 return txt + ')' 

323 

324 def new(self, data=None, zeroed=False, dims=None): 

325 """Create new UniforGridFunctions object of same kind. 

326 

327 Parameters 

328 ---------- 

329 data: 

330 Array to use for storage. 

331 zeroed: 

332 If True, set data to zero. 

333 dims: 

334 Extra dimensions (bands, spin, etc.), required if 

335 data does not fit the full array. 

336 """ 

337 if dims: 

338 assert data is not None 

339 else: 

340 dims = self.dims 

341 if data is None: 

342 data = self.xp.empty_like(self.data) 

343 

344 f_xR = UGArray(self.desc, dims, self.comm, data) 

345 if zeroed: 

346 f_xR.data[:] = 0.0 

347 return f_xR 

348 

349 def __getitem__(self, index): 

350 data = self.data[index] 

351 return UGArray(data=data, 

352 dims=data.shape[:-3], 

353 grid=self.desc) 

354 

355 def __imul__(self, 

356 other: float | np.ndarray | UGArray 

357 ) -> UGArray: 

358 if isinstance(other, float): 

359 self.data *= other 

360 return self 

361 if isinstance(other, UGArray): 

362 other = other.data 

363 assert other.shape[-3:] == self.data.shape[-3:] 

364 self.data *= other 

365 return self 

366 

367 def __mul__(self, 

368 other: float | np.ndarray | UGArray 

369 ) -> UGArray: 

370 result = self.new(data=self.data.copy()) 

371 result *= other 

372 return result 

373 

374 def _arrays(self): 

375 return self.data.reshape((-1,) + self.data.shape[-3:]) 

376 

377 def xy(self, *axes: int | None) -> tuple[Array1D, Array1D]: 

378 """Extract x, y values along line. 

379 

380 Useful for plotting:: 

381 

382 x, y = grid.xy(0, ..., 0) 

383 plt.plot(x, y) 

384 """ 

385 assert len(axes) == 3 + len(self.dims) 

386 index = tuple([slice(0, None) if axis is None else axis 

387 for axis in axes]) 

388 y = self.data[index] # type: ignore 

389 c = axes[-3:].index(...) 

390 grid = self.desc 

391 dx = (grid.cell_cv[c]**2).sum()**0.5 / grid.size_c[c] 

392 x = np.arange(grid.start_c[c], grid.end_c[c]) * dx 

393 return x, as_np(y) 

394 

395 def to_complex(self) -> UGArray: 

396 """Return a copy with dtype=complex.""" 

397 c = self.desc.new(dtype=complex).empty() 

398 c.data[:] = self.data 

399 return c 

400 

401 def scatter_from(self, data: np.ndarray | UGArray | None = None) -> None: 

402 """Scatter data from rank-0 to all ranks.""" 

403 if isinstance(data, UGArray): 

404 data = data.data 

405 

406 comm = self.desc.comm 

407 if comm.size == 1: 

408 self.data[:] = data 

409 return 

410 

411 if comm.rank != 0: 

412 comm.receive(self.data, 0, 42) 

413 return 

414 

415 requests = [] 

416 assert isinstance(data, self.xp.ndarray) 

417 for rank, block in enumerate(self.desc.blocks(data)): 

418 if rank != 0: 

419 block = block.copy() 

420 request = comm.send(block, rank, 42, False) 

421 # Remember to store a reference to the 

422 # send buffer (block) so that is isn't 

423 # deallocated: 

424 requests.append((request, block)) 

425 else: 

426 self.data[:] = block 

427 

428 for request, _ in requests: 

429 comm.wait(request) 

430 

431 def gather(self, out=None, broadcast=False): 

432 """Gather data from all ranks to rank-0.""" 

433 assert out is None 

434 comm = self.desc.comm 

435 if comm.size == 1: 

436 return self 

437 

438 if broadcast or comm.rank == 0: 

439 grid = self.desc.new(comm=serial_comm) 

440 out = grid.empty(self.dims, comm=self.comm, xp=self.xp) 

441 

442 if comm.rank != 0: 

443 # There can be several sends before the corresponding receives 

444 # are posted, so use synchronous send here 

445 comm.ssend(self.data, 0, 301) 

446 if broadcast: 

447 comm.broadcast(out.data, 0) 

448 return out 

449 return 

450 

451 # Put the subdomains from the slaves into the big array 

452 # for the whole domain: 

453 for rank, block in enumerate(self.desc.blocks(out.data)): 

454 if rank != 0: 

455 buf = self.xp.empty_like(block) 

456 comm.receive(buf, rank, 301) 

457 block[:] = buf 

458 else: 

459 block[:] = self.data 

460 

461 if broadcast: 

462 comm.broadcast(out.data, 0) 

463 

464 return out 

465 

466 def fft(self, plan=None, pw=None, out=None): 

467 r"""Do FFT. 

468 

469 Returns: 

470 PWArray with values 

471 ::: 

472 _ _ 

473 _ 1 / _ -iG.r _ 

474 C(G) = -- |dr e f(r), 

475 V / 

476 

477 where `C(\bG)` are the plane wave coefficients and V is the cell 

478 volume. 

479 

480 Parameters 

481 ---------- 

482 plan: 

483 Plan for FFT. 

484 pw: 

485 Target PW description. 

486 out: 

487 Target PWArray object. 

488 """ 

489 assert self.dims == () 

490 if out is None: 

491 assert pw is not None 

492 out = pw.empty(xp=self.xp) 

493 if pw is None: 

494 pw = out.desc 

495 if pw.dtype != self.desc.dtype: 

496 raise TypeError( 

497 f'Type mismatch: {self.desc.dtype} -> {pw.dtype}') 

498 input = self 

499 if self.desc.comm.size > 1: 

500 input = input.gather() 

501 if self.desc.comm.rank == 0: 

502 plan = plan or self.desc.fft_plans(xp=self.xp) 

503 coefs = plan.fft_sphere(input.data, pw) 

504 else: 

505 coefs = None 

506 

507 out.scatter_from(coefs) 

508 

509 return out 

510 

511 def norm2(self): 

512 """Calculate integral over cell of absolute value squared. 

513 

514 ::: 

515 

516 / _ 2 _ 

517 ||a(r)| dr 

518 / 

519 """ 

520 norm_x = [] 

521 arrays_xR = self._arrays() 

522 for a_R in arrays_xR: 

523 norm_x.append(self.xp.vdot(a_R, a_R).real * self.desc.dv) 

524 result = self.xp.array(norm_x).reshape(self.mydims) 

525 self.desc.comm.sum(result) 

526 return result 

527 

528 def integrate(self, other=None, skip_sum=False): 

529 """Integral of self or self times cc(other).""" 

530 if other is not None: 

531 assert self.desc.dtype == other.desc.dtype 

532 a_xR = self._arrays() 

533 b_yR = other._arrays() 

534 a_xR = a_xR.reshape((len(a_xR), -1)) 

535 b_yR = b_yR.reshape((len(b_yR), -1)) 

536 result = (a_xR @ b_yR.T.conj()).reshape(self.dims + other.dims) 

537 else: 

538 # Make sure we have an array and not a scalar! 

539 result = self.xp.asarray(self.data.sum(axis=(-3, -2, -1))) 

540 

541 if not skip_sum: 

542 self.desc.comm.sum(result) 

543 if result.ndim == 0: 

544 result = result.item() # convert to scalar 

545 return result * self.desc.dv 

546 

547 def to_pbc_grid(self): 

548 """Convert to UniformGrid with ``pbc=(True, True, True)``.""" 

549 if not self.desc.zerobc_c.any(): 

550 return self 

551 grid = self.desc.new(zerobc=False) 

552 new = grid.empty(self.dims) 

553 new.data[:] = 0.0 

554 *_, i, j, k = self.data.shape 

555 new.data[..., -i:, -j:, -k:] = self.data 

556 return new 

557 

558 def multiply_by_eikr(self, kpt_c: Vector | None = None) -> None: 

559 """Multiply by `exp(ik.r)`.""" 

560 if kpt_c is None: 

561 kpt_c = self.desc.kpt_c 

562 else: 

563 kpt_c = np.asarray(kpt_c) 

564 if kpt_c.any(): 

565 self.data *= self.desc.eikr(kpt_c) 

566 

567 def interpolate(self, 

568 plan1: fftw.FFTPlans | None = None, 

569 plan2: fftw.FFTPlans | None = None, 

570 grid: UGDesc | None = None, 

571 out: UGArray | None = None) -> UGArray: 

572 """Interpolate to finer grid. 

573 

574 Parameters 

575 ---------- 

576 plan1: 

577 Plan for FFT (course grid). 

578 plan2: 

579 Plan for inverse FFT (fine grid). 

580 grid: 

581 Target grid. 

582 out: 

583 Target UGArray object. 

584 """ 

585 if out is None: 

586 if grid is None: 

587 raise ValueError('Please specify "grid" or "out".') 

588 out = grid.empty(self.dims, xp=self.xp) 

589 

590 if out.desc.zerobc_c.any() or self.desc.zerobc_c.any(): 

591 raise ValueError('Grids must have zerobc=False!') 

592 

593 if self.desc.comm.size > 1: 

594 input = self.gather() 

595 if input is not None: 

596 output = input.interpolate(plan1, plan2, 

597 out.desc.new(comm=None)) 

598 out.scatter_from(output.data) 

599 else: 

600 out.scatter_from() 

601 return out 

602 

603 size1_c = self.desc.size_c 

604 size2_c = out.desc.size_c 

605 if (size2_c <= size1_c).any(): 

606 raise ValueError('Too few points in target grid!') 

607 

608 plan1 = plan1 or self.desc.fft_plans(xp=self.xp) 

609 plan2 = plan2 or out.desc.fft_plans(xp=self.xp) 

610 

611 if self.dims: 

612 for input, output in zips(self.flat(), out.flat()): 

613 input.interpolate(plan1, plan2, grid, output) 

614 return out 

615 

616 plan1.tmp_R[:] = self.data 

617 kpt_c = self.desc.kpt_c 

618 if kpt_c.any(): 

619 plan1.tmp_R *= self.desc.eikr(-kpt_c) 

620 plan1.fft() 

621 

622 a_Q = plan1.tmp_Q 

623 b_Q = plan2.tmp_Q 

624 

625 e0, e1, e2 = 1 - size1_c % 2 # even or odd size 

626 a0, a1, a2 = size2_c // 2 - size1_c // 2 

627 b0, b1, b2 = size1_c + (a0, a1, a2) 

628 

629 if self.desc.dtype == float: 

630 b2 = (b2 - a2) // 2 + 1 

631 a2 = 0 

632 axes = [0, 1] 

633 else: 

634 axes = [0, 1, 2] 

635 

636 b_Q[:] = 0.0 

637 b_Q[a0:b0, a1:b1, a2:b2] = self.xp.fft.fftshift(a_Q, axes=axes) 

638 

639 if e0: 

640 b_Q[a0, a1:b1, a2:b2] *= 0.5 

641 b_Q[b0, a1:b1, a2:b2] = b_Q[a0, a1:b1, a2:b2] 

642 b0 += 1 

643 if e1: 

644 b_Q[a0:b0, a1, a2:b2] *= 0.5 

645 b_Q[a0:b0, b1, a2:b2] = b_Q[a0:b0, a1, a2:b2] 

646 b1 += 1 

647 if self.desc.dtype == complex: 

648 if e2: 

649 b_Q[a0:b0, a1:b1, a2] *= 0.5 

650 b_Q[a0:b0, a1:b1, b2] = b_Q[a0:b0, a1:b1, a2] 

651 else: 

652 if e2: 

653 b_Q[a0:b0, a1:b1, b2 - 1] *= 0.5 

654 

655 b_Q[:] = self.xp.fft.ifftshift(b_Q, axes=axes) 

656 plan2.ifft() 

657 out.data[:] = plan2.tmp_R 

658 out.data *= (1.0 / self.data.size) 

659 out.multiply_by_eikr() 

660 return out 

661 

662 def fft_restrict(self, 

663 plan1: fftw.FFTPlans | None = None, 

664 plan2: fftw.FFTPlans | None = None, 

665 grid: UGDesc | None = None, 

666 out: UGArray | None = None) -> UGArray: 

667 """Restrict to coarser grid. 

668 

669 Parameters 

670 ---------- 

671 plan1: 

672 Plan for FFT. 

673 plan2: 

674 Plan for inverse FFT. 

675 grid: 

676 Target grid. 

677 out: 

678 Target UGArray object. 

679 """ 

680 if out is None: 

681 if grid is None: 

682 raise ValueError('Please specify "grid" or "out".') 

683 out = grid.empty(self.dims, xp=self.xp) 

684 

685 if out.desc.zerobc_c.any() or self.desc.zerobc_c.any(): 

686 raise ValueError('Grids must have zerobc=False!') 

687 

688 if self.desc.comm.size > 1: 

689 input = self.gather() 

690 if input is not None: 

691 output = input.fft_restrict(plan1, plan2, 

692 out.desc.new(comm=None)) 

693 out.scatter_from(output.data) 

694 else: 

695 out.scatter_from() 

696 return out 

697 

698 size1_c = self.desc.size_c 

699 size2_c = out.desc.size_c 

700 

701 plan1 = plan1 or self.desc.fft_plans() 

702 plan2 = plan2 or out.desc.fft_plans() 

703 

704 if self.dims: 

705 for input, output in zips(self.flat(), out.flat()): 

706 input.fft_restrict(plan1, plan2, grid, output) 

707 return out 

708 

709 plan1.tmp_R[:] = self.data 

710 a_Q = plan2.tmp_Q 

711 b_Q = plan1.tmp_Q 

712 

713 e0, e1, e2 = 1 - size2_c % 2 # even or odd size 

714 a0, a1, a2 = size1_c // 2 - size2_c // 2 

715 b0, b1, b2 = size2_c // 2 + size1_c // 2 + 1 

716 

717 if self.desc.dtype == float: 

718 b2 = size2_c[2] // 2 + 1 

719 a2 = 0 

720 axes = [0, 1] 

721 else: 

722 axes = [0, 1, 2] 

723 

724 plan1.fft() 

725 b_Q[:] = self.xp.fft.fftshift(b_Q, axes=axes) 

726 

727 if e0: 

728 b_Q[a0, a1:b1, a2:b2] += b_Q[b0 - 1, a1:b1, a2:b2] 

729 b_Q[a0, a1:b1, a2:b2] *= 0.5 

730 b0 -= 1 

731 if e1: 

732 b_Q[a0:b0, a1, a2:b2] += b_Q[a0:b0, b1 - 1, a2:b2] 

733 b_Q[a0:b0, a1, a2:b2] *= 0.5 

734 b1 -= 1 

735 if self.desc.dtype == complex and e2: 

736 b_Q[a0:b0, a1:b1, a2] += b_Q[a0:b0, a1:b1, b2 - 1] 

737 b_Q[a0:b0, a1:b1, a2] *= 0.5 

738 b2 -= 1 

739 

740 a_Q[:] = b_Q[a0:b0, a1:b1, a2:b2] 

741 a_Q[:] = self.xp.fft.ifftshift(a_Q, axes=axes) 

742 plan2.ifft() 

743 out.data[:] = plan2.tmp_R 

744 out.data *= (1.0 / self.data.size) 

745 return out 

746 

747 def abs_square(self, 

748 weights: Array1D, 

749 out: UGArray | None = None) -> None: 

750 """Add weighted absolute square of data to output array.""" 

751 assert out is not None 

752 

753 if self.xp is np: 

754 for f, psit_R in zips(weights, self.data): 

755 add_to_density(f, psit_R, out.data) 

756 elif cupy_is_fake: 

757 for f, psit_R in zips(weights, self.data): 

758 add_to_density(f, psit_R._data, out.data._data) # type: ignore 

759 else: 

760 add_to_density_gpu(self.xp.asarray(weights), self.data, out.data) 

761 

762 def symmetrize(self, rotation_scc, translation_sc): 

763 """Make data symmetric.""" 

764 if len(rotation_scc) == 1: 

765 return 

766 

767 a_xR = self.gather() 

768 

769 if a_xR is None: 

770 b_xR = None 

771 else: 

772 if self.xp is not np: 

773 a_xR = a_xR.to_xp(np) 

774 b_xR = a_xR.new() 

775 t_sc = (translation_sc * self.desc.size_c).round().astype(int) 

776 offset_c = np.array(self.desc.zerobc_c, dtype=int) 

777 for a_R, b_R in zips(a_xR._arrays(), b_xR._arrays()): 

778 b_R[:] = 0.0 

779 for r_cc, t_c in zips(rotation_scc, t_sc): 

780 symmetrize_ft(a_R, b_R, r_cc, t_c, offset_c) 

781 if self.xp is not np: 

782 b_xR = b_xR.to_xp(self.xp) 

783 self.scatter_from(b_xR) 

784 

785 self.data *= 1.0 / len(rotation_scc) 

786 

787 def randomize(self, seed: int | None = None) -> None: 

788 """Insert random numbers between -0.5 and 0.5 into data.""" 

789 if seed is None: 

790 seed = self.comm.rank + self.desc.comm.rank * self.comm.size 

791 rng = self.xp.random.default_rng(seed) 

792 a = self.data.view(float) 

793 rng.random(a.shape, out=a) 

794 a -= 0.5 

795 

796 def moment(self): 

797 """Calculate moment of data.""" 

798 assert self.dims == () 

799 ug = self.desc 

800 

801 index_cr = [np.arange(ug.start_c[c], ug.end_c[c], dtype=float) 

802 for c in range(3)] 

803 for index_r, size in zip(index_cr, ug.size_c): 

804 if index_r[0] == 0: 

805 # We have periodic bc's, so index 0 is the same as index 

806 # size (= last + 1). Include both points with 0.5 weight: 

807 index_r[0] = 0.5 * size 

808 

809 rho_ijk = self.data 

810 rho_ij = rho_ijk.sum(axis=2) 

811 rho_ik = rho_ijk.sum(axis=1) 

812 rho_cr = [rho_ij.sum(axis=1), rho_ij.sum(axis=0), rho_ik.sum(axis=0)] 

813 if self.xp is not np: 

814 rho_cr = [rho_r.get() for rho_r in rho_cr] 

815 

816 d_c = [index_r @ rho_r for index_r, rho_r in zips(index_cr, rho_cr)] 

817 d_v = (d_c / ug.size_c) @ ug.cell_cv * self.dv 

818 self.desc.comm.sum(d_v) 

819 return d_v 

820 

821 def scaled(self, cell: float, values: float = 1.0) -> UGArray: 

822 """Create new scaled UGArray object. 

823 

824 Unit cell axes are multiplied by `cell` and data by `values`. 

825 """ 

826 grid = self.desc 

827 grid = UGDesc(cell=grid.cell_cv * cell, 

828 size=grid.size_c, 

829 pbc=grid.pbc_c, 

830 zerobc=grid.zerobc_c, 

831 kpt=(grid.kpt_c if grid.kpt_c.any() else None), 

832 dtype=grid.dtype, 

833 comm=grid.comm) 

834 return UGArray(grid, self.dims, self.comm, self.data * values) 

835 

836 def add_ked(self, 

837 occ_n: Array1D, 

838 taut_R: UGArray) -> None: 

839 grad_v = [ 

840 Gradient(self.desc._gd, v, n=3, dtype=self.desc.dtype) 

841 for v in range(3)] 

842 tmp_R = self.desc.empty() 

843 for f, psit_R in zips(occ_n, self): 

844 for grad in grad_v: 

845 grad(psit_R, tmp_R) 

846 add_to_density(0.5 * f, tmp_R.data, taut_R.data) 

847 

848 def redist(self, 

849 domain: UGDesc, 

850 comm1: MPIComm, comm2: MPIComm) -> UGArray: 

851 a = super().redist(domain, comm1, comm2) 

852 assert isinstance(a, UGArray) 

853 return a 

854 

855 def isosurface(self, show=True, **kwargs) -> go.Isosurface: 

856 import plotly.graph_objects as go 

857 values = self.data 

858 assert values.ndim == 3 

859 if values.dtype == complex: 

860 values = abs(values) 

861 x, y, z = (c.T.flatten() for c in self.desc.xyz().T) 

862 vmin = values.min() 

863 vmax = values.max() 

864 kwargs = { 

865 'isomin': vmin + (vmax - vmin) * 0.1, 

866 'isomax': vmax - (vmax - vmin) * 0.1, 

867 'caps': dict(x_show=False, 

868 y_show=False, 

869 z_show=False), 

870 **kwargs} 

871 surf = go.Isosurface(x=x, y=y, z=z, value=values.flatten(), 

872 **kwargs) 

873 if show: 

874 go.Figure(data=[surf]).show() 

875 return surf