Coverage for gpaw/convergence_criteria.py: 91%

220 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1from collections import deque 

2from inspect import signature 

3 

4import numpy as np 

5from ase.calculators.calculator import InputError 

6from ase.units import Bohr, Ha 

7 

8from gpaw.mpi import broadcast_float 

9 

10 

11def get_criterion(name): 

12 """Returns one of the pre-specified criteria by it's .name attribute, 

13 and raises sensible error if missing.""" 

14 # All built-in criteria should be in this list. 

15 criteria = [Energy, Density, Eigenstates, Eigenvalues, Forces, 

16 WorkFunction, MinIter, MaxIter] 

17 criteria = {c.name: c for c in criteria} 

18 try: 

19 return criteria[name] 

20 except KeyError: 

21 known = ', '.join(f'{key!r}' for key in criteria) 

22 msg = ( 

23 f'The convergence keyword "{name}" was supplied, which we do not ' 

24 'know how to handle. If this is a typo, please correct ' 

25 f'(known keywords are {known}). If this' 

26 ' is a user-written convergence criterion, it cannot be ' 

27 'imported with this function; please see the GPAW manual for ' 

28 'details.') 

29 raise InputError(msg) 

30 

31 

32def dict2criterion(dictionary): 

33 """Converts a dictionary to a convergence criterion. 

34 

35 The dictionary can either be that generated from 'todict'; that is like 

36 {'name': 'energy', 'tol': 0.005, 'n_old': 3}. Or from user-specified 

37 shortcut like {'energy': 0.005} or {'energy': (0.005, 3)}, or a 

38 combination like {'energy': {'name': 'energy', 'tol': 0.005, 'n_old': 3}. 

39 """ 

40 d = dictionary.copy() 

41 if 'name' in d: # from 'todict' 

42 name = d.pop('name') 

43 ThisCriterion = get_criterion(name) 

44 return ThisCriterion(**d) 

45 assert len(d) == 1 

46 name = list(d.keys())[0] 

47 if isinstance(d[name], dict) and 'name' in d[name]: 

48 return dict2criterion(d[name]) 

49 ThisCriterion = get_criterion(name) 

50 return ThisCriterion(*[d[name]]) 

51 

52 

53def check_convergence(criteria, ctx): 

54 entries = {} # for log file, per criteria 

55 converged_items = {} # True/False, per criteria 

56 override_others = False 

57 converged = True 

58 for name, criterion in criteria.items(): 

59 if not criterion.calc_last: 

60 ok, entry = criterion(ctx) 

61 if criterion.override_others: 

62 if ok: 

63 override_others = True 

64 else: 

65 converged = converged and ok 

66 converged_items[name] = ok 

67 entries[name] = entry 

68 

69 for name, criterion in criteria.items(): 

70 if criterion.calc_last: 

71 if converged: 

72 ok, entry = criterion(ctx) 

73 converged &= ok 

74 converged_items[name] = ok 

75 entries[name] = entry 

76 else: 

77 converged_items[name] = False 

78 entries[name] = '' 

79 

80 # Converged? 

81 return converged or override_others, converged_items, entries 

82 

83 

84class Criterion: 

85 """Base class for convergence criteria. 

86 

87 Automates the creation of the __repr__ and todict methods for generic 

88 classes. This will work for classes that save all arguments directly, 

89 like __init__(self, a, b): --> self.a = a, self.b = b. The todict 

90 method requires the class have a self.name attribute. All criteria 

91 (subclasses of Criterion) must define self.name, self.tablename, 

92 self.description, self.__init__, and self.__call___. See the online 

93 documentation for details. 

94 """ 

95 # If calc_last is True, will only be checked after all other (non-last) 

96 # criteria have been met. 

97 calc_last = False 

98 override_others = False 

99 description: str 

100 

101 def __repr__(self): 

102 parameters = signature(self.__class__).parameters 

103 s = ', '.join([str(getattr(self, p)) for p in parameters]) 

