Coverage for gpaw/poisson.py: 80%

632 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1# Copyright (C) 2003 CAMP 

2# Please see the accompanying LICENSE file for further information. 

3 

4import warnings 

5from math import pi 

6 

7import numpy as np 

8from numpy.fft import fftn, ifftn, fft2, ifft2, rfft2, irfft2, fft, ifft 

9from scipy.fftpack import dst as scipydst 

10 

11from gpaw import PoissonConvergenceError 

12from gpaw.dipole_correction import DipoleCorrection, dipole_correction 

13from gpaw.domain import decompose_domain 

14from gpaw.fd_operators import Laplace, LaplaceA, LaplaceB 

15from gpaw.transformers import Transformer 

16from gpaw.utilities.gauss import Gaussian 

17from gpaw.utilities.grid import grid2grid 

18from gpaw.utilities.ewald import madelung 

19from gpaw.utilities.tools import construct_reciprocal 

20from gpaw.utilities.timing import NullTimer 

21 

22POISSON_GRID_WARNING = """Grid unsuitable for FDPoissonSolver! 

23 

24Consider using FastPoissonSolver instead. 

25 

26The FDPoissonSolver does not have sufficient multigrid levels for good 

27performance and will converge inefficiently if at all, or yield wrong 

28results. 

29 

30You may need to manually specify a grid such that the number of points 

31along each direction is divisible by a high power of 2, such as 8, 16, 

32or 32 depending on system size; examples: 

33 

34 GPAW(gpts=(32, 32, 288)) 

35 

36or 

37 

38 from gpaw.utilities import h2gpts 

39 GPAW(gpts=h2gpts(0.2, atoms.get_cell(), idiv=16)) 

40 

41Parallelizing over very small domains can also undesirably limit the 

42number of multigrid levels even if the total number of grid points 

43is divisible by a high power of 2.""" 

44 

45 

46def create_poisson_solver(name='fast', **kwargs): 

47 if isinstance(name, _PoissonSolver): 

48 return name 

49 elif isinstance(name, dict): 

50 kwargs.update(name) 

51 return create_poisson_solver(**kwargs) 

52 elif name == 'fft': 

53 return FFTPoissonSolver(**kwargs) 

54 elif name == 'fdtd': 

55 from gpaw.fdtd.poisson_fdtd import FDTDPoissonSolver 

56 return FDTDPoissonSolver(**kwargs) 

57 elif name == 'fd': 

58 return FDPoissonSolverWrapper(**kwargs) 

59 elif name == 'fast': 

60 return FastPoissonSolver(**kwargs) 

61 elif name == 'ExtraVacuumPoissonSolver': 

62 from gpaw.poisson_extravacuum import ExtraVacuumPoissonSolver 

63 return ExtraVacuumPoissonSolver(**kwargs) 

64 elif name == 'MomentCorrectionPoissonSolver': 

65 from gpaw.poisson_moment import MomentCorrectionPoissonSolver 

66 return MomentCorrectionPoissonSolver(**kwargs) 

67 elif name == 'nointeraction': 

68 return NoInteractionPoissonSolver() 

69 else: 

70 raise ValueError('Unknown poisson solver: %s' % name) 

71 

72 

73def PoissonSolver(name='fast', dipolelayer=None, zero_vacuum=False, **kwargs): 

74 p = create_poisson_solver(name=name, **kwargs) 

75 if dipolelayer is not None: 

76 p = DipoleCorrection(p, dipolelayer, zero_vacuum=zero_vacuum) 

77 return p 

78 

79 

80def FDPoissonSolverWrapper(dipolelayer=None, zero_vacuum=False, **kwargs): 

81 if dipolelayer is not None: 

82 return DipoleCorrection(FDPoissonSolver(**kwargs), dipolelayer, 

83 zero_vacuum=zero_vacuum) 

84 return FDPoissonSolver(**kwargs) 

85 

86 

87class _PoissonSolver: 

88 """Abstract PoissonSolver class 

89 

90 This class defines an interface and a common ancestor 

91 for various PoissonSolver implementations (including wrappers).""" 

92 def __init__(self): 

93 object.__init__(self) 

94 

95 def set_grid_descriptor(self, gd): 

96 raise NotImplementedError() 

97 

98 def solve(self): 

99 raise NotImplementedError() 

100 

101 def todict(self): 

102 raise NotImplementedError(self.__class__.__name__) 

103 

104 def get_description(self): 

105 return self.__class__.__name__ 

106 

107 def estimate_memory(self, mem): 

108 raise NotImplementedError() 

109 

110 def build(self, grid, xp): 

111 from gpaw.new.poisson import PoissonSolverWrapper 

112 self.xp = xp 

113 self.set_grid_descriptor(grid._gd) 

114 return PoissonSolverWrapper(self) 

115 

116 

117class BasePoissonSolver(_PoissonSolver): 

118 def __init__(self, *, remove_moment=None, 

119 use_charge_center=False, 

120 metallic_electrodes=False, 

121 eps=None, 

122 use_charged_periodic_corrections=False, 

123 xp=np): 

124 

125 self.xp = xp 

126 

127 if eps is not None: 

128 warnings.warn( 

129 "Please do not specify the eps parameter " 

130 f"for {self.__class__.__name__}. " 

131 "The parameter doesn't do anything for this solver " 

132 "and defining it will throw an error in the future.", 

133 FutureWarning) 

