Coverage for gpaw/new/sjm.py: 51%

118 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4from ase.units import Bohr 

5from gpaw.core import UGArray, PWDesc, PWArray 

6from gpaw.jellium import create_background_charge 

7from gpaw.new.environment import Environment, FixedPotentialJellium, Jellium 

8from gpaw.new.poisson import PoissonSolverWrapper 

9from gpaw.new.pw.poisson import PWPoissonSolver 

10from gpaw.new.solvation import SolvationEnvironment, Solvation 

11 

12 

13class SJM(Solvation): 

14 name = 'sjm' 

15 

16 def __init__(self, 

17 *, 

18 cavity, 

19 dielectric, 

20 interactions, 

21 jelliumregion: dict | None = None, 

22 target_potential: float | None, # eV 

23 excess_electrons: float = 0.0, 

24 tol: float = 0.001): # eV 

25 super().__init__(cavity, dielectric, interactions) 

26 self.jelliumregion = jelliumregion or {} 

27 self.target_potential = target_potential 

28 self.excess_electrons = excess_electrons 

29 self.tol = tol 

30 

31 def build(self, 

32 setups, 

33 grid, 

34 relpos_ac, 

35 log, 

36 comm) -> SJMEnvironment: 

37 solvation = super().build( 

38 setups=setups, grid=grid, relpos_ac=relpos_ac, 

39 log=log, comm=comm) 

40 h = grid.cell_cv[2, 2] * Bohr 

41 z1 = relpos_ac[:, 2].max() * h + 3.0 

42 z2 = self.jelliumregion.get('top', h - 1.0) 

43 background = create_background_charge(charge=self.excess_electrons, 

44 z1=z1, 

45 z2=z2) 

46 background.set_grid_descriptor(grid._gd) 

47 if self.target_potential is None: 

48 jellium = Jellium(background, 

49 natoms=len(relpos_ac), 

50 grid=grid) 

51 else: 

52 jellium = FixedPotentialJellium( 

53 background, 

54 natoms=len(relpos_ac), 

55 grid=grid, 

56 workfunction=self.target_potential, 

57 tolerance=self.tol) 

58 return SJMEnvironment(solvation, jellium) 

59 

60 def todict(self): 

61 dct = super().todict() 

62 dct.update( 

63 jelliumregion=self.jelliumregion, 

64 target_potential=self.target_potential, 

65 excess_electrons=self.excess_electrons, 

66 tol=self.tol) 

67 return dct 

68 

69 

70class SJMEnvironment(Environment): 

71 def __init__(self, 

72 solvation: SolvationEnvironment, 

73 jellium: Jellium): 

74 super().__init__(solvation.natoms) 

75 self.solvation = solvation 

76 self.jellium = jellium 

77 self.charge = jellium.charge 

78 self.dielectric = solvation.dielectric 

79 

80 def create_poisson_solver(self, **kwargs): 

81 ps = self.solvation.create_poisson_solver(**kwargs).solver 

82 return SJMPoissonSolver(ps, self.solvation.dielectric) 

83 

84 def post_scf_convergence(self, 

85 ibzwfs, 

86 nelectrons, 

87 occ_calc, 

88 mixer, 

89 log) -> bool: 

90 converged = self.jellium.post_scf_convergence( 

91 ibzwfs, nelectrons, occ_calc, mixer, log) 

92 self.charge = self.jellium.charge 

93 return converged 

94 

95 def update1(self, nt_r): 

96 self.solvation.update1(nt_r) 

97 self.jellium.update1(nt_r) 

98 

99 def update1pw(self, nt_g): 

100 nt_r = self.jellium.grid.empty() 

101 nt_r.scatter_from(nt_g.ifft(grid=self.jellium.grid.new(comm=None)) 

102 if nt_g is not None else None) 

103 self.solvation.update1(nt_r) 

104 self.jellium.update1pw(nt_g) 

105 

106 def update2(self, nt_r, vHt_r, vt_sr) -> float: 

107 return self.solvation.update2(nt_r, vHt_r, vt_sr) 

108 

109 

110class SJMPoissonSolver(PoissonSolverWrapper): 