104 return self.__class__.__name__ + '(' + s + ')' 

105 

106 def todict(self): 

107 d = {'name': self.name} 

108 parameters = signature(self.__class__).parameters 

109 for parameter in parameters: 

110 d[parameter] = getattr(self, parameter) 

111 return d 

112 

113 def reset(self): 

114 pass 

115 

116 

117# Built-in criteria follow. Make sure that any new criteria added below 

118# are also added to to the list in get_criterion() so that it can import 

119# them correctly by name. 

120 

121 

122class Energy(Criterion): 

123 """A convergence criterion for the total energy. 

124 

125 Parameters: 

126 

127 tol: float 

128 Tolerance for conversion; that is the maximum variation among the 

129 last n_old values of the (extrapolated) total energy. 

130 n_old: int 

131 Number of energy values to compare. I.e., if n_old is 3, then this 

132 compares the peak-to-peak difference among the current total energy 

133 and the two previous. 

134 relative: bool 

135 Use total energy [eV] or total energy relative to number of 

136 valence electrons [eV/(valence electron)]. 

137 """ 

138 name = 'energy' 

139 tablename = 'energy' 

140 

141 def __init__(self, tol: float, *, n_old: int = 3, relative: bool = True): 

142 self.tol = tol 

143 self.n_old = n_old 

144 self.relative = relative 

145 self.description = ( 

146 f'Maximum [total energy] change in last {self.n_old} cyles: ' 

147 f'{self.tol:g} eV') 

148 if relative: 

149 self.description += ' / valence electron' 

150 

151 def reset(self): 

152 self._old = deque(maxlen=self.n_old) 

153 

154 def __call__(self, context): 

155 """Should return (bool, entry), where bool is True if converged and 

156 False if not, and entry is a <=5 character string to be printed in 

157 the user log file.""" 

158 # Note the previous code was calculating the peak-to- 

159 # peak energy difference on e_total_free, while reporting 

160 # e_total_extrapolated in the SCF table (logfile). I changed it to 

161 # use e_total_extrapolated for both. (Should be a miniscule 

162 # difference, but more consistent.) 

163 total_energy = context.ham.e_total_extrapolated * Ha 

164 if context.wfs.nvalence == 0 or not self.relative: 

165 energy = total_energy 

166 else: 

167 energy = total_energy / context.wfs.nvalence 

168 self._old.append(energy) # Pops off >3! 

169 error = np.inf 

170 if len(self._old) == self._old.maxlen: 

171 error = np.ptp(self._old) 

172 converged = error < self.tol 

173 entry = '' 

174 if np.isfinite(energy): 

175 entry = f'{total_energy:11.6f}' 

176 return converged, entry 

177 

178 

179class Density(Criterion): 

180 """A convergence criterion for the electron density. 

181 

182 Parameters: 

183 

184 tol : float 

185 Tolerance for conversion; that is the maximum change in the electron 

186 density, calculated as the integrated absolute value of the density 

187 change, normalized per valence electron. [electrons/(valence electron)] 

188 """ 

189 name = 'density' 

190 tablename = 'dens' 

191 

192 def __init__(self, tol): 

193 self.tol = tol 

194 self.description = ('Maximum integral of absolute [dens]ity change: ' 

195 '{:g} electrons / valence electron' 

196 .format(self.tol)) 

197 

198 def __call__(self, context): 

199 """Should return (bool, entry), where bool is True if converged and 

200 False if not, and entry is a <=5 character string to be printed in 

201 the user log file.""" 

202 if context.dens.fixed: 

203 # Old GPAW needs this 

204 return True, '' 

205 nv = context.wfs.nvalence 

206 if nv == 0: 

207 return True, '' 

208 # Make sure all agree on the density error. 

209 error = broadcast_float(context.dens.error, context.wfs.world) / nv 

210 converged = (error < self.tol) 

211 if (error is None or np.isinf(error) or error == 0): 

212 entry = '' 