134 

135 if remove_moment is not None: 

136 warnings.warn( 

137 "Please do not specify the remove_moment parameter " 

138 f"for {self.__class__.__name__}. " 

139 "The remove moment functionality is deprecated in this solver " 

140 "and will throw an error in the future. Instead " 

141 "use the MomentCorrectionPoissonSolver as a wrapper to " 

142 f"{self.__class__.__name__}.", 

143 FutureWarning) 

144 

145 # metallic electrodes: mirror image method to allow calculation of 

146 # charged, partly periodic systems 

147 self.gd = None 

148 self.remove_moment = remove_moment 

149 self.use_charge_center = use_charge_center 

150 self.use_charged_periodic_corrections = \ 

151 use_charged_periodic_corrections 

152 self.charged_periodic_correction = None 

153 self.eps = eps 

154 self.metallic_electrodes = metallic_electrodes 

155 assert self.metallic_electrodes in [False, None, 'single', 'both'] 

156 

157 def todict(self): 

158 d = {'name': 'basepoisson'} 

159 if self.remove_moment: 

160 d['remove_moment'] = self.remove_moment 

161 if self.use_charge_center: 

162 d['use_charge_center'] = self.use_charge_center 

163 if self.use_charged_periodic_corrections: 

164 d['use_charged_periodic_corrections'] = \ 

165 self.use_charged_periodic_corrections 

166 if self.metallic_electrodes: 

167 d['metallic_electrodes'] = self.metallic_electrodes 

168 

169 return d 

170 

171 def get_description(self): 

172 # The idea is that the subclass writes a header and main parameters, 

173 # then adds the below string. 

174 lines = [] 

175 if self.remove_moment is not None: 

176 lines.append(' Remove moments up to L=%d' % self.remove_moment) 

177 if self.use_charge_center: 

178 lines.append(' Compensate for charged system using center of ' 

179 'majority charge') 

180 if self.use_charged_periodic_corrections: 

181 lines.append(' Subtract potential of homogeneous background') 

182 

183 return '\n'.join(lines) 

184 

185 def solve(self, phi, rho, charge=None, maxcharge=1e-6, 

186 zero_initial_phi=False, timer=NullTimer()): 

187 self._init() 

188 assert np.all(phi.shape == self.gd.n_c) 

189 assert np.all(rho.shape == self.gd.n_c) 

190 

191 actual_charge = self.gd.integrate(rho) 

192 background = (actual_charge / self.gd.dv / 

193 self.gd.get_size_of_global_array().prod()) 

194 

195 if self.remove_moment: 

196 assert not self.gd.pbc_c.any() 

197 if not hasattr(self, 'gauss'): 

198 self.gauss = Gaussian(self.gd) 

199 rho_neutral = rho.copy() 

200 phi_cor_L = [] 

201 for L in range(self.remove_moment): 

202 phi_cor_L.append(self.gauss.remove_moment(rho_neutral, L)) 

203 # Remove multipoles for better initial guess 

204 for phi_cor in phi_cor_L: 

205 phi -= phi_cor 

206 

207 niter = self.solve_neutral(phi, rho_neutral, timer=timer) 

208 # correct error introduced by removing multipoles 

209 for phi_cor in phi_cor_L: 

210 phi += phi_cor 

211 

212 return niter 

213 if charge is None: 

214 charge = actual_charge 

215 if abs(charge) <= maxcharge: 

216 return self.solve_neutral(phi, rho - background, timer=timer) 

217 

218 elif abs(charge) > maxcharge and self.gd.pbc_c.all(): 

219 # System is charged and periodic. Subtract a homogeneous 

220 # background charge 

221 

222 # Set initial guess for potential 

223 if zero_initial_phi: 

224 phi[:] = 0.0 

225 

226 iters = self.solve_neutral(phi, rho - background, timer=timer) 

227 

228 if self.use_charged_periodic_corrections: 

229 if self.charged_periodic_correction is None: 

230 self.charged_periodic_correction = madelung( 

231 self.gd.cell_cv) 

232 phi += actual_charge * self.charged_periodic_correction 

233 

234 return iters 

235 

236 elif abs(charge) > maxcharge and not self.gd.pbc_c.any(): 

237 # The system is charged and in a non-periodic unit cell. 

238 # Determine the potential by 1) subtract a gaussian from the 

239 # density, 2) determine potential from the neutralized density 

240 # and 3) add the potential from the gaussian density. 

241 

242 # Load necessary attributes 

243 

244 # use_charge_center: The monopole will be removed at the 

245 # center of the majority charge, which prevents artificial 

246 # dipoles. 

247 # Due to the shape of the Gaussian and it's Fourier-Transform, 

248 # the Gaussian representing the charge should stay at least 

249 # 7 gpts from the borders - see: 

250 # listserv.fysik.dtu.dk/pipermail/gpaw-developers/2015-July/005806.html 

251 if self.use_charge_center: 

252 charge_sign = actual_charge / abs(actual_charge) 

253 rho_sign = rho * charge_sign 

254 rho_sign[np.where(rho_sign < 0)] = 0 

255 absolute_charge = self.gd.integrate(rho_sign) 

256 center = - (self.gd.calculate_dipole_moment(rho_sign) / 

257 absolute_charge) 

258 border_offset = np.inner(self.gd.h_cv, np.array((7, 7, 7))) 