111 def __init__(self, solver, dielectric): 

112 super().__init__(solver) 

113 

114 def solve(self, 

115 vHt_r, 

116 rhot_r) -> float: 

117 self.solver.solve(vHt_r.data, rhot_r.data) 

118 eps_r = vHt_r.desc.from_data(self.solver.dielectric.eps_gradeps[0]) 

119 eps0_r = eps_r.gather() 

120 vHt0_r = vHt_r.gather() 

121 if eps0_r is not None: 

122 saw_tooth_z = modified_saw_tooth(eps0_r) 

123 s1, s2 = saw_tooth_z[[2, 10]] 

124 v1, v2 = vHt0_r.data[:, :, [2, 10]].mean(axis=(0, 1)) 

125 vHt0_r.data -= (v2 - v1) / (s2 - s1) * saw_tooth_z[np.newaxis, 

126 np.newaxis] 

127 vHt0_r.data -= vHt0_r.data[:, :, -1].mean() 

128 vHt_r.scatter_from(vHt0_r) 

129 return np.nan 

130 

131 

132def modified_saw_tooth(eps_r: UGArray) -> np.ndarray: 

133 a_z = 1.0 / eps_r.data.mean(axis=(0, 1)) 

134 saw_tooth_z = np.add.accumulate(a_z) 

135 saw_tooth_z -= 0.5 * a_z # +0.5 from z=0.0 ??? 

136 return saw_tooth_z 

137 

138 

139class SJMPWPoissonSolver(PWPoissonSolver): 

140 def __init__(self, pw, dielectric): 

141 super().__init__(pw) 

142 self.dielectric = dielectric 

143 self.saw_tooth_g = saw_tooth(pw, 0.25) 

144 

145 def solve(self, vHt_g, rhot_g): 

146 energy = super().solve(vHt_g, rhot_g) 

147 dipole = rhot_g.moment()[2] 

148 slope = 4 * np.pi * dipole / rhot_g.desc.volume 

149 vHt_g.data += slope * self.saw_tooth_g.data 

150 # Shift potential so that it is zero above the slab: 

151 shift = 0.5 * slope * rhot_g.desc.cell_cv[2, 2] 

152 v0 = vHt_g.boundary_value(2) 

153 vHt_g.data[0] -= shift + v0 

154 return energy 

155 

156 

157def saw_tooth_sympy(): 

158 """Fourier-transform.""" 

159 from sympy import Symbol, integrate, sin, var 

160 z = var('z') 

161 G = Symbol('G', positive=True) 

162 b = Symbol('b', positive=True) 

163 m = integrate(sin(G * z) * z, (z, 0, b)) 

164 print(m) # -b*cos(G*b)/G + sin(G*b)/G**2 

165 

166 

167def saw_tooth(pw: PWDesc, width: float = 0.5) -> PWArray: 

168 """Saw-tooth in reciprocal space with a slope of 1.""" 

169 assert np.allclose(pw.cell_cv[:2, 2], 0.0) 

170 assert np.allclose(pw.cell_cv[2, :2], 0.0) 

171 

172 m0_g, m1_g = pw.indices_cG[:2, pw.ng1:pw.ng2] == 0 

173 mask_g = m0_g & m1_g 

174 Gz_i = pw.G_plus_k_Gv[mask_g, 2] 

175 if pw.comm.rank == 0.0: 

176 assert Gz_i[0] == 0.0 

177 Gz_i[0] = 1.0 

178 L = pw.cell_cv[2, 2] 

179 b = L / 2 

180 st_i = -(np.sin(b * Gz_i) / Gz_i - 

181 b * np.cos(b * Gz_i)) / Gz_i * (2j / L) 

182 if pw.comm.rank == 0.0: 

183 st_i[0] = 0.0 

184 

185 # Make the saw-tooth more smooth (fold with Gaussian): 

186 alpha = width**-2 

187 st_i *= np.exp(-Gz_i**2 / (4 * alpha)) 

188 

189 # Shift by half the cell height: 

190 st_i *= np.exp(1j * Gz_i * b) 

191 

192 st_g = pw.zeros() 

193 st_g.data[mask_g] = st_i 

194 return st_g