Coverage for gpaw/new/pwfd/rmmdiis.py: 92%
73 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1from __future__ import annotations
3import warnings
4from pprint import pformat
6import numpy as np
7from gpaw.gpu import as_np
8from gpaw.new import zips as zip
9from gpaw.new.pwfd.eigensolver import PWFDEigensolver, calculate_residuals
10from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions
13class RMMDIIS(PWFDEigensolver):
14 def __init__(self,
15 nbands: int,
16 wf_grid,
17 band_comm,
18 hamiltonian,
19 converge_bands='occupied',
20 niter: int = 1,
21 trial_step: float | None = None,
22 scalapack_parameters=None,
23 max_buffer_mem: int = 200 * 1024 ** 2):
24 """RMM-DIIS eigensolver.
26 Solution steps are:
28 * Subspace diagonalization
29 * Calculation of residuals
30 * Improvement of wave functions: psi' = psi + lambda PR + lambda PR'
31 * Orthonormalization
33 Parameters
34 ==========
35 trial_step:
36 Step length for final step. Use None for using the previously
37 optimized step lengths.
38 """
40 if niter != 1:
41 warnings.warn(f'Ignoring niter={niter} in RMMDIIS')
42 super().__init__(hamiltonian, converge_bands,
43 max_buffer_mem=max_buffer_mem)
44 self.trial_step = trial_step
46 def __str__(self):
47 return pformat(dict(name='RMMDIIS',
48 converge_bands=self.converge_bands))
50 def _initialize(self, ibzwfs):
51 super()._initialize(ibzwfs)
52 self._allocate_work_arrays(ibzwfs, shape=(1,))
53 self._allocate_buffer_arrays(ibzwfs, shape=(2,))
55 def iterate1(self,
56 wfs: PWFDWaveFunctions,
57 Ht, dH, dS_aii, weight_n):
58 """Do one step ...
60 See here:
62 https://gpaw.readthedocs.io/documentation/rmm-diis.html
63 """
65 psit_nX = wfs.psit_nX
66 mynbands = psit_nX.mydims[0]
68 residual_nX = psit_nX.new(data=self.work_arrays[0, :mynbands])
70 P_ani = wfs.P_ani
71 work1_ani = P_ani.new()
72 work2_ani = P_ani.new()
74 wfs.subspace_diagonalize(Ht, dH,
75 psit2_nX=residual_nX,
76 data_buffer=self.data_buffers[0])
77 calculate_residuals(wfs.psit_nX, residual_nX, wfs.pt_aiX,
78 wfs.P_ani, wfs.myeig_n,
79 dH, dS_aii, work1_ani, work2_ani)
81 work1_nX = psit_nX.create_work_buffer(self.data_buffers[0])
82 work2_nX = psit_nX.create_work_buffer(self.data_buffers[1])
83 blocksize = work1_nX.data.shape[0]
84 P1_ani = P_ani.layout.empty(blocksize)
85 P2_ani = P_ani.layout.empty(blocksize)
86 if weight_n is None:
87 error = np.inf
88 else:
89 error = weight_n @ as_np(residual_nX.norm2())
91 comm = psit_nX.comm
92 blocksize_world = comm.sum_scalar(blocksize)
93 totalbands = comm.sum_scalar(mynbands)
94 for i1, N1 in enumerate(
95 range(0, totalbands, blocksize_world)):
96 n1 = i1 * blocksize
97 n2 = n1 + blocksize
98 if n2 > mynbands:
99 n2 = mynbands
100 P1_ani = P1_ani[:, :n2 - n1]
101 P2_ani = P2_ani[:, :n2 - n1]
102 block_step(
103 psit_nX[n1:n2],
104 residual_nX[n1:n2],
105 wfs.pt_aiX, wfs.myeig_n[n1:n2], Ht, dH, dS_aii,
106 self.trial_step,
107 work1_nX[:n2 - n1],
108 work2_nX[:n2 - n1],
109 P1_ani, P2_ani,
110 self.preconditioner)
111 wfs._P_ani = None
112 wfs.orthonormalized = False
113 wfs.orthonormalize(residual_nX)
114 return error
117def block_step(psit_nX,
118 R_nX,
119 pt_aiX,
120 eig_n,
121 Ht,
122 dH,
123 dS_aii,
124 trial_step,
125 work1_nX,
126 work2_nX,
127 P1_ani,
128 P2_ani,
129 preconditioner) -> None:
130 """See here:
132 https://gpaw.readthedocs.io/documentation/rmm-diis.html
133 """
134 xp = psit_nX.xp
135 PR_nX = work1_nX
136 dR_nX = work2_nX
137 ekin_n = preconditioner(psit_nX, R_nX, out=PR_nX)
139 Ht(PR_nX, out=dR_nX)
140 P_ani = pt_aiX.integrate(PR_nX)
141 calculate_residuals(PR_nX, dR_nX, pt_aiX, P_ani, eig_n,
142 dH, dS_aii, P1_ani, P2_ani)
143 a_n = xp.asarray([-d_X.integrate(r_X)
144 for d_X, r_X in zip(dR_nX, R_nX)])
145 b_n = dR_nX.norm2()
146 shape = (len(a_n),) + (1,) * (psit_nX.data.ndim - 1)
147 lambda_n = (a_n / b_n).reshape(shape)
148 PR_nX.data *= lambda_n
149 psit_nX.data += PR_nX.data
150 dR_nX.data *= lambda_n
151 R_nX.data += dR_nX.data
152 preconditioner(psit_nX, R_nX, out=PR_nX, ekin_n=ekin_n)
153 if trial_step is None:
154 PR_nX.data *= lambda_n
155 else:
156 PR_nX.data *= trial_step
157 psit_nX.data += PR_nX.data