259 borders = np.inner(self.gd.h_cv, self.gd.N_c) 

260 borders -= border_offset 

261 if np.any(center > borders) or np.any(center < border_offset): 

262 raise RuntimeError('Poisson solver: ' 

263 'center of charge outside borders ' 

264 '- please increase box') 

265 center[np.where(center > borders)] = borders 

266 self.load_gauss(center=center) 

267 else: 

268 self.load_gauss() 

269 

270 # Remove monopole moment 

271 q = actual_charge / np.sqrt(4 * pi) # Monopole moment 

272 rho_neutral = rho - q * self.rho_gauss # neutralized density 

273 

274 # Set initial guess for potential 

275 if zero_initial_phi: 

276 phi[:] = 0.0 

277 else: 

278 phi -= q * self.phi_gauss 

279 

280 # Determine potential from neutral density using standard solver 

281 niter = self.solve_neutral(phi, rho_neutral, timer=timer) 

282 

283 # correct error introduced by removing monopole 

284 phi += q * self.phi_gauss 

285 

286 return niter 

287 else: 

288 # System is charged with mixed boundaryconditions 

289 if self.metallic_electrodes == 'single': 

290 self.c = 2 

291 origin_c = [0, 0, 0] 

292 origin_c[self.c] = self.gd.N_c[self.c] 

293 drhot_g, dvHt_g, self.correction = dipole_correction( 

294 self.c, 

295 self.gd, 

296 rho, 

297 origin_c=origin_c) 

298 # self.correction *=-1. 

299 phi -= dvHt_g 

300 iters = self.solve_neutral(phi, rho + drhot_g, timer=timer) 

301 phi += dvHt_g 

302 phi -= self.correction 

303 self.correction = 0.0 

304 

305 return iters 

306 

307 elif self.metallic_electrodes == 'both': 

308 iters = self.solve_neutral(phi, rho, timer=timer) 

309 return iters 

310 

311 else: 

312 # System is charged with mixed boundaryconditions 

313 msg = ('Charged systems with mixed periodic/zero' 

314 ' boundary conditions') 

315 raise NotImplementedError(msg) 

316 

317 def load_gauss(self, center=None): 

318 if not hasattr(self, 'rho_gauss') or center is not None: 

319 gauss = Gaussian(self.gd, center=center) 

320 self.rho_gauss = self.xp.asarray(gauss.get_gauss(0)) 

321 self.phi_gauss = self.xp.asarray(gauss.get_gauss_pot(0)) 

322 

323 

324class FDPoissonSolver(BasePoissonSolver): 

325 def __init__(self, nn=3, relax='J', eps=2e-10, maxiter=1000, 

326 remove_moment=None, use_charge_center=False, 

327 metallic_electrodes=False, 

328 use_charged_periodic_corrections=False, **kwargs): 

329 super(FDPoissonSolver, self).__init__( 

330 remove_moment=remove_moment, 

331 use_charge_center=use_charge_center, 

332 metallic_electrodes=metallic_electrodes, 

333 use_charged_periodic_corrections=use_charged_periodic_corrections, 

334 **kwargs) 

335 self.eps = eps 

336 self.relax = relax 

337 self.nn = nn 

338 self.maxiter = maxiter 

339 

340 # Relaxation method 

341 if relax == 'GS': 

342 # Gauss-Seidel 

343 self.relax_method = 1 

344 elif relax == 'J': 

345 # Jacobi 

346 self.relax_method = 2 

347 else: 

348 raise NotImplementedError('Relaxation method %s' % relax) 

349 

350 self.description = None 

351 self._initialized = False 

352 

353 def todict(self): 

354 d = super().todict() 

355 d.update({'name': 'fd', 'nn': self.nn, 'relax': self.relax, 

356 'eps': self.eps}) 

357 return d 

358 

359 def get_stencil(self): 

360 return self.nn 

361 

362 def create_laplace(self, gd, scale=1.0, n=1, dtype=float): 

363 """Instantiate and return a Laplace operator 

364 

365 Allows subclasses to change the Laplace operator 

366 """ 

367 return Laplace(gd, scale, n, dtype, xp=self.xp) 

368 

369 def set_grid_descriptor(self, gd): 

370 # Should probably be renamed initialize 

371 self.gd = gd 

372 scale = -0.25 / pi 

373 

374 if self.nn == 'M': 

375 if not gd.orthogonal: 

376 raise RuntimeError('Cannot use Mehrstellen stencil with ' 

377 'non orthogonal cell.') 

378 

379 self.operators = [LaplaceA(gd, -scale, xp=self.xp)] 

380 self.B = LaplaceB(gd) 

381 else: 

382 self.operators = [self.create_laplace(gd, scale, self.nn)] 

383 self.B = None 

384 

385 self.interpolators = [] 

386 self.restrictors = [] 

387 

388 level = 0 

389 self.presmooths = [2] 

390 self.postsmooths = [1] 

391 

392 # Weights for the relaxation, 

393 # only used if 'J' (Jacobi) is chosen as method 

394 self.weights = [2.0 / 3.0] 

395 

396 while level < 8: 

397 try: 

398 gd2 = gd.coarsen() 

399 except ValueError: 

400 break 

401 self.operators.append(self.create_laplace(gd2, scale, 1)) 

402 self.interpolators.append(Transformer(gd2, gd, xp=self.xp)) 

