Coverage for gpaw/eigensolvers/eigensolver.py: 92%
144 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-12 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-12 00:18 +0000
1"""Module defining an eigensolver base-class."""
2from functools import partial
4import numpy as np
5from ase.dft.bandgap import _bandgap
6from ase.units import Ha
7from ase.utils.timing import timer
9from gpaw.matrix import matrix_matrix_multiply as mmm
10from gpaw.utilities.mblas import multi_axpy
11from gpaw.xc.hybrid import HybridXC
12from gpaw.mpi import broadcast_exception
15def reshape(a_x, shape):
16 """Get an ndarray of size shape from a_x buffer."""
17 return a_x.ravel()[:np.prod(shape)].reshape(shape)
20class Eigensolver:
21 def __init__(self, keep_htpsit=True, blocksize=1):
22 self.keep_htpsit = keep_htpsit
23 self.initialized = False
24 self.Htpsit_nG = None
25 self.error = np.inf
26 self.blocksize = blocksize
27 self.orthonormalization_required = True
29 def initialize(self, wfs):
30 self.timer = wfs.timer
31 self.world = wfs.world
32 self.kpt_comm = wfs.kd.comm
33 self.band_comm = wfs.bd.comm
34 self.dtype = wfs.dtype
35 self.bd = wfs.bd
36 self.nbands = wfs.bd.nbands
37 self.mynbands = wfs.bd.mynbands
39 if wfs.bd.comm.size > 1:
40 self.keep_htpsit = False
42 if self.keep_htpsit:
43 self.Htpsit_nG = np.empty_like(wfs.work_array)
45 # Preconditioner for the electronic gradients:
46 self.preconditioner = wfs.make_preconditioner(self.blocksize)
48 for kpt in wfs.kpt_u:
49 if kpt.eps_n is None:
50 kpt.eps_n = np.empty(self.mynbands)
52 self.initialized = True
54 def reset(self):
55 self.initialized = False
57 def weights(self, wfs):
58 """Calculate convergence weights for all eigenstates."""
59 weight_un = np.zeros((len(wfs.kpt_u), self.bd.mynbands))
61 if isinstance(self.nbands_converge, int):
62 # Converge fixed number of bands:
63 n = self.nbands_converge - self.bd.beg
64 if n > 0:
65 for weight_n, kpt in zip(weight_un, wfs.kpt_u):
66 weight_n[:n] = kpt.weight
67 elif self.nbands_converge == 'occupied':
68 # Conveged occupied bands:
69 for weight_n, kpt in zip(weight_un, wfs.kpt_u):
70 if kpt.f_n is None: # no eigenvalues yet
71 weight_n[:] = np.inf
72 else:
73 # Methfessel-Paxton distribution can give negative
74 # occupation numbers - so we take the absolute value:
75 weight_n[:] = np.abs(kpt.f_n)
76 else:
77 # Converge state with energy up to CBM + delta:
78 assert self.nbands_converge.startswith('CBM+')
79 delta = float(self.nbands_converge[4:]) / Ha
81 if wfs.kpt_u[0].f_n is None:
82 weight_un[:] = np.inf # no eigenvalues yet
83 else:
84 # Collect all eigenvalues and calculate band gap:
85 efermi = np.mean(wfs.fermi_levels)
86 eps_skn = np.array(
87 [[wfs.collect_eigenvalues(k, spin) - efermi
88 for k in range(wfs.kd.nibzkpts)]
89 for spin in range(wfs.nspins)])
90 if wfs.world.rank > 0:
91 eps_skn = np.empty((wfs.nspins,
92 wfs.kd.nibzkpts,
93 wfs.bd.nbands))
94 wfs.world.broadcast(eps_skn, 0)
95 try:
96 # Find bandgap + positions of CBM:
97 try:
98 gap, _, (s, k, n) = _bandgap(eps_skn, direct=False)
99 except TypeError:
100 # Old ASE:
101 gap, _, (s, k, n) = _bandgap(eps_skn, spin=None,
102 direct=False)
103 except ValueError:
104 gap = 0.0
106 if gap == 0.0:
107 cbm = efermi
108 else:
109 cbm = efermi + eps_skn[s, k, n]
111 ecut = cbm + delta
112 for weight_n, kpt in zip(weight_un, wfs.kpt_u):
113 weight_n[kpt.eps_n < ecut] = kpt.weight
115 if (eps_skn[:, :, -1] < ecut - efermi).any():
116 # We don't have enough bands!
117 weight_un[:] = np.inf
119 return weight_un
121 def iterate(self, ham, wfs):
122 """Solves eigenvalue problem iteratively
124 This method is inherited by the actual eigensolver which should
125 implement *iterate_one_k_point* method for a single iteration of
126 a single kpoint.
127 """
129 if not self.initialized:
130 if isinstance(ham.xc, HybridXC):
131 self.blocksize = wfs.bd.mynbands
132 self.initialize(wfs)
134 weight_un = self.weights(wfs)
136 error = 0.0
137 with broadcast_exception(self.kpt_comm):
138 for kpt, weights in zip(wfs.kpt_u, weight_un):
139 if not wfs.orthonormalized:
140 wfs.orthonormalize(kpt)
141 e = self.iterate_one_k_point(ham, wfs, kpt, weights)
142 error += e
143 if self.orthonormalization_required:
144 wfs.orthonormalize(kpt)
146 wfs.orthonormalized = True
147 self.error = self.band_comm.sum_scalar(self.kpt_comm.sum_scalar(error))
149 def iterate_one_k_point(self, ham, kpt):
150 """Implemented in subclasses."""
151 raise NotImplementedError
153 def calculate_residuals(self, kpt, wfs, ham, psit, P, eps_n,
154 R, C, n_x=None, calculate_change=False):
155 """Calculate residual.
157 From R=Ht*psit calculate R=H*psit-eps*S*psit."""
159 multi_axpy(-eps_n, psit.array, R.array)
161 ham.dH(P, out=C)
162 for a, I1, I2 in P.indices:
163 dS_ii = ham.setups[a].dO_ii
164 C.array[..., I1:I2] -= np.dot((P.array[..., I1:I2].T * eps_n).T,
165 dS_ii)
167 ham.xc.add_correction(kpt, psit.array, R.array,
168 {a: P_ni for a, P_ni in P.items()},
169 {a: C_ni for a, C_ni in C.items()},
170 n_x,
171 calculate_change)
172 wfs.pt.add(R.array, {a: C_ni for a, C_ni in C.items()}, kpt.q)
174 @timer('Subspace diag')
175 def subspace_diagonalize(self, ham, wfs, kpt, rotate_psi=True):
176 """Diagonalize the Hamiltonian in the subspace of kpt.psit_nG
178 *Htpsit_nG* is a work array of same size as psit_nG which contains
179 the local part of the Hamiltonian times psit on exit
181 First, the Hamiltonian (defined by *kin*, *vt_sG*, and
182 *dH_asp*) is applied to the wave functions, then the *H_nn*
183 matrix is calculated and diagonalized, and finally, the wave
184 functions (and also Htpsit_nG are rotated. Also the
185 projections *P_ani* are rotated.
187 It is assumed that the wave functions *psit_nG* are orthonormal
188 and that the integrals of projector functions and wave functions
189 *P_ani* are already calculated.
191 Return rotated wave functions and H applied to the rotated
192 wave functions if self.keep_htpsit is True.
193 """
195 if self.band_comm.size > 1 and wfs.bd.strided:
196 raise NotImplementedError
198 psit = kpt.psit
199 tmp = psit.new(buf=wfs.work_array)
200 H = wfs.work_matrix_nn
201 P2 = kpt.projections.new()
203 Ht = partial(wfs.apply_pseudo_hamiltonian, kpt, ham)
205 with self.timer('calc_h_matrix'):
206 # We calculate the complex conjugate of H, because
207 # that is what is most efficient with BLAS given the layout of
208 # our matrices.
209 psit.matrix_elements(operator=Ht, result=tmp, out=H,
210 symmetric=True, cc=True)
211 ham.dH(kpt.projections, out=P2)
212 mmm(1.0, kpt.projections, 'N', P2, 'C', 1.0, H, symmetric=True)
213 ham.xc.correct_hamiltonian_matrix(kpt, H.array)
215 with wfs.timer('diagonalize'):
216 slcomm, r, c, b = wfs.scalapack_parameters
217 if r == c == 1:
218 slcomm = None
219 # Complex conjugate before diagonalizing:
220 eps_n = H.eigh(cc=True, scalapack=(slcomm, r, c, b))
221 # H.array[n, :] now contains the n'th eigenvector and eps_n[n]
222 # the n'th eigenvalue
223 kpt.eps_n[:] = eps_n[wfs.bd.get_slice()]
225 with self.timer('rotate_psi'):
226 if not rotate_psi:
227 return
228 if self.keep_htpsit:
229 Htpsit = psit.new(buf=self.Htpsit_nG)
230 mmm(1.0, H, 'N', tmp, 'N', 0.0, Htpsit)
231 mmm(1.0, H, 'N', psit, 'N', 0.0, tmp)
232 psit[:] = tmp
233 mmm(1.0, H, 'N', kpt.projections, 'N', 0.0, P2)
234 kpt.projections.matrix = P2.matrix
235 # Rotate orbital dependent XC stuff:
236 ham.xc.rotate(kpt, H.array.T)
238 def estimate_memory(self, mem, wfs):
239 gridmem = wfs.bytes_per_wave_function()
241 keep_htpsit = self.keep_htpsit and (wfs.bd.mynbands == wfs.bd.nbands)
243 if keep_htpsit:
244 mem.subnode('Htpsit', wfs.bd.nbands * gridmem)
245 else:
246 mem.subnode('No Htpsit', 0)
248 mem.subnode('eps_n', wfs.bd.mynbands * mem.floatsize)
249 mem.subnode('eps_N', wfs.bd.nbands * mem.floatsize)
250 mem.subnode('Preconditioner', 4 * gridmem)
251 mem.subnode('Work', gridmem)