Coverage for gpaw/dipole_correction.py: 85%

109 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import numpy as np 

2from scipy.special import erf 

3from ase.units import Bohr 

4 

5 

6class DipoleCorrection: 

7 """Dipole-correcting wrapper around another PoissonSolver.""" 

8 def __init__(self, poissonsolver, direction, width=1.0, 

9 zero_vacuum=False): 

10 """Construct dipole correction object. 

11 

12 poissonsolver: 

13 Poisson solver. 

14 direction: int or str 

15 Specification of layer: 0, 1, 2, 'xy', 'xz' or 'yz'. 

16 width: float 

17 Width in Angstrom of dipole layer used for the plane-wave 

18 implementation. 

19 """ 

20 self.c = direction 

21 self.poissonsolver = poissonsolver 

22 self.width = width / Bohr 

23 self.zero_vacuum = zero_vacuum 

24 self.correction = None # shift in potential 

25 self.sawtooth_q = None # Fourier transformed sawtooth 

26 

27 def todict(self): 

28 dct = self.poissonsolver.todict() 

29 dct['dipolelayer'] = self.c 

30 if self.width != 1.0 / Bohr: 

31 dct['width'] = self.width * Bohr 

32 return dct 

33 

34 def get_stencil(self): 

35 return self.poissonsolver.get_stencil() 

36 

37 def set_grid_descriptor(self, gd): 

38 self.poissonsolver.set_grid_descriptor(gd) 

39 self.check_direction(gd, gd.pbc_c) 

40 

41 def check_direction(self, gd, pbc_c): 

42 if isinstance(self.c, str): 

43 axes = ['xyz'.index(d) for d in self.c] 

44 for c in range(3): 

45 if abs(gd.cell_cv[c, axes]).max() < 1e-12: 

46 break 

47 else: 

48 raise ValueError('No axis perpendicular to {}-plane!' 

49 .format(self.c)) 

50 self.c = c 

51 

52 if pbc_c[self.c]: 

53 raise ValueError('System must be non-periodic perpendicular ' 

54 'to dipole-layer.') 

55 

56 # Right now the dipole correction must be along one coordinate 

57 # axis and orthogonal to the two others. The two others need not 

58 # be orthogonal to each other. 

59 for c1 in range(3): 

60 if c1 != self.c: 

61 if abs(np.dot(gd.cell_cv[self.c], gd.cell_cv[c1])) > 1e-12: 

62 raise ValueError('Dipole correction axis must be ' 

63 'orthogonal to the two other axes.') 

64 

65 def get_description(self): 

66 poissondesc = self.poissonsolver.get_description() 

67 desc = 'Dipole correction along %s-axis' % 'xyz'[self.c] 

68 return '\n'.join([poissondesc, desc]) 

69 

70 def initialize(self): 

71 self.poissonsolver.initialize() 

72 

73 def solve(self, pot, dens, **kwargs): 

74 # Note that fdsolve() returns number of iterations and pwsolve() 

75 # returns the energy!! This is because the 

76 # ChargedReciprocalSpacePoissonSolver has corrections to 

77 # the energy ... 

78 if isinstance(dens, np.ndarray): 

79 # finite-diference Poisson solver: 

80 return self.fdsolve(pot, dens, **kwargs) 

81 # Plane-wave solver: 

82 return self.pwsolve(pot, dens) 

83 

84 def fdsolve(self, vHt_g, rhot_g, **kwargs): 

85 gd = self.poissonsolver.gd 

86 drhot_g, dvHt_g, self.correction = dipole_correction( 

87 self.c, gd, rhot_g) 

88 if self.zero_vacuum: 

89 dvHt_g += self.correction 

90 vHt_g -= dvHt_g 

91 iters = self.poissonsolver.solve(vHt_g, rhot_g + drhot_g, **kwargs) 

92 vHt_g += dvHt_g 

93 return iters 

94 

95 def pwsolve(self, vHt_q, dens): 

96 gd = self.poissonsolver.pd.gd 

97 

98 if self.sawtooth_q is None: 

99 self.initialize_sawtooth() 

100 

101 epot = self.poissonsolver.solve(vHt_q, dens) 

102 

103 dip_v = dens.calculate_dipole_moment() 

104 c = self.c 

105 L = gd.cell_cv[c, c] 

106 self.correction = 2 * np.pi * dip_v[c] * L / gd.volume 

107 vHt_q -= 2 * self.correction * self.sawtooth_q 

108 

109 return epot + 2 * np.pi * dip_v[c]**2 / gd.volume 

110 

111 def initialize_sawtooth(self): 

112 gd = self.poissonsolver.pd.gd 

113 if gd.comm.rank == 0: 

114 c = self.c 

115 L = gd.cell_cv[c, c] 

116 w = self.width / 2 

117 assert w < L / 2 

118 gc = int(w / gd.h_cv[c, c]) 

119 x = gd.coords(c) 

120 sawtooth = x / L - 0.5 

121 a = 1 / L - 0.75 / w 

122 b = 0.25 / w**3 

123 sawtooth[:gc] = x[:gc] * (a + b * x[:gc]**2) 

124 sawtooth[-gc:] = -sawtooth[gc:0:-1] 

125 sawtooth_g = gd.empty(global_array=True) 

126 shape = [1, 1, 1] 

127 shape[c] = -1 

128 sawtooth_g[:] = sawtooth.reshape(shape) 

129 sawtooth_q = self.poissonsolver.pd.fft(sawtooth_g, local=True) 

130 else: 

131 sawtooth_q = None 

132 self.sawtooth_q = self.poissonsolver.pd.scatter(sawtooth_q) 

133 

134 def estimate_memory(self, mem): 

135 self.poissonsolver.estimate_memory(mem) 

136 

137 def build(self, grid, xp): 

138 from gpaw.new.poisson import PoissonSolverWrapper 

139 self.xp = xp 

140 self.set_grid_descriptor(grid._gd) 

141 return PoissonSolverWrapper(self) 

142 

143 

144def dipole_correction(c, gd, rhot_g, center=False, origin_c=None): 

145 """Get dipole corrections to charge and potential. 

146 

147 Returns arrays drhot_g and dphit_g such that if rhot_g has the 

148 potential phit_g, then rhot_g + drhot_g has the potential 

149 phit_g + dphit_g, where dphit_g is an error function. 

150 

151 The error function is chosen so as to be largely constant at the 

152 cell boundaries and beyond. 

153 """ 

154 # This implementation is not particularly economical memory-wise 

155 

156 moment = gd.calculate_dipole_moment(rhot_g, center=center, 

157 origin_c=origin_c)[c] 

158 if abs(moment) < 1e-12: 

159 return gd.zeros(), gd.zeros(), 0.0 

160 

161 r_g = gd.get_grid_point_coordinates()[c] 

162 cellsize = abs(gd.cell_cv[c, c]) 

163 sr_g = 2.0 / cellsize * r_g - 1.0 # sr ~ 'scaled r' 

164 alpha = 12.0 # should perhaps be variable 

165 drho_g = sr_g * np.exp(-alpha * sr_g**2) 

166 moment2 = gd.calculate_dipole_moment(drho_g, center=center, 

167 origin_c=origin_c)[c] 

168 factor = -moment / moment2 

169 drho_g *= factor 

170 phifactor = factor * (np.pi / alpha)**1.5 * cellsize**2 / 4.0 

171 dphi_g = -phifactor * erf(sr_g * np.sqrt(alpha)) 

172 return drho_g, dphi_g, phifactor