403 self.restrictors.append(Transformer(gd, gd2, xp=self.xp)) 

404 self.presmooths.append(4) 

405 self.postsmooths.append(4) 

406 self.weights.append(1.0) 

407 level += 1 

408 gd = gd2 

409 

410 self.levels = level 

411 

412 if self.operators[-1].gd.N_c.max() > 36: 

413 # Try to warn exactly once no matter how one uses the solver. 

414 if gd.comm.parent is None: 

415 warn = (gd.comm.rank == 0) 

416 else: 

417 warn = (gd.comm.parent.rank == 0) 

418 

419 if warn: 

420 warntxt = '\n'.join([POISSON_GRID_WARNING, '', 

421 self.get_description()]) 

422 else: 

423 warntxt = ('Poisson warning from domain rank %d' 

424 % self.gd.comm.rank) 

425 

426 # Warn from all ranks to avoid deadlocks. 

427 warnings.warn(warntxt, stacklevel=2) 

428 

429 self._initialized = False 

430 # The Gaussians depend on the grid as well so we have to 'unload' them 

431 if hasattr(self, 'rho_gauss'): 

432 del self.rho_gauss 

433 del self.phi_gauss 

434 

435 def get_description(self): 

436 name = {1: 'Gauss-Seidel', 2: 'Jacobi'}[self.relax_method] 

437 coarsest_grid = self.operators[-1].gd.N_c 

438 coarsest_grid_string = ' x '.join([str(N) for N in coarsest_grid]) 

439 assert self.levels + 1 == len(self.operators) 

440 lines = ['%s solver with %d multi-grid levels' 

441 % (name, self.levels + 1), 

442 ' Coarsest grid: %s points' % coarsest_grid_string] 

443 if coarsest_grid.max() > 24: 

444 # This friendly warning has lower threshold than the big long 

445 # one that we print when things are really bad. 

446 lines.extend([' Warning: Coarse grid has more than 24 points.', 

447 ' More multi-grid levels recommended.']) 

448 lines.extend([' Stencil: %s' % self.operators[0].description, 

449 ' Max iterations: %d' % self.maxiter]) 

450 lines.extend([' Tolerance: %e' % self.eps]) 

451 lines.append(super().get_description()) 

452 return '\n'.join(lines) 

453 

454 def _init(self): 

455 if self._initialized: 

456 return 

457 # Should probably be renamed allocate 

458 gd = self.gd 

459 self.rhos = [gd.empty(xp=self.xp)] 

460 self.phis = [None] 

461 self.residuals = [gd.empty(xp=self.xp)] 

462 for level in range(self.levels): 

463 gd2 = gd.coarsen() 

464 self.phis.append(gd2.empty(xp=self.xp)) 

465 self.rhos.append(gd2.empty(xp=self.xp)) 

466 self.residuals.append(gd2.empty(xp=self.xp)) 

467 gd = gd2 

468 assert len(self.phis) == len(self.rhos) 

469 level += 1 

470 assert level == self.levels 

471 

472 self.step = 0.66666666 / self.operators[0].get_diagonal_element() 

473 self.presmooths[level] = 8 

474 self.postsmooths[level] = 8 

475 self._initialized = True 

476 

477 def solve_neutral(self, phi, rho, timer=None): 

478 self._init() 

479 self.phis[0] = phi 

480 eps = self.eps 

481 if self.B is None: 

482 self.rhos[0][:] = rho 

483 else: 

484 self.B.apply(rho, self.rhos[0]) 

485 

486 niter = 1 

487 maxiter = self.maxiter 

488 while self.iterate2(self.step) > eps and niter < maxiter: 

489 niter += 1 

490 if niter == maxiter: 

491 msg = 'Poisson solver did not converge in %d iterations!' % maxiter 

492 raise PoissonConvergenceError(msg) 

493 

494 # Set the average potential to zero in periodic systems 

495 if (self.gd.pbc_c).all(): 

496 phi_ave = self.gd.comm.sum_scalar(float(np.sum(phi.ravel()))) 

497 N_c = self.gd.get_size_of_global_array() 

498 phi_ave /= np.prod(N_c) 

499 phi -= phi_ave 

500 

501 return niter 

502 

503 def iterate2(self, step, level=0): 

504 """Smooths the solution in every multigrid level""" 

505 self._init() 

506 

507 residual = self.residuals[level] 

508 

509 if level < self.levels: 

510 self.operators[level].relax(self.relax_method, 

511 self.phis[level], 

512 self.rhos[level], 

513 self.presmooths[level], 

514 self.weights[level]) 

515 

516 self.operators[level].apply(self.phis[level], residual) 

517 residual -= self.rhos[level] 

518 self.restrictors[level].apply(residual, 

519 self.rhos[level + 1]) 

520 self.phis[level + 1][:] = 0.0 

521 self.iterate2(4.0 * step, level + 1) 

522 self.interpolators[level].apply(self.phis[level + 1], residual) 

523 self.phis[level] -= residual 

524 

525 self.operators[level].relax(self.relax_method, 

526 self.phis[level], 

527 self.rhos[level], 

528 self.postsmooths[level], 

529 self.weights[level]) 

530 if level == 0: 

531 self.operators[level].apply(self.phis[level], residual) 

532 residual -= self.rhos[level] 

533 error = self.gd.comm.sum_scalar( 

534 float(self.xp.dot(residual.ravel(), 

535 residual.ravel()))) * self.gd.dv 

