Coverage for gpaw/new/scf.py: 84%

114 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1from __future__ import annotations 

2 

3import itertools 

4import warnings 

5from math import inf 

6from types import SimpleNamespace 

7from typing import Any, Callable 

8 

9import numpy as np 

10 

11from gpaw import KohnShamConvergenceError 

12from gpaw.convergence_criteria import (Criterion, check_convergence, 

13 dict2criterion) 

14from gpaw.new.energies import DFTEnergies 

15from gpaw.new.logger import indent 

16from gpaw.new.ibzwfs import IBZWaveFunctions 

17from gpaw.scf import write_iteration 

18from gpaw.typing import Array2D 

19 

20 

21class TooFewBandsError(KohnShamConvergenceError): 

22 """Not enough bands for CBM+x convergence criterion.""" 

23 

24 

25class SCFLoop: 

26 def __init__(self, 

27 hamiltonian, 

28 occ_calc, 

29 eigensolver, 

30 mixer, 

31 comm, 

32 convergence, 

33 maxiter): 

34 self.hamiltonian = hamiltonian 

35 self.eigensolver = eigensolver 

36 self.mixer = mixer 

37 self.occ_calc = occ_calc 

38 self.comm = comm 

39 self.convergence = create_convergence_criteria(convergence) 

40 self.maxiter = maxiter 

41 self.niter = 0 

42 self.update_density_and_potential = True 

43 self.fix_fermi_level = False 

44 

45 def __repr__(self): 

46 return 'SCFLoop(...)' 

47 

48 def __str__(self): 

49 return (f'eigensolver:\n{indent(self.eigensolver)}\n' 

50 f'{self.mixer}\n' 

51 f'occupation numbers:\n{indent(self.occ_calc)}\n') 

52 

53 def iterate(self, 

54 ibzwfs: IBZWaveFunctions, 

55 density, 

56 potential, 

57 energies: DFTEnergies, 

58 pot_calc, 

59 *, 

60 maxiter=None, 

61 calculate_forces=None, 

62 log=None): 

63 cc = self.convergence 

64 maxiter = maxiter or self.maxiter 

65 

66 if log: 

67 log('convergence criteria:') 

68 for criterion in cc.values(): 

69 if criterion.description is not None: 

70 log('- ' + criterion.description) 

71 log(f'maximum number of iterations: {self.maxiter}\n') 

72 

73 self.mixer.reset() 

74 

75 self.occ_calc.initialize_reference_orbitals() 

76 

77 if self.update_density_and_potential: 

78 dens_error = self.mixer.mix(density) 

79 else: 

80 dens_error = 0.0 

81 

82 for self.niter in itertools.count(start=1): 

83 eig_error, wfs_error, energies = self.eigensolver.iterate( 

84 ibzwfs, density, potential, 

85 self.hamiltonian, pot_calc, energies) 

86 nelectrons = (density.nvalence - density.charge + 

87 pot_calc.environment.charge) 

88 e_band, e_entropy, e_extrapolation = ibzwfs.calculate_occs( 

89 self.occ_calc, 

90 nelectrons, 

91 fix_fermi_level=self.fix_fermi_level) 

92 

93 energies.set(**pot_calc.xc.energies, 

94 band=e_band, 

95 entropy=e_entropy, 

96 extrapolation=e_extrapolation) 

97 

98 ctx = SCFContext( 

99 log, self.niter, energies, 

100 ibzwfs, density, potential, 

101 wfs_error, dens_error, eig_error, 

102 self.comm, calculate_forces, 

103 pot_calc, self.update_density_and_potential) 

104 

105 yield ctx 

106 

107 converged, converged_items, entries = check_convergence(cc, ctx) 

108 nconverged = self.comm.sum_scalar(int(converged)) 

109 assert nconverged in [0, self.comm.size], converged_items 

110 

111 if log: 

112 write_iteration(cc, converged_items, entries, ctx, log) 

113 

114 if converged: 

115 converged = pot_calc.environment.post_scf_convergence( 

116 ibzwfs, nelectrons, self.occ_calc, self.mixer, log) 

117 if converged: 

118 break 

119 

