Coverage for gpaw/eigensolvers/rmmdiis.py: 94%
156 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1"""Module defining ``Eigensolver`` classes."""
2from functools import partial
4import numpy as np
6from gpaw.utilities.blas import axpy
7from gpaw.eigensolvers.eigensolver import Eigensolver
10class RMMDIIS(Eigensolver):
11 """RMM-DIIS eigensolver
13 It is expected that the trial wave functions are orthonormal
14 and the integrals of projector functions and wave functions
15 ``nucleus.P_uni`` are already calculated
17 Solution steps are:
19 * Subspace diagonalization
20 * Calculation of residuals
21 * Improvement of wave functions: psi' = psi + lambda PR + lambda PR'
22 * Orthonormalization"""
24 def __init__(self, keep_htpsit=True, blocksize=None, niter=3, rtol=1e-16,
25 limit_lambda=False, use_rayleigh=False, trial_step=0.1):
26 """Initialize RMM-DIIS eigensolver.
28 Parameters:
30 limit_lambda: dictionary
31 determines if step length should be limited
32 supported keys:
33 'absolute':True/False limit the absolute value
34 'upper':float upper limit for lambda
35 'lower':float lower limit for lambda
37 """
39 Eigensolver.__init__(self, keep_htpsit, blocksize)
40 self.niter = niter
41 self.rtol = rtol
42 self.limit_lambda = limit_lambda
43 self.use_rayleigh = use_rayleigh
44 if use_rayleigh:
45 1 / 0
46 self.blocksize = 1
47 self.trial_step = trial_step
48 self.first = True
50 def todict(self):
51 return {'name': 'rmm-diis', 'niter': self.niter}
53 def initialize(self, wfs):
54 if self.blocksize is None:
55 if wfs.mode == 'pw':
56 S = wfs.pd.comm.size
57 # Use a multiple of S for maximum efficiency
58 self.blocksize = int(np.ceil(10 / S)) * S
59 else:
60 self.blocksize = 10
61 Eigensolver.initialize(self, wfs)
63 def iterate_one_k_point(self, ham, wfs, kpt, weights):
64 """Do a single RMM-DIIS iteration for the kpoint"""
66 self.subspace_diagonalize(ham, wfs, kpt)
68 psit = kpt.psit
69 # psit2 = psit.new(buf=wfs.work_array)
70 P = kpt.projections
71 P2 = P.new()
72 # dMP = P.new()
73 # M_nn = wfs.work_matrix_nn
74 # dS = wfs.setups.dS
75 R = psit.new(buf=self.Htpsit_nG)
77 self.timer.start('RMM-DIIS')
78 if self.keep_htpsit:
79 with self.timer('Calculate residuals'):
80 self.calculate_residuals(kpt, wfs, ham, psit, P, kpt.eps_n,
81 R, P2)
83 def integrate(a_G, b_G):
84 return np.real(wfs.integrate(a_G, b_G, global_integral=False))
86 comm = wfs.gd.comm
88 B = self.blocksize
89 dR = R.new(dist=None, nbands=B)
90 dpsit = dR.new()
91 P = P.new(bcomm=None, nbands=B)
92 P2 = P.new()
93 errors_x = np.zeros(B)
95 # Arrays needed for DIIS step
96 if self.niter > 1:
97 psit_diis_nxG = wfs.empty(B * self.niter, q=kpt.q)
98 R_diis_nxG = wfs.empty(B * self.niter, q=kpt.q)
100 Ht = partial(wfs.apply_pseudo_hamiltonian, kpt, ham)
102 error = 0.0
103 for n1 in range(0, wfs.bd.mynbands, B):
104 n2 = n1 + B
105 if n2 > wfs.bd.mynbands:
106 n2 = wfs.bd.mynbands
107 B = n2 - n1
108 P = P.new(nbands=B)
109 P2 = P.new()
110 dR = dR.new(nbands=B, dist=None)
111 dpsit = dR.new()
113 n_x = np.arange(n1, n2)
114 psitb = psit.view(n1, n2)
116 with self.timer('Calculate residuals'):
117 Rb = R.view(n1, n2)
118 if not self.keep_htpsit:
119 psitb.apply(Ht, out=Rb)
120 psitb.matrix_elements(wfs.pt, out=P)
121 self.calculate_residuals(kpt, wfs, ham, psitb,
122 P, kpt.eps_n[n_x], Rb, P2, n_x)
124 errors_x[:] = 0.0
125 for n in range(n1, n2):
126 weight = weights[n]
127 errors_x[n - n1] = weight * integrate(Rb.array[n - n1],
128 Rb.array[n - n1])
129 comm.sum(errors_x)
130 error += np.sum(errors_x)
132 # Insert first vectors and residuals for DIIS step
133 if self.niter > 1:
134 # Save the previous vectors contiguously for each band
135 # in the block
136 psit_diis_nxG[:B * self.niter:self.niter] = psitb.array
137 R_diis_nxG[:B * self.niter:self.niter] = Rb.array
139 # Precondition the residual:
140 with self.timer('precondition'):
141 ekin_x = self.preconditioner.calculate_kinetic_energy(
142 psitb.array, kpt)
143 self.preconditioner(Rb.array, kpt, ekin_x, out=dpsit.array)
145 # Calculate the residual of dpsit_G, dR_G = (H - e S) dpsit_G:
146 # self.timer.start('Apply Hamiltonian')
147 dpsit.apply(Ht, out=dR)
148 # self.timer.stop('Apply Hamiltonian')
149 with self.timer('projections'):
150 dpsit.matrix_elements(wfs.pt, out=P)
152 with self.timer('Calculate residuals'):
153 self.calculate_residuals(kpt, wfs, ham, dpsit,
154 P, kpt.eps_n[n_x], dR, P2, n_x,
155 calculate_change=True)
157 # Find lam that minimizes the norm of R'_G = R_G + lam dR_G
158 with self.timer('Find lambda'):
159 RdR_x = np.array([integrate(dR_G, R_G)
160 for R_G, dR_G in zip(Rb.array, dR.array)])
161 dRdR_x = np.array([integrate(dR_G, dR_G) for dR_G in dR.array])
162 comm.sum(RdR_x)
163 comm.sum(dRdR_x)
164 lam_x = -RdR_x / dRdR_x
165 # Limit abs(lam) to [0.15, 1.0]
166 if self.limit_lambda:
167 upper = self.limit_lambda['upper']
168 lower = self.limit_lambda['lower']
169 if self.limit_lambda.get('absolute', False):
170 lam_x = np.where(np.abs(lam_x) < lower,
171 lower * np.sign(lam_x), lam_x)
172 lam_x = np.where(np.abs(lam_x) > upper,
173 upper * np.sign(lam_x), lam_x)
174 else:
175 lam_x = np.where(lam_x < lower, lower, lam_x)
176 lam_x = np.where(lam_x > upper, upper, lam_x)
178 # lam_x[:] = 0.1
180 # New trial wavefunction and residual
181 with self.timer('Update psi'):
182 for lam, psit_G, dpsit_G, R_G, dR_G in zip(
183 lam_x, psitb.array,
184 dpsit.array, Rb.array,
185 dR.array):
186 axpy(lam, dpsit_G, psit_G) # psit_G += lam * dpsit_G
187 axpy(lam, dR_G, R_G) # R_G += lam * dR_G
189 self.timer.start('DIIS step')
190 # DIIS step
191 for nit in range(1, self.niter):
192 # Do not perform DIIS if error is small
193 # if abs(error_block / B) < self.rtol:
194 # break
196 # Update the subspace
197 psit_diis_nxG[nit:B * self.niter:self.niter] = psitb.array
198 R_diis_nxG[nit:B * self.niter:self.niter] = Rb.array
200 # XXX Only integrals of nit old psits would be needed
201 # self.timer.start('projections')
202 # wfs.pt.integrate(psit_diis_nxG, P_diis_anxi, kpt.q)
203 # self.timer.stop('projections')
204 if nit > 1 or self.limit_lambda:
205 for ib in range(B):
206 istart = ib * self.niter
207 iend = istart + nit + 1
209 # Residual matrix
210 self.timer.start('Construct matrix')
211 R_nn = wfs.integrate(R_diis_nxG[istart:iend],
212 R_diis_nxG[istart:iend],
213 global_integral=True)
215 # Full matrix
216 A_nn = -np.ones((nit + 2, nit + 2), wfs.dtype)
217 A_nn[:nit + 1, :nit + 1] = R_nn[:]
218 A_nn[-1, -1] = 0.0
219 x_n = np.zeros(nit + 2, wfs.dtype)
220 x_n[-1] = -1.0
221 self.timer.stop('Construct matrix')
222 with self.timer('Linear solve'):
223 alpha_i = np.linalg.solve(A_nn, x_n)[:-1]
225 self.timer.start('Update trial vectors')
226 psitb.array[ib] = alpha_i[nit] * psit_diis_nxG[istart +
227 nit]
228 Rb.array[ib] = alpha_i[nit] * R_diis_nxG[istart + nit]
229 for i in range(nit):
230 # axpy(alpha_i[i], psit_diis_nxG[istart + i],
231 # psit_diis_nxG[istart + nit])
232 # axpy(alpha_i[i], R_diis_nxG[istart + i],
233 # R_diis_nxG[istart + nit])
234 axpy(alpha_i[i], psit_diis_nxG[istart + i],
235 psitb.array[ib])
236 axpy(alpha_i[i], R_diis_nxG[istart + i],
237 Rb.array[ib])
238 self.timer.stop('Update trial vectors')
240 if nit < self.niter - 1:
241 with self.timer('precondition'):
242 self.preconditioner(Rb.array, kpt,
243 ekin_x, out=dpsit.array)
245 for psit_G, lam, dpsit_G in zip(psitb.array, lam_x,
246 dpsit.array):
247 axpy(lam, dpsit_G, psit_G)
249 # Calculate the new residuals
250 self.timer.start('Calculate residuals')
251 psitb.apply(Ht, out=Rb)
252 psitb.matrix_elements(wfs.pt, out=P)
253 self.calculate_residuals(kpt, wfs, ham, psitb,
254 P, kpt.eps_n[n_x], Rb, P2, n_x,
255 calculate_change=True)
256 self.timer.stop('Calculate residuals')
258 self.timer.stop('DIIS step')
259 # Final trial step
260 with self.timer('precondition'):
261 self.preconditioner(Rb.array, kpt, ekin_x, out=dpsit.array)
263 self.timer.start('Update psi')
264 if self.trial_step is not None:
265 lam_x[:] = self.trial_step
266 for lam, psit_G, dpsit_G in zip(lam_x, psitb.array, dpsit.array):
267 axpy(lam, dpsit_G, psit_G) # psit_G += lam * dpsit_G
268 self.timer.stop('Update psi')
270 self.timer.stop('RMM-DIIS')
271 return error
273 def __repr__(self):
274 repr_string = 'RMM-DIIS eigensolver\n'
275 repr_string += ' keep_htpsit: %s\n' % self.keep_htpsit
276 repr_string += ' DIIS iterations: %d\n' % self.niter
277 repr_string += ' Threshold for DIIS: %5.1e\n' % self.rtol
278 repr_string += ' Limit lambda: %s\n' % self.limit_lambda
279 repr_string += ' use_rayleigh: %s\n' % self.use_rayleigh
280 repr_string += ' trial_step: %s' % self.trial_step
281 return repr_string