536 

537 # How about this instead: 

538 # error = self.gd.comm.max(abs(residual).max()) 

539 

540 return error 

541 

542 def estimate_memory(self, mem): 

543 # XXX Memory estimate works only for J and GS, not FFT solver 

544 # Poisson solver appears to use same amount of memory regardless 

545 # of whether it's J or GS, which is a bit strange 

546 

547 gdbytes = self.gd.bytecount() 

548 nbytes = -gdbytes # No phi on finest grid, compensate ahead 

549 for level in range(self.levels): 

550 nbytes += 3 * gdbytes # Arrays: rho, phi, residual 

551 gdbytes //= 8 

552 mem.subnode('rho, phi, residual [%d levels]' % self.levels, nbytes) 

553 

554 def __repr__(self): 

555 template = 'FDPoissonSolver(relax=\'%s\', nn=%s, eps=%e)' 

556 representation = template % (self.relax, repr(self.nn), self.eps) 

557 return representation 

558 

559 

560class NoInteractionPoissonSolver(_PoissonSolver): 

561 relax_method = 0 

562 nn = 1 

563 

564 def get_description(self): 

565 return 'No interaction' 

566 

567 def get_stencil(self): 

568 return 1 

569 

570 def solve(self, phi, rho, charge, timer=None): 

571 return 0 

572 

573 def set_grid_descriptor(self, gd): 

574 pass 

575 

576 def todict(self): 

577 return {'name': 'nointeraction'} 

578 

579 def estimate_memory(self, mem): 

580 pass 

581 

582 

583class FFTPoissonSolver(BasePoissonSolver): 

584 """FFT Poisson solver for general unit cells.""" 

585 

586 relax_method = 0 

587 nn = 999 

588 

589 def __init__(self, **kwargs): 

590 super().__init__(**kwargs) 

591 self._initialized = False 

592 

593 def get_description(self): 

594 return 'Parallel FFT' 

595 

596 def todict(self): 

597 return {'name': 'fft'} 

598 

599 def set_grid_descriptor(self, gd): 

600 # We will probably want to use this on non-periodic grids too... 

601 assert gd.pbc_c.all() 

602 self.gd = gd 

603 

604 self.grids = [gd] 

605 for c in range(3): 

606 N_c = gd.N_c.copy() 

607 N_c[c] = 1 # Will be serial in that direction 

608 parsize_c = decompose_domain(N_c, gd.comm.size) 

609 self.grids.append(gd.new_descriptor(parsize_c=parsize_c)) 

610 self._initialized = False 

611 

612 def _init(self): 

613 if self._initialized: 

614 return 

615 

616 gd = self.grids[-1] 

617 k2_Q, N3 = construct_reciprocal(gd) 

618 self.poisson_factor_Q = 4.0 * np.pi / k2_Q 

619 self._initialized = True 

620 

621 def solve_neutral(self, phi_g, rho_g, timer=None): 

622 self._init() 

623 # Will be a bit more efficient if reduced dimension is always 

624 # contiguous. Probably more things can be improved... 

625 

626 gd1 = self.gd 

627 work1_g = rho_g 

628 

629 for c in range(3): 

630 gd2 = self.grids[c + 1] 

631 work2_g = gd2.empty(dtype=work1_g.dtype) 

632 grid2grid(gd1.comm, gd1, gd2, work1_g, work2_g) 

633 work1_g = fftn(work2_g, axes=[c]) 

634 gd1 = gd2 

635 

636 work1_g *= self.poisson_factor_Q 

637 

638 for c in [2, 1, 0]: 

639 gd2 = self.grids[c] 

640 work2_g = ifftn(work1_g, axes=[c]) 

641 work1_g = gd2.empty(dtype=work2_g.dtype) 

642 grid2grid(gd1.comm, gd1, gd2, work2_g, work1_g) 

643 gd1 = gd2 

644 

645 phi_g[:] = work1_g.real 

646 return 1 

647 

648 def estimate_memory(self, mem): 

649 mem.subnode('k squared', self.grids[-1].bytecount()) 

650 

651 

652"""def rfst2(A_g, axes=[0,1]): 

653 assert axes[0] == 0 

654 assert axes[1] == 1 

655 x,y,z = A_g.shape 

656 temp_g = np.zeros((x*2+2, y*2+2, z)) 

657 temp_g[1:x+1, 1:y+1,:] = A_g 

658 temp_g[x+2:, 1:y+1,:] = -A_g[::-1, :, :] 

659 temp_g[1:x+1, y+2:,:] = -A_g[:, ::-1, :] 

660 temp_g[x+2:, y+2:,:] = A_g[::-1, ::-1, :] 

661 X = -4*rfft2(temp_g, axes=axes) 

662 return X[1:x+1, 1:y+1, :].real 

663 

664def irfst2(A_g, axes=[0,1]): 

665 assert axes[0] == 0 

666 assert axes[1] == 1 

667 x,y,z = A_g.shape 

668 temp_g = np.zeros((x*2+2, (y*2+2)//2+1, z)) 

669 temp_g[1:x+1, 1:y+1,:] = A_g 

670 temp_g[x+2:, 1:y+1,:] = -A_g[::-1, :, :] 

671 return -0.25*irfft2(temp_g, axes=axes)[1:x+1, 1:y+1, :].real 

672""" 

673 

674 

675use_scipy_transforms = True 