120 if self.niter == maxiter: 

121 if wfs_error < inf: 

122 raise KohnShamConvergenceError 

123 raise TooFewBandsError 

124 

125 if self.update_density_and_potential: 

126 density.update(ibzwfs, ked=pot_calc.xc.type == 'MGGA') 

127 dens_error = self.mixer.mix(density) 

128 potential, energies, _ = pot_calc.calculate( 

129 density, ibzwfs, potential.vHt_x) 

130 

131 self.eigensolver.postprocess( 

132 ibzwfs, density, potential, self.hamiltonian) 

133 

134 

135class SCFContext: 

136 def __init__(self, 

137 log, 

138 niter: int, 

139 energies: DFTEnergies, 

140 ibzwfs, 

141 density, 

142 potential, 

143 wfs_error: float, 

144 dens_error: float, 

145 eig_error: float, 

146 comm, 

147 calculate_forces: Callable[[], Array2D], 

148 pot_calc, 

149 update_density_and_potential): 

150 self.log = log 

151 self.niter = niter 

152 self.energies = energies 

153 self.ibzwfs = ibzwfs 

154 self.density = density 

155 self.potential = potential 

156 energy = energies.total_extrapolated 

157 self.ham = SimpleNamespace(e_total_extrapolated=energy, 

158 get_workfunctions=self._get_workfunctions) 

159 self.wfs = SimpleNamespace(nvalence=density.nvalence + 

160 pot_calc.environment.charge, 

161 world=comm, 

162 eigensolver=SimpleNamespace( 

163 error=wfs_error), 

164 nspins=density.ndensities, 

165 collinear=density.collinear) 

166 self.dens = SimpleNamespace( 

167 calculate_magnetic_moments=density.calculate_magnetic_moments, 

168 fixed=not update_density_and_potential, 

169 error=dens_error) 

170 self.eig_error = eig_error 

171 self.calculate_forces = calculate_forces 

172 self.poisson_solver = pot_calc.poisson_solver 

173 

174 def _get_workfunctions(self, _): 

175 vacuum_level = self.potential.get_vacuum_level() 

176 (fermi_level,) = self.ibzwfs.fermi_levels 

177 wf = vacuum_level - fermi_level 

178 delta = self.poisson_solver.dipole_layer_correction() 

179 return np.array([wf + delta, wf - delta]) 

180 

181 

182def create_convergence_criteria(criteria: dict[str, Any] 

183 ) -> dict[str, Criterion]: 

184 criteria = criteria.copy() 

185 for k, v in [('energy', 0.0005), # eV / electron 

186 ('density', 1.0e-4), # electrons / electron 

187 ('eigenstates', 4.0e-8)]: # eV^2 / electron 

188 if k not in criteria: 

189 criteria[k] = v 

190 # Gather convergence criteria for SCF loop. 

191 custom = criteria.pop('custom', []) 

192 for name, criterion in criteria.items(): 

193 if hasattr(criterion, 'todict'): 

194 # 'Copy' so no two calculators share an instance. 

195 criteria[name] = dict2criterion(criterion.todict()) 

196 else: 

197 criteria[name] = dict2criterion({name: criterion}) 

198 

199 if not isinstance(custom, (list, tuple)): 

200 custom = [custom] 

201 for criterion in custom: 

202 if isinstance(criterion, dict): # from .gpw file 

203 msg = ('Custom convergence criterion "{:s}" encountered, ' 

204 'which GPAW does not know how to load. This ' 

205 'criterion is NOT enabled; you may want to manually' 

206 ' set it.'.format(criterion['name'])) 

207 warnings.warn(msg) 

208 continue 

209 

210 criteria[criterion.name] = criterion 

211 msg = ('Custom convergence criterion {:s} encountered. ' 

212 'Please be sure that each calculator is fed a ' 

213 'unique instance of this criterion. ' 

214 'Note that if you save the calculator instance to ' 

215 'a .gpw file you may not be able to re-open it. ' 

216 .format(criterion.name)) 

217 warnings.warn(msg) 

218 

219 for criterion in criteria.values(): 

220 criterion.reset() 

221 

222 return criteria