Coverage for gpaw/new/scf.py: 84%
114 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1from __future__ import annotations
3import itertools
4import warnings
5from math import inf
6from types import SimpleNamespace
7from typing import Any, Callable
9import numpy as np
11from gpaw import KohnShamConvergenceError
12from gpaw.convergence_criteria import (Criterion, check_convergence,
13 dict2criterion)
14from gpaw.new.energies import DFTEnergies
15from gpaw.new.logger import indent
16from gpaw.new.ibzwfs import IBZWaveFunctions
17from gpaw.scf import write_iteration
18from gpaw.typing import Array2D
21class TooFewBandsError(KohnShamConvergenceError):
22 """Not enough bands for CBM+x convergence criterion."""
25class SCFLoop:
26 def __init__(self,
27 hamiltonian,
28 occ_calc,
29 eigensolver,
30 mixer,
31 comm,
32 convergence,
33 maxiter):
34 self.hamiltonian = hamiltonian
35 self.eigensolver = eigensolver
36 self.mixer = mixer
37 self.occ_calc = occ_calc
38 self.comm = comm
39 self.convergence = create_convergence_criteria(convergence)
40 self.maxiter = maxiter
41 self.niter = 0
42 self.update_density_and_potential = True
43 self.fix_fermi_level = False
45 def __repr__(self):
46 return 'SCFLoop(...)'
48 def __str__(self):
49 return (f'eigensolver:\n{indent(self.eigensolver)}\n'
50 f'{self.mixer}\n'
51 f'occupation numbers:\n{indent(self.occ_calc)}\n')
53 def iterate(self,
54 ibzwfs: IBZWaveFunctions,
55 density,
56 potential,
57 energies: DFTEnergies,
58 pot_calc,
59 *,
60 maxiter=None,
61 calculate_forces=None,
62 log=None):
63 cc = self.convergence
64 maxiter = maxiter or self.maxiter
66 if log:
67 log('convergence criteria:')
68 for criterion in cc.values():
69 if criterion.description is not None:
70 log('- ' + criterion.description)
71 log(f'maximum number of iterations: {self.maxiter}\n')
73 self.mixer.reset()
75 self.occ_calc.initialize_reference_orbitals()
77 if self.update_density_and_potential:
78 dens_error = self.mixer.mix(density)
79 else:
80 dens_error = 0.0
82 for self.niter in itertools.count(start=1):
83 eig_error, wfs_error, energies = self.eigensolver.iterate(
84 ibzwfs, density, potential,
85 self.hamiltonian, pot_calc, energies)
86 nelectrons = (density.nvalence - density.charge +
87 pot_calc.environment.charge)
88 e_band, e_entropy, e_extrapolation = ibzwfs.calculate_occs(
89 self.occ_calc,
90 nelectrons,
91 fix_fermi_level=self.fix_fermi_level)
93 energies.set(**pot_calc.xc.energies,
94 band=e_band,
95 entropy=e_entropy,
96 extrapolation=e_extrapolation)
98 ctx = SCFContext(
99 log, self.niter, energies,
100 ibzwfs, density, potential,
101 wfs_error, dens_error, eig_error,
102 self.comm, calculate_forces,
103 pot_calc, self.update_density_and_potential)
105 yield ctx
107 converged, converged_items, entries = check_convergence(cc, ctx)
108 nconverged = self.comm.sum_scalar(int(converged))
109 assert nconverged in [0, self.comm.size], converged_items
111 if log:
112 write_iteration(cc, converged_items, entries, ctx, log)
114 if converged:
115 converged = pot_calc.environment.post_scf_convergence(
116 ibzwfs, nelectrons, self.occ_calc, self.mixer, log)
117 if converged:
118 break
120 if self.niter == maxiter:
121 if wfs_error < inf:
122 raise KohnShamConvergenceError
123 raise TooFewBandsError
125 if self.update_density_and_potential:
126 density.update(ibzwfs, ked=pot_calc.xc.type == 'MGGA')
127 dens_error = self.mixer.mix(density)
128 potential, energies, _ = pot_calc.calculate(
129 density, ibzwfs, potential.vHt_x)
131 self.eigensolver.postprocess(
132 ibzwfs, density, potential, self.hamiltonian)
135class SCFContext:
136 def __init__(self,
137 log,
138 niter: int,
139 energies: DFTEnergies,
140 ibzwfs,
141 density,
142 potential,
143 wfs_error: float,
144 dens_error: float,
145 eig_error: float,
146 comm,
147 calculate_forces: Callable[[], Array2D],
148 pot_calc,
149 update_density_and_potential):
150 self.log = log
151 self.niter = niter
152 self.energies = energies
153 self.ibzwfs = ibzwfs
154 self.density = density
155 self.potential = potential
156 energy = energies.total_extrapolated
157 self.ham = SimpleNamespace(e_total_extrapolated=energy,
158 get_workfunctions=self._get_workfunctions)
159 self.wfs = SimpleNamespace(nvalence=density.nvalence +
160 pot_calc.environment.charge,
161 world=comm,
162 eigensolver=SimpleNamespace(
163 error=wfs_error),
164 nspins=density.ndensities,
165 collinear=density.collinear)
166 self.dens = SimpleNamespace(
167 calculate_magnetic_moments=density.calculate_magnetic_moments,
168 fixed=not update_density_and_potential,
169 error=dens_error)
170 self.eig_error = eig_error
171 self.calculate_forces = calculate_forces
172 self.poisson_solver = pot_calc.poisson_solver
174 def _get_workfunctions(self, _):
175 vacuum_level = self.potential.get_vacuum_level()
176 (fermi_level,) = self.ibzwfs.fermi_levels
177 wf = vacuum_level - fermi_level
178 delta = self.poisson_solver.dipole_layer_correction()
179 return np.array([wf + delta, wf - delta])
182def create_convergence_criteria(criteria: dict[str, Any]
183 ) -> dict[str, Criterion]:
184 criteria = criteria.copy()
185 for k, v in [('energy', 0.0005), # eV / electron
186 ('density', 1.0e-4), # electrons / electron
187 ('eigenstates', 4.0e-8)]: # eV^2 / electron
188 if k not in criteria:
189 criteria[k] = v
190 # Gather convergence criteria for SCF loop.
191 custom = criteria.pop('custom', [])
192 for name, criterion in criteria.items():
193 if hasattr(criterion, 'todict'):
194 # 'Copy' so no two calculators share an instance.
195 criteria[name] = dict2criterion(criterion.todict())
196 else:
197 criteria[name] = dict2criterion({name: criterion})
199 if not isinstance(custom, (list, tuple)):
200 custom = [custom]
201 for criterion in custom:
202 if isinstance(criterion, dict): # from .gpw file
203 msg = ('Custom convergence criterion "{:s}" encountered, '
204 'which GPAW does not know how to load. This '
205 'criterion is NOT enabled; you may want to manually'
206 ' set it.'.format(criterion['name']))
207 warnings.warn(msg)
208 continue
210 criteria[criterion.name] = criterion
211 msg = ('Custom convergence criterion {:s} encountered. '
212 'Please be sure that each calculator is fed a '
213 'unique instance of this criterion. '
214 'Note that if you save the calculator instance to '
215 'a .gpw file you may not be able to re-open it. '
216 .format(criterion.name))
217 warnings.warn(msg)
219 for criterion in criteria.values():
220 criterion.reset()
222 return criteria