676 

677 

678def rfst2(A_g, axes=[0, 1]): 

679 all = {0, 1, 2} 

680 third = [all.difference(set(axes)).pop()] 

681 

682 if use_scipy_transforms: 

683 Y = A_g 

684 for axis in axes: 

685 Y = scipydst(Y, axis=axis, type=1) 

686 Y *= 2**len(axes) 

687 return Y 

688 

689 A_g = np.transpose(A_g, axes + third) 

690 x, y, z = A_g.shape 

691 temp_g = np.zeros((x * 2 + 2, y * 2 + 2, z)) 

692 temp_g[1:x + 1, 1:y + 1, :] = A_g 

693 temp_g[x + 2:, 1:y + 1, :] = -A_g[::-1, :, :] 

694 temp_g[1:x + 1, y + 2:, :] = -A_g[:, ::-1, :] 

695 temp_g[x + 2:, y + 2:, :] = A_g[::-1, ::-1, :] 

696 X = -4 * rfft2(temp_g, axes=[0, 1])[1:x + 1, 1:y + 1, :].real 

697 return np.transpose(X, np.argsort(axes + third)) 

698 

699 

700def irfst2(A_g, axes=[0, 1]): 

701 if use_scipy_transforms: 

702 Y = A_g 

703 for axis in axes: 

704 Y = scipydst(Y, axis=axis, type=1) 

705 magic = 1.0 / (16 * np.prod([A_g.shape[axis] + 1 for axis in axes])) 

706 Y *= magic 

707 # Y /= 211200 

708 return Y 

709 

710 all = {0, 1, 2} 

711 third = [all.difference(set(axes)).pop()] 

712 A_g = np.transpose(A_g, axes + third) 

713 x, y, z = A_g.shape 

