Coverage for gpaw/poisson_moment.py: 99%
141 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1from typing import Any, Dict, Optional, List, Sequence, Union
3import numpy as np
4from ase.units import Bohr
5from ase.utils.timing import Timer
6from gpaw.poisson import _PoissonSolver, create_poisson_solver
7from gpaw.utilities.gauss import Gaussian
8from gpaw.typing import Array1D
9from gpaw.utilities.timing import nulltimer, NullTimer
11from ase.utils.timing import timer
14MomentCorrectionsType = Union[int, List[Dict[str, Any]]]
17class MomentCorrection:
19 def __init__(self,
20 center: Optional[Union[Sequence, Array1D]],
21 moms: Union[int, Sequence[int]]):
22 if center is not None:
23 center = np.asarray(center) / Bohr
24 self.center = center
25 self.moms = np.asarray(moms)
27 def todict(self) -> Dict[str, Any]:
28 """ return dictionary description, converting the moment correction
29 from units of Bohr to Ångström """
31 center = self.center
32 if center is not None:
33 center = center * Bohr
35 dict_out = dict(moms=self.moms, center=center)
37 return dict_out
39 def __str__(self) -> str:
40 if self.center is None:
41 center = 'center'
42 else:
43 center = ', '.join([f'{x:.2f}' for x in self.center * Bohr])
45 if np.allclose(np.diff(self.moms), 1):
46 # Increasing sequence
47 moms = f'range({self.moms[0]}, {self.moms[-1] + 1})'
48 else:
49 # List of integers
50 _moms = ', '.join([f'{m:.0f}' for m in self.moms])
51 moms = f'({_moms})'
52 return f'[{center}] {moms}'
54 def __repr__(self) -> str:
55 center = self.center
56 if center is not None:
57 center *= Bohr
58 return f'{repr(self.moms)} @ {repr(center)}'
61class MomentCorrectionPoissonSolver(_PoissonSolver):
62 """Wrapper for the poisson solver that includes moment corrections
64 Parameters
65 ----------
66 poissonsolver
67 underlying poisson solver
68 moment_corrections
69 list of moment corrections, expressed as dictionaries
70 `{'moms': ..., 'center': ...}` that specify the multipole moments
71 and their centres.
73 moment_corrections = [{'moms': moms_list1, 'center': center1},
74 {'moms': moms_list2, 'center': center2},
75 ...]
77 Here moms_listX is list of integers of multipole moments to be
78 corrected at centerX. For example moms_list=range(4) corresponds to
79 s, p_x, p_y and p_z type multipoles.
81 Optionally setting moment_corrections to an integer m is equivalent to
82 including multipole moments corresponding to range(m) at the center of
83 the cell
86 timer
87 timer
89 """
91 def __init__(self,
92 poissonsolver: Union[_PoissonSolver, Dict[str, Any]],
93 moment_corrections: Optional[MomentCorrectionsType],
94 timer: Union[NullTimer, Timer] = nulltimer):
96 self._initialized = False
97 self.poissonsolver = create_poisson_solver(poissonsolver)
98 self.timer = timer
100 if moment_corrections is None:
101 self.moment_corrections = []
102 elif isinstance(moment_corrections, int):
103 moms = range(moment_corrections)
104 center = None
105 self.moment_corrections = [MomentCorrection(moms=moms,
106 center=center)]
107 elif isinstance(moment_corrections, list):
108 assert all(['moms' in mom and 'center' in mom
109 for mom in moment_corrections]), \
110 (f'{self.__class__.__name__}: each element in '
111 'moment_correction must be a dictionary '
112 'with the keys "moms" and "center"')
113 self.moment_corrections = [MomentCorrection(**mom)
114 for mom in moment_corrections]
115 else:
116 raise ValueError(f'{self.__class__.__name__}: moment_correction '
117 'must be a list of dictionaries')
119 def todict(self):
120 mom_corr = [mom.todict() for mom in self.moment_corrections]
121 d = {'name': 'MomentCorrectionPoissonSolver',
122 'poissonsolver': self.poissonsolver.todict(),
123 'moment_corrections': mom_corr}
125 return d
127 def set_grid_descriptor(self, gd):
128 self.poissonsolver.set_grid_descriptor(gd)
129 self.gd = gd
131 def get_description(self) -> str:
132 description = self.poissonsolver.get_description()
133 n = len(self.moment_corrections)
135 lines = [description]
136 lines.extend([f' {n} moment corrections:'])
137 lines.extend([f' {str(mom)}'
138 for mom in self.moment_corrections])
140 return '\n'.join(lines)
142 @timer('Poisson initialize')
143 def _init(self):
144 if self._initialized:
145 return
146 self.poissonsolver._init()
148 if not self.gd.orthogonal or self.gd.pbc_c.any():
149 raise NotImplementedError('Only orthogonal unit cells '
150 'and non-periodic boundary '
151 'conditions are tested')
152 self.load_moment_corrections_gauss()
154 self._initialized = True
156 @timer('Load moment corrections')
157 def load_moment_corrections_gauss(self):
158 self.gauss_i = []
159 self.mom_ij = []
160 self.mask_ig = []
162 if len(self.moment_corrections) == 0:
163 return
165 mask_ir = []
166 r_ir = []
168 for rmom in self.moment_corrections:
169 center = rmom.center
170 mom_j = rmom.moms
171 gauss = Gaussian(self.gd, center=center)
172 self.gauss_i.append(gauss)
173 r_ir.append(gauss.r.ravel())
174 mask_ir.append(self.gd.zeros(dtype=int).ravel())
175 self.mom_ij.append(mom_j)
177 r_ir = np.array(r_ir)
178 mask_ir = np.array(mask_ir)
180 Ni = r_ir.shape[0]
181 Nr = r_ir.shape[1]
183 for r in range(Nr):
184 i = np.argmin(r_ir[:, r])
185 mask_ir[i, r] = 1
187 for i in range(Ni):
188 mask_r = mask_ir[i]
189 mask_g = mask_r.reshape(self.gd.n_c)
190 self.mask_ig.append(mask_g)
192 def solve(self, phi, rho, **kwargs):
193 self._init()
194 return self._solve(phi, rho, **kwargs)
196 @timer('Solve')
197 def _solve(self, phi, rho, **kwargs):
198 timer = kwargs.get('timer', self.timer)
200 if len(self.moment_corrections) > 0:
201 assert not self.gd.pbc_c.any()
203 timer.start('Multipole moment corrections')
205 rho_neutral = rho * 0.0
206 phi_cor_g = self.gd.zeros()
207 for gauss, mask_g, mom_j in zip(self.gauss_i, self.mask_ig,
208 self.mom_ij):
209 rho_masked = rho * mask_g
210 for mom in mom_j:
211 phi_cor_g += gauss.remove_moment(rho_masked, mom)
212 rho_neutral += rho_masked
214 # Remove multipoles for better initial guess
215 phi -= phi_cor_g
217 timer.stop('Multipole moment corrections')
219 timer.start('Solve neutral')
220 niter = self.poissonsolver.solve(phi, rho_neutral, **kwargs)
221 timer.stop('Solve neutral')
223 timer.start('Multipole moment corrections')
224 # correct error introduced by removing multipoles
225 phi += phi_cor_g
226 timer.stop('Multipole moment corrections')
228 return niter
229 else:
230 return self.poissonsolver.solve(phi, rho, **kwargs)
232 def estimate_memory(self, mem):
233 self.poissonsolver.estimate_memory(mem)
234 gdbytes = self.gd.bytecount()
235 if self.moment_corrections is not None:
236 mem.subnode('moment_corrections masks',
237 len(self.moment_corrections) * gdbytes)
239 def __repr__(self):
240 if len(self.moment_corrections) == 0:
241 corrections_str = 'no corrections'
242 elif len(self.moment_corrections) < 2:
243 corrections_str = f'{repr(self.moment_corrections[0])}'
244 else:
245 corrections_str = f'{len(self.moment_corrections)} corrections'
247 representation = f'MomentCorrectionPoissonSolver ({corrections_str})'
248 return representation