Coverage for gpaw/new/sjm.py: 51%
118 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1from __future__ import annotations
3import numpy as np
4from ase.units import Bohr
5from gpaw.core import UGArray, PWDesc, PWArray
6from gpaw.jellium import create_background_charge
7from gpaw.new.environment import Environment, FixedPotentialJellium, Jellium
8from gpaw.new.poisson import PoissonSolverWrapper
9from gpaw.new.pw.poisson import PWPoissonSolver
10from gpaw.new.solvation import SolvationEnvironment, Solvation
13class SJM(Solvation):
14 name = 'sjm'
16 def __init__(self,
17 *,
18 cavity,
19 dielectric,
20 interactions,
21 jelliumregion: dict | None = None,
22 target_potential: float | None, # eV
23 excess_electrons: float = 0.0,
24 tol: float = 0.001): # eV
25 super().__init__(cavity, dielectric, interactions)
26 self.jelliumregion = jelliumregion or {}
27 self.target_potential = target_potential
28 self.excess_electrons = excess_electrons
29 self.tol = tol
31 def build(self,
32 setups,
33 grid,
34 relpos_ac,
35 log,
36 comm) -> SJMEnvironment:
37 solvation = super().build(
38 setups=setups, grid=grid, relpos_ac=relpos_ac,
39 log=log, comm=comm)
40 h = grid.cell_cv[2, 2] * Bohr
41 z1 = relpos_ac[:, 2].max() * h + 3.0
42 z2 = self.jelliumregion.get('top', h - 1.0)
43 background = create_background_charge(charge=self.excess_electrons,
44 z1=z1,
45 z2=z2)
46 background.set_grid_descriptor(grid._gd)
47 if self.target_potential is None:
48 jellium = Jellium(background,
49 natoms=len(relpos_ac),
50 grid=grid)
51 else:
52 jellium = FixedPotentialJellium(
53 background,
54 natoms=len(relpos_ac),
55 grid=grid,
56 workfunction=self.target_potential,
57 tolerance=self.tol)
58 return SJMEnvironment(solvation, jellium)
60 def todict(self):
61 dct = super().todict()
62 dct.update(
63 jelliumregion=self.jelliumregion,
64 target_potential=self.target_potential,
65 excess_electrons=self.excess_electrons,
66 tol=self.tol)
67 return dct
70class SJMEnvironment(Environment):
71 def __init__(self,
72 solvation: SolvationEnvironment,
73 jellium: Jellium):
74 super().__init__(solvation.natoms)
75 self.solvation = solvation
76 self.jellium = jellium
77 self.charge = jellium.charge
78 self.dielectric = solvation.dielectric
80 def create_poisson_solver(self, **kwargs):
81 ps = self.solvation.create_poisson_solver(**kwargs).solver
82 return SJMPoissonSolver(ps, self.solvation.dielectric)
84 def post_scf_convergence(self,
85 ibzwfs,
86 nelectrons,
87 occ_calc,
88 mixer,
89 log) -> bool:
90 converged = self.jellium.post_scf_convergence(
91 ibzwfs, nelectrons, occ_calc, mixer, log)
92 self.charge = self.jellium.charge
93 return converged
95 def update1(self, nt_r):
96 self.solvation.update1(nt_r)
97 self.jellium.update1(nt_r)
99 def update1pw(self, nt_g):
100 nt_r = self.jellium.grid.empty()
101 nt_r.scatter_from(nt_g.ifft(grid=self.jellium.grid.new(comm=None))
102 if nt_g is not None else None)
103 self.solvation.update1(nt_r)
104 self.jellium.update1pw(nt_g)
106 def update2(self, nt_r, vHt_r, vt_sr) -> float:
107 return self.solvation.update2(nt_r, vHt_r, vt_sr)
110class SJMPoissonSolver(PoissonSolverWrapper):
111 def __init__(self, solver, dielectric):
112 super().__init__(solver)
114 def solve(self,
115 vHt_r,
116 rhot_r) -> float:
117 self.solver.solve(vHt_r.data, rhot_r.data)
118 eps_r = vHt_r.desc.from_data(self.solver.dielectric.eps_gradeps[0])
119 eps0_r = eps_r.gather()
120 vHt0_r = vHt_r.gather()
121 if eps0_r is not None:
122 saw_tooth_z = modified_saw_tooth(eps0_r)
123 s1, s2 = saw_tooth_z[[2, 10]]
124 v1, v2 = vHt0_r.data[:, :, [2, 10]].mean(axis=(0, 1))
125 vHt0_r.data -= (v2 - v1) / (s2 - s1) * saw_tooth_z[np.newaxis,
126 np.newaxis]
127 vHt0_r.data -= vHt0_r.data[:, :, -1].mean()
128 vHt_r.scatter_from(vHt0_r)
129 return np.nan
132def modified_saw_tooth(eps_r: UGArray) -> np.ndarray:
133 a_z = 1.0 / eps_r.data.mean(axis=(0, 1))
134 saw_tooth_z = np.add.accumulate(a_z)
135 saw_tooth_z -= 0.5 * a_z # +0.5 from z=0.0 ???
136 return saw_tooth_z
139class SJMPWPoissonSolver(PWPoissonSolver):
140 def __init__(self, pw, dielectric):
141 super().__init__(pw)
142 self.dielectric = dielectric
143 self.saw_tooth_g = saw_tooth(pw, 0.25)
145 def solve(self, vHt_g, rhot_g):
146 energy = super().solve(vHt_g, rhot_g)
147 dipole = rhot_g.moment()[2]
148 slope = 4 * np.pi * dipole / rhot_g.desc.volume
149 vHt_g.data += slope * self.saw_tooth_g.data
150 # Shift potential so that it is zero above the slab:
151 shift = 0.5 * slope * rhot_g.desc.cell_cv[2, 2]
152 v0 = vHt_g.boundary_value(2)
153 vHt_g.data[0] -= shift + v0
154 return energy
157def saw_tooth_sympy():
158 """Fourier-transform."""
159 from sympy import Symbol, integrate, sin, var
160 z = var('z')
161 G = Symbol('G', positive=True)
162 b = Symbol('b', positive=True)
163 m = integrate(sin(G * z) * z, (z, 0, b))
164 print(m) # -b*cos(G*b)/G + sin(G*b)/G**2
167def saw_tooth(pw: PWDesc, width: float = 0.5) -> PWArray:
168 """Saw-tooth in reciprocal space with a slope of 1."""
169 assert np.allclose(pw.cell_cv[:2, 2], 0.0)
170 assert np.allclose(pw.cell_cv[2, :2], 0.0)
172 m0_g, m1_g = pw.indices_cG[:2, pw.ng1:pw.ng2] == 0
173 mask_g = m0_g & m1_g
174 Gz_i = pw.G_plus_k_Gv[mask_g, 2]
175 if pw.comm.rank == 0.0:
176 assert Gz_i[0] == 0.0
177 Gz_i[0] = 1.0
178 L = pw.cell_cv[2, 2]
179 b = L / 2
180 st_i = -(np.sin(b * Gz_i) / Gz_i -
181 b * np.cos(b * Gz_i)) / Gz_i * (2j / L)
182 if pw.comm.rank == 0.0:
183 st_i[0] = 0.0
185 # Make the saw-tooth more smooth (fold with Gaussian):
186 alpha = width**-2
187 st_i *= np.exp(-Gz_i**2 / (4 * alpha))
189 # Shift by half the cell height:
190 st_i *= np.exp(1j * Gz_i * b)
192 st_g = pw.zeros()
193 st_g.data[mask_g] = st_i
194 return st_g