714 temp_g = np.zeros((x * 2 + 2, (y * 2 + 2) // 2 + 1, z)) 

715 temp_g[1:x + 1, 1:y + 1, :] = A_g.real 

716 temp_g[x + 2:, 1:y + 1, :] = -A_g[::-1, :, :].real 

717 X = -0.25 * irfft2(temp_g, axes=[0, 1])[1:x + 1, 1:y + 1, :] 

718 

719 T = np.transpose(X, np.argsort(axes + third)) 

720 return T 

721 

722 

723# This method needs to be taken from fftw / scipy to gain speedup of ~4x 

724def fst(A_g, axis): 

725 x, y, z = A_g.shape 

726 N_c = np.array([x, y, z]) 

727 N_c[axis] = N_c[axis] * 2 + 2 

728 temp_g = np.zeros(N_c, dtype=A_g.dtype) 

729 if axis == 0: 

730 temp_g[1:x + 1, :, :] = A_g 

731 temp_g[x + 2:, :, :] = -A_g[::-1, :, :] 

732 elif axis == 1: 

733 temp_g[:, 1:y + 1, :] = A_g 

734 temp_g[:, y + 2:, :] = -A_g[:, ::-1, :] 

735 elif axis == 2: 

736 temp_g[:, :, 1:z + 1] = A_g 

737 temp_g[:, :, z + 2:] = -A_g[:, ::, ::-1] 

738 else: 

739 raise NotImplementedError() 

740 X = 0.5j * fft(temp_g, axis=axis) 

741 if axis == 0: 

742 return X[1:x + 1, :, :] 

743 elif axis == 1: 

744 return X[:, 1:y + 1, :] 

745 elif axis == 2: 

746 return X[:, :, 1:z + 1] 

747 

748 

749def ifst(A_g, axis): 

750 x, y, z = A_g.shape 

751 N_c = np.array([x, y, z]) 

752 N_c[axis] = N_c[axis] * 2 + 2 

753 temp_g = np.zeros(N_c, dtype=A_g.dtype) 

754 

755 if axis == 0: 

756 temp_g[1:x + 1, :, :] = A_g 

757 temp_g[x + 2:, :, :] = -A_g[::-1, :, :] 

758 elif axis == 1: 

759 temp_g[:, 1:y + 1, :] = A_g 

760 temp_g[:, y + 2:, :] = -A_g[:, ::-1, :] 

761 elif axis == 2: 

762 temp_g[:, :, 1:z + 1] = A_g 

763 temp_g[:, :, z + 2:] = -A_g[:, ::, ::-1] 

764 else: 

765 raise NotImplementedError() 

766 

767 X_g = ifft(temp_g, axis=axis) 

768 if axis == 0: 

769 return -2j * X_g[1:x + 1, :, :] 

770 elif axis == 1: 

771 return -2j * X_g[:, 1:y + 1, :] 

772 elif axis == 2: 

773 return -2j * X_g[:, :, 1:z + 1] 

774 

775 

776def transform(A_g, axis=None, pbc=True): 

777 if pbc: 

778 if A_g.size == 0: 

779 return A_g.astype(complex) 

780 

781 return fft(A_g, axis=axis) 

782 else: 

783 if A_g.size == 0: 

784 return A_g 

785 

786 if not use_scipy_transforms: 

787 x = fst(A_g, axis) 

788 return x 

789 y = scipydst(A_g, axis=axis, type=1) 

790 y *= .5 

791 return y 

792 

793 

794def transform2(A_g, axes=None, pbc=[True, True]): 

795 if all(pbc): 

796 if A_g.size == 0: 

797 return A_g.astype(complex) 

798 

799 return fft2(A_g, axes=axes) 

800 elif not any(pbc): 

801 if A_g.size == 0: 

802 return A_g 

803 

804 return rfst2(A_g, axes=axes) 

805 else: 

806 return transform(transform(A_g, axis=axes[0], pbc=pbc[0]), 

807 axis=axes[1], pbc=pbc[1]) 

808 

809 

810def itransform(A_g, axis=None, pbc=True): 

811 if pbc: 

812 if A_g.size == 0: 

813 return A_g.astype(complex) 

814 

815 return ifft(A_g, axis=axis) 

816 else: 

817 if A_g.size == 0: 

818 return A_g 

819 

820 if not use_scipy_transforms: 

821 x = ifst(A_g, axis) 

822 return x 

823 y = scipydst(A_g, axis=axis, type=1) 

824 magic = 1.0 / (A_g.shape[axis] + 1) 

825 y *= magic 

826 return y 

827 

828 

829def itransform2(A_g, axes=None, pbc=[True, True]): 

830 if all(pbc): 

831 if A_g.size == 0: 

832 return A_g.astype(complex) 

833 

834 return ifft2(A_g, axes=axes) 

835 elif not any(pbc): 

836 if A_g.size == 0: 

837 return A_g 

838 

839 return irfst2(A_g, axes=axes) 

840 else: 

841 return itransform(itransform(A_g, axis=axes[0], pbc=pbc[0]), 

842 axis=axes[1], pbc=pbc[1]) 

843 

844 

845class BadAxesError(ValueError): 

846 pass 

847 

848 

849class FastPoissonSolver(BasePoissonSolver): 

850 def __init__(self, nn=3, **kwargs): 

851 BasePoissonSolver.__init__(self, **kwargs) 

852 self.nn = nn 

853 # We may later enable this to work with Cholesky, but not now: 

854 self.use_cholesky = False 

855 

856 def _init(self): 

857 pass 

858 

859 def set_grid_descriptor(self, gd): 

860 self.gd = gd 

861 axes = np.arange(3) 

862 pbc_c = np.array(gd.pbc_c, dtype=bool) 

863 periodic_axes = axes[pbc_c] 

864 non_periodic_axes = axes[np.logical_not(pbc_c)] 

865 

866 # Find out which axes are orthogonal (0, 1 or 3) 

867 # Note that one expects that the axes are always rotated in 

868 # conventional form, thus for all axes to be 

869 # classified as orthogonal, the cell_cv needs to be diagonal. 

870 # This may always be achieved by rotating 

871 # the unit-cell along with the atoms. The classification is 

872 # inherited from grid_descriptor.orthogonal. 

873 dotprods = np.dot(gd.cell_cv, gd.cell_cv.T) 

874 # For each direction, check whether there is only one nonzero 

875 # element in that row (necessarily being the diagonal element, 

876 # since this is a cell vector length and must be > 0). 

877 orthogonal_c = (np.abs(dotprods) > 1e-10).sum(axis=0) == 1 

878 assert sum(orthogonal_c) in [0, 1, 3] 

879 

880 non_orthogonal_axes = axes[np.logical_not(orthogonal_c)] 

881 

882 if not all(pbc_c | orthogonal_c): 

883 raise BadAxesError('Each axis must be periodic or orthogonal ' 

884 'to other axes. But we have pbc={} ' 

885 'and orthogonal={}' 

886 .format(pbc_c.astype(int), 

887 orthogonal_c.astype(int))) 

888 

889 # We sort them, and pick the longest non-periodic axes as the 

890 # cholesky axis. 

891 sorted_non_periodic_axes = sorted(non_periodic_axes, 

892 key=lambda c: gd.N_c[c]) 

893 if self.use_cholesky: 

894 if len(sorted_non_periodic_axes) > 0: 

895 cholesky_axes = [sorted_non_periodic_axes[-1]] 

896 if cholesky_axes[0] in non_orthogonal_axes: 

897 msg = ('Cholesky axis cannot be non-orthogonal. ' 

898 'Do you really want a non-orthogonal non-periodic ' 

899 'axis? If so, run with use_cholesky=False.') 

900 raise NotImplementedError(msg) 

901 fst_axes = sorted_non_periodic_axes[0:-1] 

902 else: 

903 cholesky_axes = [] 

904 fst_axes = [] 

905 else: 

906 cholesky_axes = [] 

907 fst_axes = sorted_non_periodic_axes 

908 fft_axes = list(periodic_axes) 

909 

910 (self.cholesky_axes, self.fst_axes, 

911 self.fft_axes) = cholesky_axes, fst_axes, fft_axes 

912 

913 fftfst_axes = self.fft_axes + self.fst_axes 

914 axes = self.fft_axes + self.fst_axes + self.cholesky_axes 

915 self.axes = axes 

916 

917 # Create xy flat decomposition (where x=axes[0] and y=axes[1]) 

918 parsize_c = [1, 1, 1] 

919 parsize_c[axes[2]] = gd.comm.size 

920 gd1d = gd.new_descriptor(parsize_c=parsize_c, 

921 allow_empty_domains=True) 

922 self.gd1d = gd1d 

923 

924 # Create z flat decomposition 

925 domain = gd.N_c.copy() 

926 domain[axes[2]] = 1 

927 parsize_c = decompose_domain(domain, gd.comm.size) 

928 gd2d = gd.new_descriptor(parsize_c=parsize_c) 

929 self.gd2d = gd2d 

930 

931 # Calculate eigenvalues in fst/fft decomposition for 

932 # non-cholesky axes in parallel 

933 xp = self.xp 

934 r_cx = xp.indices(gd2d.n_c) 

935 r_cx += xp.asarray(gd2d.beg_c[:, xp.newaxis, xp.newaxis, xp.newaxis]) 

936 r_cx = r_cx.astype(complex) 

937 for c, axis in enumerate(fftfst_axes): 

938 r_cx[axis] *= 2j * xp.pi / gd2d.N_c[axis] 

939 if axis in fst_axes: 

940 r_cx[axis] /= 2 

941 for c, axis in enumerate(cholesky_axes): 

942 r_cx[axis] = 0.0 

943 xp.exp(r_cx, out=r_cx) 

944 fft_lambdas = xp.zeros_like(r_cx[0], dtype=complex) 

945 laplace = Laplace(self.gd, -0.25 / pi, self.nn) 

946 self.stencil_description = laplace.description 

947 

948 for coeff, offset_c in zip(laplace.coef_p, laplace.offset_pc): 

949 offset_c = np.array(offset_c) 

950 if not any(offset_c): 

951 # The centerpoint is handled with (temp-1.0) 

952 continue 

953 non_zero_axes, = np.where(offset_c) 

954 if set(non_zero_axes).issubset(fftfst_axes): 

955 temp = xp.ones_like(fft_lambdas) 

956 for c, axis in enumerate(fftfst_axes): 

957 temp *= r_cx[axis] ** offset_c[axis] 

958 fft_lambdas += coeff * (temp - 1.0) 

959 

960 assert xp.linalg.norm(fft_lambdas.imag) < 1e-10 

961 fft_lambdas = fft_lambdas.real.copy() # arr.real is not contiguous 

962 

963 # If there is no Cholesky decomposition, the system is already 

964 # fully diagonal and we can directly invert the linear problem 

965 # by dividing with the eigenvalues. 

966 # TODO: Remove cholesky alltogether, since it is not used 

967 assert len(cholesky_axes) == 0 

968 # if len(cholesky_axes) == 0: 

969 with np.errstate(divide='ignore'): 

970 self.inv_fft_lambdas = xp.where( 

971 xp.abs(fft_lambdas) > 1e-10, 1.0 / fft_lambdas, 0) 

972 

973 def solve_neutral(self, phi_g, rho_g, timer=None): 

974 if len(self.cholesky_axes) != 0: 

975 raise NotImplementedError 

976 

977 gd = self.gd 

978 gd1d = self.gd1d 

979 gd2d = self.gd2d 

980 comm = self.gd.comm 

981 axes = self.axes 

982 

983 with timer('Communicate to 1D'): 

984 work1d_g = gd1d.empty(dtype=rho_g.dtype, xp=self.xp) 

985 grid2grid(comm, gd, gd1d, rho_g, work1d_g, xp=self.xp) 

986 with timer('FFT 2D'): 

987 work1d_g = transform2(work1d_g, axes=axes[:2], 

988 pbc=gd.pbc_c[axes[:2]]) 

989 with timer('Communicate to 2D'): 

990 work2d_g = gd2d.empty(dtype=work1d_g.dtype, xp=self.xp) 

991 grid2grid(comm, gd1d, gd2d, work1d_g, work2d_g, xp=self.xp) 

992 with timer('FFT 1D'): 

993 work2d_g = transform(work2d_g, axis=axes[2], 

994 pbc=gd.pbc_c[axes[2]]) 

995 

996 # The remaining problem is 0D dimensional, i.e the problem 

997 # has been fully diagonalized 

998 work2d_g *= self.inv_fft_lambdas 

999 

1000 with timer('FFT 1D'): 

1001 work2d_g = itransform(work2d_g, axis=axes[2], 

1002 pbc=gd.pbc_c[axes[2]]) 

1003 with timer('Communicate from 2D'): 

1004 work1d_g = gd1d.empty(dtype=work2d_g.dtype, xp=self.xp) 

1005 grid2grid(comm, gd2d, gd1d, work2d_g, work1d_g, xp=self.xp) 

1006 with timer('FFT 2D'): 

1007 work1d_g = itransform2(work1d_g, axes=axes[1::-1], 

1008 pbc=gd.pbc_c[axes[1::-1]]) 

1009 with timer('Communicate from 1D'): 

1010 work_g = gd.empty(dtype=work1d_g.dtype, xp=self.xp) 

1011 grid2grid(comm, gd1d, gd, work1d_g, work_g, xp=self.xp) 

1012 

1013 phi_g[:] = work_g.real 

1014 return 1 # Non-iterative method, return 1 iteration 

1015 

1016 def todict(self): 

1017 d = super().todict() 

1018 d.update({'name': 'fast', 'nn': self.nn}) 

1019 return d 

1020 

1021 def estimate_memory(self, mem): 

1022 pass 

1023 

1024 def get_description(self): 

1025 lines = [f'{self.__class__.__name__} using', 

1026 f' Stencil: {self.stencil_description}', 

1027 f' FFT axes: {self.fft_axes}', 

1028 f' FST axes: {self.fst_axes}', 

1029 ] 

1030 lines.append(BasePoissonSolver.get_description(self)) 

1031 return '\n'.join(lines)