213 else: 

214 entry = f'{np.log10(error):+5.2f}' 

215 return converged, entry 

216 

217 

218class Eigenstates(Criterion): 

219 """A convergence criterion for the eigenstates. 

220 

221 Parameters: 

222 

223 tol : float 

224 Tolerance for conversion; that is the maximum change in the 

225 eigenstates, calculated as the integration of the square of the 

226 residuals of the Kohn--Sham equations, normalized per valence 

227 electron. [eV^2/(valence electron)] 

228 """ 

229 name = 'eigenstates' 

230 tablename = 'eigst' 

231 

232 def __init__(self, tol): 

233 self.tol = tol 

234 self.description = ('Maximum integral of absolute [eigenst]ate ' 

235 'change: {:g} eV^2 / valence electron' 

236 .format(self.tol)) 

237 

238 def __call__(self, context): 

239 """Should return (bool, entry), where bool is True if converged and 

240 False if not, and entry is a <=5 character string to be printed in 

241 the user log file.""" 

242 if context.wfs.nvalence == 0: 

243 return True, '' 

244 error = self.get_error(context) 

245 converged = (error < self.tol) 

246 if (context.wfs.nvalence == 0 or error == 0 or np.isinf(error)): 

247 entry = '' 

248 else: 

249 entry = f'{np.log10(error):+6.2f}' 

250 return converged, entry 

251 

252 def get_error(self, context): 

253 """Returns the raw error.""" 

254 return context.wfs.eigensolver.error * Ha**2 / context.wfs.nvalence 

255 

256 

257class Eigenvalues(Criterion): 

258 name = 'eigenvalues' 

259 tablename = 'eigs' 

260 calc_last = False 

261 

262 def __init__(self, tol=1e-3): 

263 self.tol = tol 

264 self.description = 'Maximum absolute change in eigenvalues [eV].' 

265 

266 def __call__(self, context): 

267 if context.wfs.nvalence == 0: 

268 return True, '' 

269 error = self.get_error(context) 

270 converged = (error < self.tol) 

271 if (context.wfs.nvalence == 0 or error == 0 or np.isinf(error)): 

272 entry = '' 

273 else: 

274 entry = f'{np.log10(error):+6.2f}' 

275 return converged, entry 

276 

277 def get_error(self, context): 

278 return context.eig_error * Ha 

279 

280 

281class Forces(Criterion): 

282 """A convergence criterion for the forces. 

283 

284 Parameters: 

285 

286 atol : float 

287 Absolute tolerance for convergence; that is, the force on each atom 

288 is compared with its force from the previous iteration, and the change 

289 in each atom's force is calculated as an l2-norm 

290 (Euclidean distance). The atom with the largest norm must be less 

291 than tol. [eV/Angstrom] 

292 rtol : float 

293 Relative tolerance for convergence. The difference in the l2-norm of 

294 force on each atom is calculated, and convergence is achieved when 

295 the largest difference between two iterations is rtol * max force. 

296 calc_last : bool 

297 If True, calculates forces last; that is, it waits until all other 

298 convergence criteria are satisfied before checking to see if the 

299 forces have converged. (This is more computationally efficient.) 

300 If False, checks forces at each SCF step. 

301 """ 

302 name = 'forces' 

303 tablename = 'force' 

304 

305 def __init__(self, atol, rtol=np.inf, calc_last=True): 

306 self.atol = atol 

307 self.rtol = rtol 

308 self.description = ('Maximum change in the atomic [forces] across ' 

309 f'last 2 cycles: {self.atol} eV/Ang OR\n' 

310 'Maximum error relative to the maximum ' 

311 f'force is below {self.rtol}') 

312 self.calc_last = calc_last 

313 self.reset() 

314 

315 def __call__(self, context): 

316 """Should return (bool, entry), where bool is True if converged and 

317 False if not, and entry is a <=5 character string to be printed in 

318 the user log file.""" 

319 

