Coverage for gpaw/poisson.py: 80%
632 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1# Copyright (C) 2003 CAMP
2# Please see the accompanying LICENSE file for further information.
4import warnings
5from math import pi
7import numpy as np
8from numpy.fft import fftn, ifftn, fft2, ifft2, rfft2, irfft2, fft, ifft
9from scipy.fftpack import dst as scipydst
11from gpaw import PoissonConvergenceError
12from gpaw.dipole_correction import DipoleCorrection, dipole_correction
13from gpaw.domain import decompose_domain
14from gpaw.fd_operators import Laplace, LaplaceA, LaplaceB
15from gpaw.transformers import Transformer
16from gpaw.utilities.gauss import Gaussian
17from gpaw.utilities.grid import grid2grid
18from gpaw.utilities.ewald import madelung
19from gpaw.utilities.tools import construct_reciprocal
20from gpaw.utilities.timing import NullTimer
22POISSON_GRID_WARNING = """Grid unsuitable for FDPoissonSolver!
24Consider using FastPoissonSolver instead.
26The FDPoissonSolver does not have sufficient multigrid levels for good
27performance and will converge inefficiently if at all, or yield wrong
28results.
30You may need to manually specify a grid such that the number of points
31along each direction is divisible by a high power of 2, such as 8, 16,
32or 32 depending on system size; examples:
34 GPAW(gpts=(32, 32, 288))
36or
38 from gpaw.utilities import h2gpts
39 GPAW(gpts=h2gpts(0.2, atoms.get_cell(), idiv=16))
41Parallelizing over very small domains can also undesirably limit the
42number of multigrid levels even if the total number of grid points
43is divisible by a high power of 2."""
46def create_poisson_solver(name='fast', **kwargs):
47 if isinstance(name, _PoissonSolver):
48 return name
49 elif isinstance(name, dict):
50 kwargs.update(name)
51 return create_poisson_solver(**kwargs)
52 elif name == 'fft':
53 return FFTPoissonSolver(**kwargs)
54 elif name == 'fdtd':
55 from gpaw.fdtd.poisson_fdtd import FDTDPoissonSolver
56 return FDTDPoissonSolver(**kwargs)
57 elif name == 'fd':
58 return FDPoissonSolverWrapper(**kwargs)
59 elif name == 'fast':
60 return FastPoissonSolver(**kwargs)
61 elif name == 'ExtraVacuumPoissonSolver':
62 from gpaw.poisson_extravacuum import ExtraVacuumPoissonSolver
63 return ExtraVacuumPoissonSolver(**kwargs)
64 elif name == 'MomentCorrectionPoissonSolver':
65 from gpaw.poisson_moment import MomentCorrectionPoissonSolver
66 return MomentCorrectionPoissonSolver(**kwargs)
67 elif name == 'nointeraction':
68 return NoInteractionPoissonSolver()
69 else:
70 raise ValueError('Unknown poisson solver: %s' % name)
73def PoissonSolver(name='fast', dipolelayer=None, zero_vacuum=False, **kwargs):
74 p = create_poisson_solver(name=name, **kwargs)
75 if dipolelayer is not None:
76 p = DipoleCorrection(p, dipolelayer, zero_vacuum=zero_vacuum)
77 return p
80def FDPoissonSolverWrapper(dipolelayer=None, zero_vacuum=False, **kwargs):
81 if dipolelayer is not None:
82 return DipoleCorrection(FDPoissonSolver(**kwargs), dipolelayer,
83 zero_vacuum=zero_vacuum)
84 return FDPoissonSolver(**kwargs)
87class _PoissonSolver:
88 """Abstract PoissonSolver class
90 This class defines an interface and a common ancestor
91 for various PoissonSolver implementations (including wrappers)."""
92 def __init__(self):
93 object.__init__(self)
95 def set_grid_descriptor(self, gd):
96 raise NotImplementedError()
98 def solve(self):
99 raise NotImplementedError()
101 def todict(self):
102 raise NotImplementedError(self.__class__.__name__)
104 def get_description(self):
105 return self.__class__.__name__
107 def estimate_memory(self, mem):
108 raise NotImplementedError()
110 def build(self, grid, xp):
111 from gpaw.new.poisson import PoissonSolverWrapper
112 self.xp = xp
113 self.set_grid_descriptor(grid._gd)
114 return PoissonSolverWrapper(self)
117class BasePoissonSolver(_PoissonSolver):
118 def __init__(self, *, remove_moment=None,
119 use_charge_center=False,
120 metallic_electrodes=False,
121 eps=None,
122 use_charged_periodic_corrections=False,
123 xp=np):
125 self.xp = xp
127 if eps is not None:
128 warnings.warn(
129 "Please do not specify the eps parameter "
130 f"for {self.__class__.__name__}. "
131 "The parameter doesn't do anything for this solver "
132 "and defining it will throw an error in the future.",
133 FutureWarning)
135 if remove_moment is not None:
136 warnings.warn(
137 "Please do not specify the remove_moment parameter "
138 f"for {self.__class__.__name__}. "
139 "The remove moment functionality is deprecated in this solver "
140 "and will throw an error in the future. Instead "
141 "use the MomentCorrectionPoissonSolver as a wrapper to "
142 f"{self.__class__.__name__}.",
143 FutureWarning)
145 # metallic electrodes: mirror image method to allow calculation of
146 # charged, partly periodic systems
147 self.gd = None
148 self.remove_moment = remove_moment
149 self.use_charge_center = use_charge_center
150 self.use_charged_periodic_corrections = \
151 use_charged_periodic_corrections
152 self.charged_periodic_correction = None
153 self.eps = eps
154 self.metallic_electrodes = metallic_electrodes
155 assert self.metallic_electrodes in [False, None, 'single', 'both']
157 def todict(self):
158 d = {'name': 'basepoisson'}
159 if self.remove_moment:
160 d['remove_moment'] = self.remove_moment
161 if self.use_charge_center:
162 d['use_charge_center'] = self.use_charge_center
163 if self.use_charged_periodic_corrections:
164 d['use_charged_periodic_corrections'] = \
165 self.use_charged_periodic_corrections
166 if self.metallic_electrodes:
167 d['metallic_electrodes'] = self.metallic_electrodes
169 return d
171 def get_description(self):
172 # The idea is that the subclass writes a header and main parameters,
173 # then adds the below string.
174 lines = []
175 if self.remove_moment is not None:
176 lines.append(' Remove moments up to L=%d' % self.remove_moment)
177 if self.use_charge_center:
178 lines.append(' Compensate for charged system using center of '
179 'majority charge')
180 if self.use_charged_periodic_corrections:
181 lines.append(' Subtract potential of homogeneous background')
183 return '\n'.join(lines)
185 def solve(self, phi, rho, charge=None, maxcharge=1e-6,
186 zero_initial_phi=False, timer=NullTimer()):
187 self._init()
188 assert np.all(phi.shape == self.gd.n_c)
189 assert np.all(rho.shape == self.gd.n_c)
191 actual_charge = self.gd.integrate(rho)
192 background = (actual_charge / self.gd.dv /
193 self.gd.get_size_of_global_array().prod())
195 if self.remove_moment:
196 assert not self.gd.pbc_c.any()
197 if not hasattr(self, 'gauss'):
198 self.gauss = Gaussian(self.gd)
199 rho_neutral = rho.copy()
200 phi_cor_L = []
201 for L in range(self.remove_moment):
202 phi_cor_L.append(self.gauss.remove_moment(rho_neutral, L))
203 # Remove multipoles for better initial guess
204 for phi_cor in phi_cor_L:
205 phi -= phi_cor
207 niter = self.solve_neutral(phi, rho_neutral, timer=timer)
208 # correct error introduced by removing multipoles
209 for phi_cor in phi_cor_L:
210 phi += phi_cor
212 return niter
213 if charge is None:
214 charge = actual_charge
215 if abs(charge) <= maxcharge:
216 return self.solve_neutral(phi, rho - background, timer=timer)
218 elif abs(charge) > maxcharge and self.gd.pbc_c.all():
219 # System is charged and periodic. Subtract a homogeneous
220 # background charge
222 # Set initial guess for potential
223 if zero_initial_phi:
224 phi[:] = 0.0
226 iters = self.solve_neutral(phi, rho - background, timer=timer)
228 if self.use_charged_periodic_corrections:
229 if self.charged_periodic_correction is None:
230 self.charged_periodic_correction = madelung(
231 self.gd.cell_cv)
232 phi += actual_charge * self.charged_periodic_correction
234 return iters
236 elif abs(charge) > maxcharge and not self.gd.pbc_c.any():
237 # The system is charged and in a non-periodic unit cell.
238 # Determine the potential by 1) subtract a gaussian from the
239 # density, 2) determine potential from the neutralized density
240 # and 3) add the potential from the gaussian density.
242 # Load necessary attributes
244 # use_charge_center: The monopole will be removed at the
245 # center of the majority charge, which prevents artificial
246 # dipoles.
247 # Due to the shape of the Gaussian and it's Fourier-Transform,
248 # the Gaussian representing the charge should stay at least
249 # 7 gpts from the borders - see:
250 # listserv.fysik.dtu.dk/pipermail/gpaw-developers/2015-July/005806.html
251 if self.use_charge_center:
252 charge_sign = actual_charge / abs(actual_charge)
253 rho_sign = rho * charge_sign
254 rho_sign[np.where(rho_sign < 0)] = 0
255 absolute_charge = self.gd.integrate(rho_sign)
256 center = - (self.gd.calculate_dipole_moment(rho_sign) /
257 absolute_charge)
258 border_offset = np.inner(self.gd.h_cv, np.array((7, 7, 7)))
259 borders = np.inner(self.gd.h_cv, self.gd.N_c)
260 borders -= border_offset
261 if np.any(center > borders) or np.any(center < border_offset):
262 raise RuntimeError('Poisson solver: '
263 'center of charge outside borders '
264 '- please increase box')
265 center[np.where(center > borders)] = borders
266 self.load_gauss(center=center)
267 else:
268 self.load_gauss()
270 # Remove monopole moment
271 q = actual_charge / np.sqrt(4 * pi) # Monopole moment
272 rho_neutral = rho - q * self.rho_gauss # neutralized density
274 # Set initial guess for potential
275 if zero_initial_phi:
276 phi[:] = 0.0
277 else:
278 phi -= q * self.phi_gauss
280 # Determine potential from neutral density using standard solver
281 niter = self.solve_neutral(phi, rho_neutral, timer=timer)
283 # correct error introduced by removing monopole
284 phi += q * self.phi_gauss
286 return niter
287 else:
288 # System is charged with mixed boundaryconditions
289 if self.metallic_electrodes == 'single':
290 self.c = 2
291 origin_c = [0, 0, 0]
292 origin_c[self.c] = self.gd.N_c[self.c]
293 drhot_g, dvHt_g, self.correction = dipole_correction(
294 self.c,
295 self.gd,
296 rho,
297 origin_c=origin_c)
298 # self.correction *=-1.
299 phi -= dvHt_g
300 iters = self.solve_neutral(phi, rho + drhot_g, timer=timer)
301 phi += dvHt_g
302 phi -= self.correction
303 self.correction = 0.0
305 return iters
307 elif self.metallic_electrodes == 'both':
308 iters = self.solve_neutral(phi, rho, timer=timer)
309 return iters
311 else:
312 # System is charged with mixed boundaryconditions
313 msg = ('Charged systems with mixed periodic/zero'
314 ' boundary conditions')
315 raise NotImplementedError(msg)
317 def load_gauss(self, center=None):
318 if not hasattr(self, 'rho_gauss') or center is not None:
319 gauss = Gaussian(self.gd, center=center)
320 self.rho_gauss = self.xp.asarray(gauss.get_gauss(0))
321 self.phi_gauss = self.xp.asarray(gauss.get_gauss_pot(0))
324class FDPoissonSolver(BasePoissonSolver):
325 def __init__(self, nn=3, relax='J', eps=2e-10, maxiter=1000,
326 remove_moment=None, use_charge_center=False,
327 metallic_electrodes=False,
328 use_charged_periodic_corrections=False, **kwargs):
329 super(FDPoissonSolver, self).__init__(
330 remove_moment=remove_moment,
331 use_charge_center=use_charge_center,
332 metallic_electrodes=metallic_electrodes,
333 use_charged_periodic_corrections=use_charged_periodic_corrections,
334 **kwargs)
335 self.eps = eps
336 self.relax = relax
337 self.nn = nn
338 self.maxiter = maxiter
340 # Relaxation method
341 if relax == 'GS':
342 # Gauss-Seidel
343 self.relax_method = 1
344 elif relax == 'J':
345 # Jacobi
346 self.relax_method = 2
347 else:
348 raise NotImplementedError('Relaxation method %s' % relax)
350 self.description = None
351 self._initialized = False
353 def todict(self):
354 d = super().todict()
355 d.update({'name': 'fd', 'nn': self.nn, 'relax': self.relax,
356 'eps': self.eps})
357 return d
359 def get_stencil(self):
360 return self.nn
362 def create_laplace(self, gd, scale=1.0, n=1, dtype=float):
363 """Instantiate and return a Laplace operator
365 Allows subclasses to change the Laplace operator
366 """
367 return Laplace(gd, scale, n, dtype, xp=self.xp)
369 def set_grid_descriptor(self, gd):
370 # Should probably be renamed initialize
371 self.gd = gd
372 scale = -0.25 / pi
374 if self.nn == 'M':
375 if not gd.orthogonal:
376 raise RuntimeError('Cannot use Mehrstellen stencil with '
377 'non orthogonal cell.')
379 self.operators = [LaplaceA(gd, -scale, xp=self.xp)]
380 self.B = LaplaceB(gd)
381 else:
382 self.operators = [self.create_laplace(gd, scale, self.nn)]
383 self.B = None
385 self.interpolators = []
386 self.restrictors = []
388 level = 0
389 self.presmooths = [2]
390 self.postsmooths = [1]
392 # Weights for the relaxation,
393 # only used if 'J' (Jacobi) is chosen as method
394 self.weights = [2.0 / 3.0]
396 while level < 8:
397 try:
398 gd2 = gd.coarsen()
399 except ValueError:
400 break
401 self.operators.append(self.create_laplace(gd2, scale, 1))
402 self.interpolators.append(Transformer(gd2, gd, xp=self.xp))
403 self.restrictors.append(Transformer(gd, gd2, xp=self.xp))
404 self.presmooths.append(4)
405 self.postsmooths.append(4)
406 self.weights.append(1.0)
407 level += 1
408 gd = gd2
410 self.levels = level
412 if self.operators[-1].gd.N_c.max() > 36:
413 # Try to warn exactly once no matter how one uses the solver.
414 if gd.comm.parent is None:
415 warn = (gd.comm.rank == 0)
416 else:
417 warn = (gd.comm.parent.rank == 0)
419 if warn:
420 warntxt = '\n'.join([POISSON_GRID_WARNING, '',
421 self.get_description()])
422 else:
423 warntxt = ('Poisson warning from domain rank %d'
424 % self.gd.comm.rank)
426 # Warn from all ranks to avoid deadlocks.
427 warnings.warn(warntxt, stacklevel=2)
429 self._initialized = False
430 # The Gaussians depend on the grid as well so we have to 'unload' them
431 if hasattr(self, 'rho_gauss'):
432 del self.rho_gauss
433 del self.phi_gauss
435 def get_description(self):
436 name = {1: 'Gauss-Seidel', 2: 'Jacobi'}[self.relax_method]
437 coarsest_grid = self.operators[-1].gd.N_c
438 coarsest_grid_string = ' x '.join([str(N) for N in coarsest_grid])
439 assert self.levels + 1 == len(self.operators)
440 lines = ['%s solver with %d multi-grid levels'
441 % (name, self.levels + 1),
442 ' Coarsest grid: %s points' % coarsest_grid_string]
443 if coarsest_grid.max() > 24:
444 # This friendly warning has lower threshold than the big long
445 # one that we print when things are really bad.
446 lines.extend([' Warning: Coarse grid has more than 24 points.',
447 ' More multi-grid levels recommended.'])
448 lines.extend([' Stencil: %s' % self.operators[0].description,
449 ' Max iterations: %d' % self.maxiter])
450 lines.extend([' Tolerance: %e' % self.eps])
451 lines.append(super().get_description())
452 return '\n'.join(lines)
454 def _init(self):
455 if self._initialized:
456 return
457 # Should probably be renamed allocate
458 gd = self.gd
459 self.rhos = [gd.empty(xp=self.xp)]
460 self.phis = [None]
461 self.residuals = [gd.empty(xp=self.xp)]
462 for level in range(self.levels):
463 gd2 = gd.coarsen()
464 self.phis.append(gd2.empty(xp=self.xp))
465 self.rhos.append(gd2.empty(xp=self.xp))
466 self.residuals.append(gd2.empty(xp=self.xp))
467 gd = gd2
468 assert len(self.phis) == len(self.rhos)
469 level += 1
470 assert level == self.levels
472 self.step = 0.66666666 / self.operators[0].get_diagonal_element()
473 self.presmooths[level] = 8
474 self.postsmooths[level] = 8
475 self._initialized = True
477 def solve_neutral(self, phi, rho, timer=None):
478 self._init()
479 self.phis[0] = phi
480 eps = self.eps
481 if self.B is None:
482 self.rhos[0][:] = rho
483 else:
484 self.B.apply(rho, self.rhos[0])
486 niter = 1
487 maxiter = self.maxiter
488 while self.iterate2(self.step) > eps and niter < maxiter:
489 niter += 1
490 if niter == maxiter:
491 msg = 'Poisson solver did not converge in %d iterations!' % maxiter
492 raise PoissonConvergenceError(msg)
494 # Set the average potential to zero in periodic systems
495 if (self.gd.pbc_c).all():
496 phi_ave = self.gd.comm.sum_scalar(float(np.sum(phi.ravel())))
497 N_c = self.gd.get_size_of_global_array()
498 phi_ave /= np.prod(N_c)
499 phi -= phi_ave
501 return niter
503 def iterate2(self, step, level=0):
504 """Smooths the solution in every multigrid level"""
505 self._init()
507 residual = self.residuals[level]
509 if level < self.levels:
510 self.operators[level].relax(self.relax_method,
511 self.phis[level],
512 self.rhos[level],
513 self.presmooths[level],
514 self.weights[level])
516 self.operators[level].apply(self.phis[level], residual)
517 residual -= self.rhos[level]
518 self.restrictors[level].apply(residual,
519 self.rhos[level + 1])
520 self.phis[level + 1][:] = 0.0
521 self.iterate2(4.0 * step, level + 1)
522 self.interpolators[level].apply(self.phis[level + 1], residual)
523 self.phis[level] -= residual
525 self.operators[level].relax(self.relax_method,
526 self.phis[level],
527 self.rhos[level],
528 self.postsmooths[level],
529 self.weights[level])
530 if level == 0:
531 self.operators[level].apply(self.phis[level], residual)
532 residual -= self.rhos[level]
533 error = self.gd.comm.sum_scalar(
534 float(self.xp.dot(residual.ravel(),
535 residual.ravel()))) * self.gd.dv
537 # How about this instead:
538 # error = self.gd.comm.max(abs(residual).max())
540 return error
542 def estimate_memory(self, mem):
543 # XXX Memory estimate works only for J and GS, not FFT solver
544 # Poisson solver appears to use same amount of memory regardless
545 # of whether it's J or GS, which is a bit strange
547 gdbytes = self.gd.bytecount()
548 nbytes = -gdbytes # No phi on finest grid, compensate ahead
549 for level in range(self.levels):
550 nbytes += 3 * gdbytes # Arrays: rho, phi, residual
551 gdbytes //= 8
552 mem.subnode('rho, phi, residual [%d levels]' % self.levels, nbytes)
554 def __repr__(self):
555 template = 'FDPoissonSolver(relax=\'%s\', nn=%s, eps=%e)'
556 representation = template % (self.relax, repr(self.nn), self.eps)
557 return representation
560class NoInteractionPoissonSolver(_PoissonSolver):
561 relax_method = 0
562 nn = 1
564 def get_description(self):
565 return 'No interaction'
567 def get_stencil(self):
568 return 1
570 def solve(self, phi, rho, charge, timer=None):
571 return 0
573 def set_grid_descriptor(self, gd):
574 pass
576 def todict(self):
577 return {'name': 'nointeraction'}
579 def estimate_memory(self, mem):
580 pass
583class FFTPoissonSolver(BasePoissonSolver):
584 """FFT Poisson solver for general unit cells."""
586 relax_method = 0
587 nn = 999
589 def __init__(self, **kwargs):
590 super().__init__(**kwargs)
591 self._initialized = False
593 def get_description(self):
594 return 'Parallel FFT'
596 def todict(self):
597 return {'name': 'fft'}
599 def set_grid_descriptor(self, gd):
600 # We will probably want to use this on non-periodic grids too...
601 assert gd.pbc_c.all()
602 self.gd = gd
604 self.grids = [gd]
605 for c in range(3):
606 N_c = gd.N_c.copy()
607 N_c[c] = 1 # Will be serial in that direction
608 parsize_c = decompose_domain(N_c, gd.comm.size)
609 self.grids.append(gd.new_descriptor(parsize_c=parsize_c))
610 self._initialized = False
612 def _init(self):
613 if self._initialized:
614 return
616 gd = self.grids[-1]
617 k2_Q, N3 = construct_reciprocal(gd)
618 self.poisson_factor_Q = 4.0 * np.pi / k2_Q
619 self._initialized = True
621 def solve_neutral(self, phi_g, rho_g, timer=None):
622 self._init()
623 # Will be a bit more efficient if reduced dimension is always
624 # contiguous. Probably more things can be improved...
626 gd1 = self.gd
627 work1_g = rho_g
629 for c in range(3):
630 gd2 = self.grids[c + 1]
631 work2_g = gd2.empty(dtype=work1_g.dtype)
632 grid2grid(gd1.comm, gd1, gd2, work1_g, work2_g)
633 work1_g = fftn(work2_g, axes=[c])
634 gd1 = gd2
636 work1_g *= self.poisson_factor_Q
638 for c in [2, 1, 0]:
639 gd2 = self.grids[c]
640 work2_g = ifftn(work1_g, axes=[c])
641 work1_g = gd2.empty(dtype=work2_g.dtype)
642 grid2grid(gd1.comm, gd1, gd2, work2_g, work1_g)
643 gd1 = gd2
645 phi_g[:] = work1_g.real
646 return 1
648 def estimate_memory(self, mem):
649 mem.subnode('k squared', self.grids[-1].bytecount())
652"""def rfst2(A_g, axes=[0,1]):
653 assert axes[0] == 0
654 assert axes[1] == 1
655 x,y,z = A_g.shape
656 temp_g = np.zeros((x*2+2, y*2+2, z))
657 temp_g[1:x+1, 1:y+1,:] = A_g
658 temp_g[x+2:, 1:y+1,:] = -A_g[::-1, :, :]
659 temp_g[1:x+1, y+2:,:] = -A_g[:, ::-1, :]
660 temp_g[x+2:, y+2:,:] = A_g[::-1, ::-1, :]
661 X = -4*rfft2(temp_g, axes=axes)
662 return X[1:x+1, 1:y+1, :].real
664def irfst2(A_g, axes=[0,1]):
665 assert axes[0] == 0
666 assert axes[1] == 1
667 x,y,z = A_g.shape
668 temp_g = np.zeros((x*2+2, (y*2+2)//2+1, z))
669 temp_g[1:x+1, 1:y+1,:] = A_g
670 temp_g[x+2:, 1:y+1,:] = -A_g[::-1, :, :]
671 return -0.25*irfft2(temp_g, axes=axes)[1:x+1, 1:y+1, :].real
672"""
675use_scipy_transforms = True
678def rfst2(A_g, axes=[0, 1]):
679 all = {0, 1, 2}
680 third = [all.difference(set(axes)).pop()]
682 if use_scipy_transforms:
683 Y = A_g
684 for axis in axes:
685 Y = scipydst(Y, axis=axis, type=1)
686 Y *= 2**len(axes)
687 return Y
689 A_g = np.transpose(A_g, axes + third)
690 x, y, z = A_g.shape
691 temp_g = np.zeros((x * 2 + 2, y * 2 + 2, z))
692 temp_g[1:x + 1, 1:y + 1, :] = A_g
693 temp_g[x + 2:, 1:y + 1, :] = -A_g[::-1, :, :]
694 temp_g[1:x + 1, y + 2:, :] = -A_g[:, ::-1, :]
695 temp_g[x + 2:, y + 2:, :] = A_g[::-1, ::-1, :]
696 X = -4 * rfft2(temp_g, axes=[0, 1])[1:x + 1, 1:y + 1, :].real
697 return np.transpose(X, np.argsort(axes + third))
700def irfst2(A_g, axes=[0, 1]):
701 if use_scipy_transforms:
702 Y = A_g
703 for axis in axes:
704 Y = scipydst(Y, axis=axis, type=1)
705 magic = 1.0 / (16 * np.prod([A_g.shape[axis] + 1 for axis in axes]))
706 Y *= magic
707 # Y /= 211200
708 return Y
710 all = {0, 1, 2}
711 third = [all.difference(set(axes)).pop()]
712 A_g = np.transpose(A_g, axes + third)
713 x, y, z = A_g.shape
714 temp_g = np.zeros((x * 2 + 2, (y * 2 + 2) // 2 + 1, z))
715 temp_g[1:x + 1, 1:y + 1, :] = A_g.real
716 temp_g[x + 2:, 1:y + 1, :] = -A_g[::-1, :, :].real
717 X = -0.25 * irfft2(temp_g, axes=[0, 1])[1:x + 1, 1:y + 1, :]
719 T = np.transpose(X, np.argsort(axes + third))
720 return T
723# This method needs to be taken from fftw / scipy to gain speedup of ~4x
724def fst(A_g, axis):
725 x, y, z = A_g.shape
726 N_c = np.array([x, y, z])
727 N_c[axis] = N_c[axis] * 2 + 2
728 temp_g = np.zeros(N_c, dtype=A_g.dtype)
729 if axis == 0:
730 temp_g[1:x + 1, :, :] = A_g
731 temp_g[x + 2:, :, :] = -A_g[::-1, :, :]
732 elif axis == 1:
733 temp_g[:, 1:y + 1, :] = A_g
734 temp_g[:, y + 2:, :] = -A_g[:, ::-1, :]
735 elif axis == 2:
736 temp_g[:, :, 1:z + 1] = A_g
737 temp_g[:, :, z + 2:] = -A_g[:, ::, ::-1]
738 else:
739 raise NotImplementedError()
740 X = 0.5j * fft(temp_g, axis=axis)
741 if axis == 0:
742 return X[1:x + 1, :, :]
743 elif axis == 1:
744 return X[:, 1:y + 1, :]
745 elif axis == 2:
746 return X[:, :, 1:z + 1]
749def ifst(A_g, axis):
750 x, y, z = A_g.shape
751 N_c = np.array([x, y, z])
752 N_c[axis] = N_c[axis] * 2 + 2
753 temp_g = np.zeros(N_c, dtype=A_g.dtype)
755 if axis == 0:
756 temp_g[1:x + 1, :, :] = A_g
757 temp_g[x + 2:, :, :] = -A_g[::-1, :, :]
758 elif axis == 1:
759 temp_g[:, 1:y + 1, :] = A_g
760 temp_g[:, y + 2:, :] = -A_g[:, ::-1, :]
761 elif axis == 2:
762 temp_g[:, :, 1:z + 1] = A_g
763 temp_g[:, :, z + 2:] = -A_g[:, ::, ::-1]
764 else:
765 raise NotImplementedError()
767 X_g = ifft(temp_g, axis=axis)
768 if axis == 0:
769 return -2j * X_g[1:x + 1, :, :]
770 elif axis == 1:
771 return -2j * X_g[:, 1:y + 1, :]
772 elif axis == 2:
773 return -2j * X_g[:, :, 1:z + 1]
776def transform(A_g, axis=None, pbc=True):
777 if pbc:
778 if A_g.size == 0:
779 return A_g.astype(complex)
781 return fft(A_g, axis=axis)
782 else:
783 if A_g.size == 0:
784 return A_g
786 if not use_scipy_transforms:
787 x = fst(A_g, axis)
788 return x
789 y = scipydst(A_g, axis=axis, type=1)
790 y *= .5
791 return y
794def transform2(A_g, axes=None, pbc=[True, True]):
795 if all(pbc):
796 if A_g.size == 0:
797 return A_g.astype(complex)
799 return fft2(A_g, axes=axes)
800 elif not any(pbc):
801 if A_g.size == 0:
802 return A_g
804 return rfst2(A_g, axes=axes)
805 else:
806 return transform(transform(A_g, axis=axes[0], pbc=pbc[0]),
807 axis=axes[1], pbc=pbc[1])
810def itransform(A_g, axis=None, pbc=True):
811 if pbc:
812 if A_g.size == 0:
813 return A_g.astype(complex)
815 return ifft(A_g, axis=axis)
816 else:
817 if A_g.size == 0:
818 return A_g
820 if not use_scipy_transforms:
821 x = ifst(A_g, axis)
822 return x
823 y = scipydst(A_g, axis=axis, type=1)
824 magic = 1.0 / (A_g.shape[axis] + 1)
825 y *= magic
826 return y
829def itransform2(A_g, axes=None, pbc=[True, True]):
830 if all(pbc):
831 if A_g.size == 0:
832 return A_g.astype(complex)
834 return ifft2(A_g, axes=axes)
835 elif not any(pbc):
836 if A_g.size == 0:
837 return A_g
839 return irfst2(A_g, axes=axes)
840 else:
841 return itransform(itransform(A_g, axis=axes[0], pbc=pbc[0]),
842 axis=axes[1], pbc=pbc[1])
845class BadAxesError(ValueError):
846 pass
849class FastPoissonSolver(BasePoissonSolver):
850 def __init__(self, nn=3, **kwargs):
851 BasePoissonSolver.__init__(self, **kwargs)
852 self.nn = nn
853 # We may later enable this to work with Cholesky, but not now:
854 self.use_cholesky = False
856 def _init(self):
857 pass
859 def set_grid_descriptor(self, gd):
860 self.gd = gd
861 axes = np.arange(3)
862 pbc_c = np.array(gd.pbc_c, dtype=bool)
863 periodic_axes = axes[pbc_c]
864 non_periodic_axes = axes[np.logical_not(pbc_c)]
866 # Find out which axes are orthogonal (0, 1 or 3)
867 # Note that one expects that the axes are always rotated in
868 # conventional form, thus for all axes to be
869 # classified as orthogonal, the cell_cv needs to be diagonal.
870 # This may always be achieved by rotating
871 # the unit-cell along with the atoms. The classification is
872 # inherited from grid_descriptor.orthogonal.
873 dotprods = np.dot(gd.cell_cv, gd.cell_cv.T)
874 # For each direction, check whether there is only one nonzero
875 # element in that row (necessarily being the diagonal element,
876 # since this is a cell vector length and must be > 0).
877 orthogonal_c = (np.abs(dotprods) > 1e-10).sum(axis=0) == 1
878 assert sum(orthogonal_c) in [0, 1, 3]
880 non_orthogonal_axes = axes[np.logical_not(orthogonal_c)]
882 if not all(pbc_c | orthogonal_c):
883 raise BadAxesError('Each axis must be periodic or orthogonal '
884 'to other axes. But we have pbc={} '
885 'and orthogonal={}'
886 .format(pbc_c.astype(int),
887 orthogonal_c.astype(int)))
889 # We sort them, and pick the longest non-periodic axes as the
890 # cholesky axis.
891 sorted_non_periodic_axes = sorted(non_periodic_axes,
892 key=lambda c: gd.N_c[c])
893 if self.use_cholesky:
894 if len(sorted_non_periodic_axes) > 0:
895 cholesky_axes = [sorted_non_periodic_axes[-1]]
896 if cholesky_axes[0] in non_orthogonal_axes:
897 msg = ('Cholesky axis cannot be non-orthogonal. '
898 'Do you really want a non-orthogonal non-periodic '
899 'axis? If so, run with use_cholesky=False.')
900 raise NotImplementedError(msg)
901 fst_axes = sorted_non_periodic_axes[0:-1]
902 else:
903 cholesky_axes = []
904 fst_axes = []
905 else:
906 cholesky_axes = []
907 fst_axes = sorted_non_periodic_axes
908 fft_axes = list(periodic_axes)
910 (self.cholesky_axes, self.fst_axes,
911 self.fft_axes) = cholesky_axes, fst_axes, fft_axes
913 fftfst_axes = self.fft_axes + self.fst_axes
914 axes = self.fft_axes + self.fst_axes + self.cholesky_axes
915 self.axes = axes
917 # Create xy flat decomposition (where x=axes[0] and y=axes[1])
918 parsize_c = [1, 1, 1]
919 parsize_c[axes[2]] = gd.comm.size
920 gd1d = gd.new_descriptor(parsize_c=parsize_c,
921 allow_empty_domains=True)
922 self.gd1d = gd1d
924 # Create z flat decomposition
925 domain = gd.N_c.copy()
926 domain[axes[2]] = 1
927 parsize_c = decompose_domain(domain, gd.comm.size)
928 gd2d = gd.new_descriptor(parsize_c=parsize_c)
929 self.gd2d = gd2d
931 # Calculate eigenvalues in fst/fft decomposition for
932 # non-cholesky axes in parallel
933 xp = self.xp
934 r_cx = xp.indices(gd2d.n_c)
935 r_cx += xp.asarray(gd2d.beg_c[:, xp.newaxis, xp.newaxis, xp.newaxis])
936 r_cx = r_cx.astype(complex)
937 for c, axis in enumerate(fftfst_axes):
938 r_cx[axis] *= 2j * xp.pi / gd2d.N_c[axis]
939 if axis in fst_axes:
940 r_cx[axis] /= 2
941 for c, axis in enumerate(cholesky_axes):
942 r_cx[axis] = 0.0
943 xp.exp(r_cx, out=r_cx)
944 fft_lambdas = xp.zeros_like(r_cx[0], dtype=complex)
945 laplace = Laplace(self.gd, -0.25 / pi, self.nn)
946 self.stencil_description = laplace.description
948 for coeff, offset_c in zip(laplace.coef_p, laplace.offset_pc):
949 offset_c = np.array(offset_c)
950 if not any(offset_c):
951 # The centerpoint is handled with (temp-1.0)
952 continue
953 non_zero_axes, = np.where(offset_c)
954 if set(non_zero_axes).issubset(fftfst_axes):
955 temp = xp.ones_like(fft_lambdas)
956 for c, axis in enumerate(fftfst_axes):
957 temp *= r_cx[axis] ** offset_c[axis]
958 fft_lambdas += coeff * (temp - 1.0)
960 assert xp.linalg.norm(fft_lambdas.imag) < 1e-10
961 fft_lambdas = fft_lambdas.real.copy() # arr.real is not contiguous
963 # If there is no Cholesky decomposition, the system is already
964 # fully diagonal and we can directly invert the linear problem
965 # by dividing with the eigenvalues.
966 # TODO: Remove cholesky alltogether, since it is not used
967 assert len(cholesky_axes) == 0
968 # if len(cholesky_axes) == 0:
969 with np.errstate(divide='ignore'):
970 self.inv_fft_lambdas = xp.where(
971 xp.abs(fft_lambdas) > 1e-10, 1.0 / fft_lambdas, 0)
973 def solve_neutral(self, phi_g, rho_g, timer=None):
974 if len(self.cholesky_axes) != 0:
975 raise NotImplementedError
977 gd = self.gd
978 gd1d = self.gd1d
979 gd2d = self.gd2d
980 comm = self.gd.comm
981 axes = self.axes
983 with timer('Communicate to 1D'):
984 work1d_g = gd1d.empty(dtype=rho_g.dtype, xp=self.xp)
985 grid2grid(comm, gd, gd1d, rho_g, work1d_g, xp=self.xp)
986 with timer('FFT 2D'):
987 work1d_g = transform2(work1d_g, axes=axes[:2],
988 pbc=gd.pbc_c[axes[:2]])
989 with timer('Communicate to 2D'):
990 work2d_g = gd2d.empty(dtype=work1d_g.dtype, xp=self.xp)
991 grid2grid(comm, gd1d, gd2d, work1d_g, work2d_g, xp=self.xp)
992 with timer('FFT 1D'):
993 work2d_g = transform(work2d_g, axis=axes[2],
994 pbc=gd.pbc_c[axes[2]])
996 # The remaining problem is 0D dimensional, i.e the problem
997 # has been fully diagonalized
998 work2d_g *= self.inv_fft_lambdas
1000 with timer('FFT 1D'):
1001 work2d_g = itransform(work2d_g, axis=axes[2],
1002 pbc=gd.pbc_c[axes[2]])
1003 with timer('Communicate from 2D'):
1004 work1d_g = gd1d.empty(dtype=work2d_g.dtype, xp=self.xp)
1005 grid2grid(comm, gd2d, gd1d, work2d_g, work1d_g, xp=self.xp)
1006 with timer('FFT 2D'):
1007 work1d_g = itransform2(work1d_g, axes=axes[1::-1],
1008 pbc=gd.pbc_c[axes[1::-1]])
1009 with timer('Communicate from 1D'):
1010 work_g = gd.empty(dtype=work1d_g.dtype, xp=self.xp)
1011 grid2grid(comm, gd1d, gd, work1d_g, work_g, xp=self.xp)
1013 phi_g[:] = work_g.real
1014 return 1 # Non-iterative method, return 1 iteration
1016 def todict(self):
1017 d = super().todict()
1018 d.update({'name': 'fast', 'nn': self.nn})
1019 return d
1021 def estimate_memory(self, mem):
1022 pass
1024 def get_description(self):
1025 lines = [f'{self.__class__.__name__} using',
1026 f' Stencil: {self.stencil_description}',
1027 f' FFT axes: {self.fft_axes}',
1028 f' FST axes: {self.fst_axes}',
1029 ]
1030 lines.append(BasePoissonSolver.get_description(self))
1031 return '\n'.join(lines)