Coverage for gpaw/dipole_correction.py: 85%
109 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1import numpy as np
2from scipy.special import erf
3from ase.units import Bohr
6class DipoleCorrection:
7 """Dipole-correcting wrapper around another PoissonSolver."""
8 def __init__(self, poissonsolver, direction, width=1.0,
9 zero_vacuum=False):
10 """Construct dipole correction object.
12 poissonsolver:
13 Poisson solver.
14 direction: int or str
15 Specification of layer: 0, 1, 2, 'xy', 'xz' or 'yz'.
16 width: float
17 Width in Angstrom of dipole layer used for the plane-wave
18 implementation.
19 """
20 self.c = direction
21 self.poissonsolver = poissonsolver
22 self.width = width / Bohr
23 self.zero_vacuum = zero_vacuum
24 self.correction = None # shift in potential
25 self.sawtooth_q = None # Fourier transformed sawtooth
27 def todict(self):
28 dct = self.poissonsolver.todict()
29 dct['dipolelayer'] = self.c
30 if self.width != 1.0 / Bohr:
31 dct['width'] = self.width * Bohr
32 return dct
34 def get_stencil(self):
35 return self.poissonsolver.get_stencil()
37 def set_grid_descriptor(self, gd):
38 self.poissonsolver.set_grid_descriptor(gd)
39 self.check_direction(gd, gd.pbc_c)
41 def check_direction(self, gd, pbc_c):
42 if isinstance(self.c, str):
43 axes = ['xyz'.index(d) for d in self.c]
44 for c in range(3):
45 if abs(gd.cell_cv[c, axes]).max() < 1e-12:
46 break
47 else:
48 raise ValueError('No axis perpendicular to {}-plane!'
49 .format(self.c))
50 self.c = c
52 if pbc_c[self.c]:
53 raise ValueError('System must be non-periodic perpendicular '
54 'to dipole-layer.')
56 # Right now the dipole correction must be along one coordinate
57 # axis and orthogonal to the two others. The two others need not
58 # be orthogonal to each other.
59 for c1 in range(3):
60 if c1 != self.c:
61 if abs(np.dot(gd.cell_cv[self.c], gd.cell_cv[c1])) > 1e-12:
62 raise ValueError('Dipole correction axis must be '
63 'orthogonal to the two other axes.')
65 def get_description(self):
66 poissondesc = self.poissonsolver.get_description()
67 desc = 'Dipole correction along %s-axis' % 'xyz'[self.c]
68 return '\n'.join([poissondesc, desc])
70 def initialize(self):
71 self.poissonsolver.initialize()
73 def solve(self, pot, dens, **kwargs):
74 # Note that fdsolve() returns number of iterations and pwsolve()
75 # returns the energy!! This is because the
76 # ChargedReciprocalSpacePoissonSolver has corrections to
77 # the energy ...
78 if isinstance(dens, np.ndarray):
79 # finite-diference Poisson solver:
80 return self.fdsolve(pot, dens, **kwargs)
81 # Plane-wave solver:
82 return self.pwsolve(pot, dens)
84 def fdsolve(self, vHt_g, rhot_g, **kwargs):
85 gd = self.poissonsolver.gd
86 drhot_g, dvHt_g, self.correction = dipole_correction(
87 self.c, gd, rhot_g)
88 if self.zero_vacuum:
89 dvHt_g += self.correction
90 vHt_g -= dvHt_g
91 iters = self.poissonsolver.solve(vHt_g, rhot_g + drhot_g, **kwargs)
92 vHt_g += dvHt_g
93 return iters
95 def pwsolve(self, vHt_q, dens):
96 gd = self.poissonsolver.pd.gd
98 if self.sawtooth_q is None:
99 self.initialize_sawtooth()
101 epot = self.poissonsolver.solve(vHt_q, dens)
103 dip_v = dens.calculate_dipole_moment()
104 c = self.c
105 L = gd.cell_cv[c, c]
106 self.correction = 2 * np.pi * dip_v[c] * L / gd.volume
107 vHt_q -= 2 * self.correction * self.sawtooth_q
109 return epot + 2 * np.pi * dip_v[c]**2 / gd.volume
111 def initialize_sawtooth(self):
112 gd = self.poissonsolver.pd.gd
113 if gd.comm.rank == 0:
114 c = self.c
115 L = gd.cell_cv[c, c]
116 w = self.width / 2
117 assert w < L / 2
118 gc = int(w / gd.h_cv[c, c])
119 x = gd.coords(c)
120 sawtooth = x / L - 0.5
121 a = 1 / L - 0.75 / w
122 b = 0.25 / w**3
123 sawtooth[:gc] = x[:gc] * (a + b * x[:gc]**2)
124 sawtooth[-gc:] = -sawtooth[gc:0:-1]
125 sawtooth_g = gd.empty(global_array=True)
126 shape = [1, 1, 1]
127 shape[c] = -1
128 sawtooth_g[:] = sawtooth.reshape(shape)
129 sawtooth_q = self.poissonsolver.pd.fft(sawtooth_g, local=True)
130 else:
131 sawtooth_q = None
132 self.sawtooth_q = self.poissonsolver.pd.scatter(sawtooth_q)
134 def estimate_memory(self, mem):
135 self.poissonsolver.estimate_memory(mem)
137 def build(self, grid, xp):
138 from gpaw.new.poisson import PoissonSolverWrapper
139 self.xp = xp
140 self.set_grid_descriptor(grid._gd)
141 return PoissonSolverWrapper(self)
144def dipole_correction(c, gd, rhot_g, center=False, origin_c=None):
145 """Get dipole corrections to charge and potential.
147 Returns arrays drhot_g and dphit_g such that if rhot_g has the
148 potential phit_g, then rhot_g + drhot_g has the potential
149 phit_g + dphit_g, where dphit_g is an error function.
151 The error function is chosen so as to be largely constant at the
152 cell boundaries and beyond.
153 """
154 # This implementation is not particularly economical memory-wise
156 moment = gd.calculate_dipole_moment(rhot_g, center=center,
157 origin_c=origin_c)[c]
158 if abs(moment) < 1e-12:
159 return gd.zeros(), gd.zeros(), 0.0
161 r_g = gd.get_grid_point_coordinates()[c]
162 cellsize = abs(gd.cell_cv[c, c])
163 sr_g = 2.0 / cellsize * r_g - 1.0 # sr ~ 'scaled r'
164 alpha = 12.0 # should perhaps be variable
165 drho_g = sr_g * np.exp(-alpha * sr_g**2)
166 moment2 = gd.calculate_dipole_moment(drho_g, center=center,
167 origin_c=origin_c)[c]
168 factor = -moment / moment2
169 drho_g *= factor
170 phifactor = factor * (np.pi / alpha)**1.5 * cellsize**2 / 4.0
171 dphi_g = -phifactor * erf(sr_g * np.sqrt(alpha))
172 return drho_g, dphi_g, phifactor