320 # criterion is off; backwards compatibility 

321 if np.isinf(self.atol) and np.isinf(self.rtol): 

322 return True, '' 

323 F_av = context.calculate_forces() * (Ha / Bohr) 

324 error = np.inf 

325 max_force = np.max(np.linalg.norm(F_av, axis=1)) 

326 if self.old_F_av is not None: 

327 error = ((F_av - self.old_F_av)**2).sum(1).max()**0.5 

328 self.old_F_av = F_av 

329 

330 if np.isfinite(self.rtol): 

331 error_threshold = min(self.atol, self.rtol * max_force) 

332 else: 

333 # Avoid possible inf * 0.0: 

334 error_threshold = self.atol 

335 converged = error < error_threshold 

336 

337 entry = '' 

338 if np.isfinite(error): 

339 if error: 

340 entry = f'{np.log10(error):+5.2f}' 

341 else: 

342 entry = '-inf' 

343 return converged, entry 

344 

345 def reset(self): 

346 self.old_F_av = None 

347 

348 

349class WorkFunction(Criterion): 

350 """A convergence criterion for the work function. 

351 

352 Parameters: 

353 

354 tol : float 

355 Tolerance for conversion; that is the maximum variation among the 

356 last n_old values of either work function. [eV] 

357 n_old : int 

358 Number of work functions to compare. I.e., if n_old is 3, then this 

359 compares the peak-to-peak difference among the current work 

360 function and the two previous. 

361 """ 

362 name = 'work function' 

363 tablename = 'wkfxn' 

364 

365 def __init__(self, tol=0.005, n_old=3): 

366 self.tol = tol 

367 self.n_old = n_old 

368 self.description = ('Maximum change in the last {:d} ' 

369 'work functions [wkfxn]: {:g} eV' 

370 .format(n_old, tol)) 

371 

372 def reset(self): 

373 self._old = deque(maxlen=self.n_old) 

374 

375 def __call__(self, context): 

376 """Should return (bool, entry), where bool is True if converged and 

377 False if not, and entry is a <=5 character string to be printed in 

378 the user log file.""" 

379 workfunctions = context.ham.get_workfunctions(context.wfs) 

380 workfunctions = Ha * np.array(workfunctions) 

381 self._old.append(workfunctions) # Pops off >3! 

382 if len(self._old) == self._old.maxlen: 

383 error = max(np.ptp(self._old, axis=0)) 

384 else: 

385 error = np.inf 

386 converged = (error < self.tol) 

387 if error < np.inf: 

388 entry = f'{np.log10(error):+5.2f}' 

389 else: 

390 entry = '' 

391 return converged, entry 

392 

393 

394class MinIter(Criterion): 

395 """A convergence criterion that enforces a minimum number of iterations. 

396 

397 Parameters: 

398 

399 n : int 

400 Minimum number of iterations that must be complete before 

401 the SCF cycle exits. 

402 """ 

403 calc_last = False 

404 name = 'minimum iterations' 

405 tablename = 'minit' 

406 

407 def __init__(self, n): 

408 self.n = n 

409 self.description = f'Minimum number of iterations [minit]: {n}' 

410 

411 def __call__(self, context): 

412 converged = context.niter >= self.n 

413 entry = f'{context.niter:d}' 

414 return converged, entry 

415 

416 

417class MaxIter(Criterion): 

418 """A convergence criterion that enforces a maximum number of iterations. 

419 

420 Parameters: 

421 

422 n : int 

423 Maximum number of iterations that must be complete before 

424 the SCF cycle exits. 

425 """ 

426 calc_last = False 

427 name = 'maximum iterations' 

428 tablename = 'maxit' 

429 override_others = True 

430 

431 def __init__(self, n): 

432 self.n = n 

433 self.description = f'Maximum number of iterations [minit]: {n}' 

434 

435 def __call__(self, context): 

436 converged = context.niter >= self.n 

437 entry = f'{context.niter:d}' 

438 return converged, entry