Coverage for gpaw/convergence_criteria.py: 91%
220 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1from collections import deque
2from inspect import signature
4import numpy as np
5from ase.calculators.calculator import InputError
6from ase.units import Bohr, Ha
8from gpaw.mpi import broadcast_float
11def get_criterion(name):
12 """Returns one of the pre-specified criteria by it's .name attribute,
13 and raises sensible error if missing."""
14 # All built-in criteria should be in this list.
15 criteria = [Energy, Density, Eigenstates, Eigenvalues, Forces,
16 WorkFunction, MinIter, MaxIter]
17 criteria = {c.name: c for c in criteria}
18 try:
19 return criteria[name]
20 except KeyError:
21 known = ', '.join(f'{key!r}' for key in criteria)
22 msg = (
23 f'The convergence keyword "{name}" was supplied, which we do not '
24 'know how to handle. If this is a typo, please correct '
25 f'(known keywords are {known}). If this'
26 ' is a user-written convergence criterion, it cannot be '
27 'imported with this function; please see the GPAW manual for '
28 'details.')
29 raise InputError(msg)
32def dict2criterion(dictionary):
33 """Converts a dictionary to a convergence criterion.
35 The dictionary can either be that generated from 'todict'; that is like
36 {'name': 'energy', 'tol': 0.005, 'n_old': 3}. Or from user-specified
37 shortcut like {'energy': 0.005} or {'energy': (0.005, 3)}, or a
38 combination like {'energy': {'name': 'energy', 'tol': 0.005, 'n_old': 3}.
39 """
40 d = dictionary.copy()
41 if 'name' in d: # from 'todict'
42 name = d.pop('name')
43 ThisCriterion = get_criterion(name)
44 return ThisCriterion(**d)
45 assert len(d) == 1
46 name = list(d.keys())[0]
47 if isinstance(d[name], dict) and 'name' in d[name]:
48 return dict2criterion(d[name])
49 ThisCriterion = get_criterion(name)
50 return ThisCriterion(*[d[name]])
53def check_convergence(criteria, ctx):
54 entries = {} # for log file, per criteria
55 converged_items = {} # True/False, per criteria
56 override_others = False
57 converged = True
58 for name, criterion in criteria.items():
59 if not criterion.calc_last:
60 ok, entry = criterion(ctx)
61 if criterion.override_others:
62 if ok:
63 override_others = True
64 else:
65 converged = converged and ok
66 converged_items[name] = ok
67 entries[name] = entry
69 for name, criterion in criteria.items():
70 if criterion.calc_last:
71 if converged:
72 ok, entry = criterion(ctx)
73 converged &= ok
74 converged_items[name] = ok
75 entries[name] = entry
76 else:
77 converged_items[name] = False
78 entries[name] = ''
80 # Converged?
81 return converged or override_others, converged_items, entries
84class Criterion:
85 """Base class for convergence criteria.
87 Automates the creation of the __repr__ and todict methods for generic
88 classes. This will work for classes that save all arguments directly,
89 like __init__(self, a, b): --> self.a = a, self.b = b. The todict
90 method requires the class have a self.name attribute. All criteria
91 (subclasses of Criterion) must define self.name, self.tablename,
92 self.description, self.__init__, and self.__call___. See the online
93 documentation for details.
94 """
95 # If calc_last is True, will only be checked after all other (non-last)
96 # criteria have been met.
97 calc_last = False
98 override_others = False
99 description: str
101 def __repr__(self):
102 parameters = signature(self.__class__).parameters
103 s = ', '.join([str(getattr(self, p)) for p in parameters])
104 return self.__class__.__name__ + '(' + s + ')'
106 def todict(self):
107 d = {'name': self.name}
108 parameters = signature(self.__class__).parameters
109 for parameter in parameters:
110 d[parameter] = getattr(self, parameter)
111 return d
113 def reset(self):
114 pass
117# Built-in criteria follow. Make sure that any new criteria added below
118# are also added to to the list in get_criterion() so that it can import
119# them correctly by name.
122class Energy(Criterion):
123 """A convergence criterion for the total energy.
125 Parameters:
127 tol: float
128 Tolerance for conversion; that is the maximum variation among the
129 last n_old values of the (extrapolated) total energy.
130 n_old: int
131 Number of energy values to compare. I.e., if n_old is 3, then this
132 compares the peak-to-peak difference among the current total energy
133 and the two previous.
134 relative: bool
135 Use total energy [eV] or total energy relative to number of
136 valence electrons [eV/(valence electron)].
137 """
138 name = 'energy'
139 tablename = 'energy'
141 def __init__(self, tol: float, *, n_old: int = 3, relative: bool = True):
142 self.tol = tol
143 self.n_old = n_old
144 self.relative = relative
145 self.description = (
146 f'Maximum [total energy] change in last {self.n_old} cyles: '
147 f'{self.tol:g} eV')
148 if relative:
149 self.description += ' / valence electron'
151 def reset(self):
152 self._old = deque(maxlen=self.n_old)
154 def __call__(self, context):
155 """Should return (bool, entry), where bool is True if converged and
156 False if not, and entry is a <=5 character string to be printed in
157 the user log file."""
158 # Note the previous code was calculating the peak-to-
159 # peak energy difference on e_total_free, while reporting
160 # e_total_extrapolated in the SCF table (logfile). I changed it to
161 # use e_total_extrapolated for both. (Should be a miniscule
162 # difference, but more consistent.)
163 total_energy = context.ham.e_total_extrapolated * Ha
164 if context.wfs.nvalence == 0 or not self.relative:
165 energy = total_energy
166 else:
167 energy = total_energy / context.wfs.nvalence
168 self._old.append(energy) # Pops off >3!
169 error = np.inf
170 if len(self._old) == self._old.maxlen:
171 error = np.ptp(self._old)
172 converged = error < self.tol
173 entry = ''
174 if np.isfinite(energy):
175 entry = f'{total_energy:11.6f}'
176 return converged, entry
179class Density(Criterion):
180 """A convergence criterion for the electron density.
182 Parameters:
184 tol : float
185 Tolerance for conversion; that is the maximum change in the electron
186 density, calculated as the integrated absolute value of the density
187 change, normalized per valence electron. [electrons/(valence electron)]
188 """
189 name = 'density'
190 tablename = 'dens'
192 def __init__(self, tol):
193 self.tol = tol
194 self.description = ('Maximum integral of absolute [dens]ity change: '
195 '{:g} electrons / valence electron'
196 .format(self.tol))
198 def __call__(self, context):
199 """Should return (bool, entry), where bool is True if converged and
200 False if not, and entry is a <=5 character string to be printed in
201 the user log file."""
202 if context.dens.fixed:
203 # Old GPAW needs this
204 return True, ''
205 nv = context.wfs.nvalence
206 if nv == 0:
207 return True, ''
208 # Make sure all agree on the density error.
209 error = broadcast_float(context.dens.error, context.wfs.world) / nv
210 converged = (error < self.tol)
211 if (error is None or np.isinf(error) or error == 0):
212 entry = ''
213 else:
214 entry = f'{np.log10(error):+5.2f}'
215 return converged, entry
218class Eigenstates(Criterion):
219 """A convergence criterion for the eigenstates.
221 Parameters:
223 tol : float
224 Tolerance for conversion; that is the maximum change in the
225 eigenstates, calculated as the integration of the square of the
226 residuals of the Kohn--Sham equations, normalized per valence
227 electron. [eV^2/(valence electron)]
228 """
229 name = 'eigenstates'
230 tablename = 'eigst'
232 def __init__(self, tol):
233 self.tol = tol
234 self.description = ('Maximum integral of absolute [eigenst]ate '
235 'change: {:g} eV^2 / valence electron'
236 .format(self.tol))
238 def __call__(self, context):
239 """Should return (bool, entry), where bool is True if converged and
240 False if not, and entry is a <=5 character string to be printed in
241 the user log file."""
242 if context.wfs.nvalence == 0:
243 return True, ''
244 error = self.get_error(context)
245 converged = (error < self.tol)
246 if (context.wfs.nvalence == 0 or error == 0 or np.isinf(error)):
247 entry = ''
248 else:
249 entry = f'{np.log10(error):+6.2f}'
250 return converged, entry
252 def get_error(self, context):
253 """Returns the raw error."""
254 return context.wfs.eigensolver.error * Ha**2 / context.wfs.nvalence
257class Eigenvalues(Criterion):
258 name = 'eigenvalues'
259 tablename = 'eigs'
260 calc_last = False
262 def __init__(self, tol=1e-3):
263 self.tol = tol
264 self.description = 'Maximum absolute change in eigenvalues [eV].'
266 def __call__(self, context):
267 if context.wfs.nvalence == 0:
268 return True, ''
269 error = self.get_error(context)
270 converged = (error < self.tol)
271 if (context.wfs.nvalence == 0 or error == 0 or np.isinf(error)):
272 entry = ''
273 else:
274 entry = f'{np.log10(error):+6.2f}'
275 return converged, entry
277 def get_error(self, context):
278 return context.eig_error * Ha
281class Forces(Criterion):
282 """A convergence criterion for the forces.
284 Parameters:
286 atol : float
287 Absolute tolerance for convergence; that is, the force on each atom
288 is compared with its force from the previous iteration, and the change
289 in each atom's force is calculated as an l2-norm
290 (Euclidean distance). The atom with the largest norm must be less
291 than tol. [eV/Angstrom]
292 rtol : float
293 Relative tolerance for convergence. The difference in the l2-norm of
294 force on each atom is calculated, and convergence is achieved when
295 the largest difference between two iterations is rtol * max force.
296 calc_last : bool
297 If True, calculates forces last; that is, it waits until all other
298 convergence criteria are satisfied before checking to see if the
299 forces have converged. (This is more computationally efficient.)
300 If False, checks forces at each SCF step.
301 """
302 name = 'forces'
303 tablename = 'force'
305 def __init__(self, atol, rtol=np.inf, calc_last=True):
306 self.atol = atol
307 self.rtol = rtol
308 self.description = ('Maximum change in the atomic [forces] across '
309 f'last 2 cycles: {self.atol} eV/Ang OR\n'
310 'Maximum error relative to the maximum '
311 f'force is below {self.rtol}')
312 self.calc_last = calc_last
313 self.reset()
315 def __call__(self, context):
316 """Should return (bool, entry), where bool is True if converged and
317 False if not, and entry is a <=5 character string to be printed in
318 the user log file."""
320 # criterion is off; backwards compatibility
321 if np.isinf(self.atol) and np.isinf(self.rtol):
322 return True, ''
323 F_av = context.calculate_forces() * (Ha / Bohr)
324 error = np.inf
325 max_force = np.max(np.linalg.norm(F_av, axis=1))
326 if self.old_F_av is not None:
327 error = ((F_av - self.old_F_av)**2).sum(1).max()**0.5
328 self.old_F_av = F_av
330 if np.isfinite(self.rtol):
331 error_threshold = min(self.atol, self.rtol * max_force)
332 else:
333 # Avoid possible inf * 0.0:
334 error_threshold = self.atol
335 converged = error < error_threshold
337 entry = ''
338 if np.isfinite(error):
339 if error:
340 entry = f'{np.log10(error):+5.2f}'
341 else:
342 entry = '-inf'
343 return converged, entry
345 def reset(self):
346 self.old_F_av = None
349class WorkFunction(Criterion):
350 """A convergence criterion for the work function.
352 Parameters:
354 tol : float
355 Tolerance for conversion; that is the maximum variation among the
356 last n_old values of either work function. [eV]
357 n_old : int
358 Number of work functions to compare. I.e., if n_old is 3, then this
359 compares the peak-to-peak difference among the current work
360 function and the two previous.
361 """
362 name = 'work function'
363 tablename = 'wkfxn'
365 def __init__(self, tol=0.005, n_old=3):
366 self.tol = tol
367 self.n_old = n_old
368 self.description = ('Maximum change in the last {:d} '
369 'work functions [wkfxn]: {:g} eV'
370 .format(n_old, tol))
372 def reset(self):
373 self._old = deque(maxlen=self.n_old)
375 def __call__(self, context):
376 """Should return (bool, entry), where bool is True if converged and
377 False if not, and entry is a <=5 character string to be printed in
378 the user log file."""
379 workfunctions = context.ham.get_workfunctions(context.wfs)
380 workfunctions = Ha * np.array(workfunctions)
381 self._old.append(workfunctions) # Pops off >3!
382 if len(self._old) == self._old.maxlen:
383 error = max(np.ptp(self._old, axis=0))
384 else:
385 error = np.inf
386 converged = (error < self.tol)
387 if error < np.inf:
388 entry = f'{np.log10(error):+5.2f}'
389 else:
390 entry = ''
391 return converged, entry
394class MinIter(Criterion):
395 """A convergence criterion that enforces a minimum number of iterations.
397 Parameters:
399 n : int
400 Minimum number of iterations that must be complete before
401 the SCF cycle exits.
402 """
403 calc_last = False
404 name = 'minimum iterations'
405 tablename = 'minit'
407 def __init__(self, n):
408 self.n = n
409 self.description = f'Minimum number of iterations [minit]: {n}'
411 def __call__(self, context):
412 converged = context.niter >= self.n
413 entry = f'{context.niter:d}'
414 return converged, entry
417class MaxIter(Criterion):
418 """A convergence criterion that enforces a maximum number of iterations.
420 Parameters:
422 n : int
423 Maximum number of iterations that must be complete before
424 the SCF cycle exits.
425 """
426 calc_last = False
427 name = 'maximum iterations'
428 tablename = 'maxit'
429 override_others = True
431 def __init__(self, n):
432 self.n = n
433 self.description = f'Maximum number of iterations [minit]: {n}'
435 def __call__(self, context):
436 converged = context.niter >= self.n
437 entry = f'{context.niter:d}'
438 return converged, entry