Coverage for gpaw/new/pwfd/davidson.py: 99%
153 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1from __future__ import annotations
3from pprint import pformat
5import numpy as np
6from gpaw import debug
7from gpaw.core.matrix import Matrix
8from gpaw.gpu import as_np
9from gpaw.mpi import broadcast_exception
10from gpaw.new.pwfd.eigensolver import PWFDEigensolver, calculate_residuals
11from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions
12from gpaw.typing import Array2D
13from gpaw.new import trace, tracectx
16class Davidson(PWFDEigensolver):
17 def __init__(self,
18 nbands: int,
19 wf_grid,
20 band_comm,
21 hamiltonian,
22 converge_bands='occupied',
23 niter=2,
24 scalapack_parameters=None,
25 max_buffer_mem: int = 200 * 1024 ** 2):
26 super().__init__(
27 hamiltonian,
28 converge_bands,
29 max_buffer_mem=max_buffer_mem)
30 self.niter = niter
31 self.H_NN: Matrix
32 self.S_NN: Matrix
33 self.M_nn: Matrix
35 def __str__(self):
36 return pformat(dict(name='Davidson',
37 niter=self.niter,
38 converge_bands=self.converge_bands))
40 def _initialize(self, ibzwfs):
41 super()._initialize(ibzwfs)
42 self._allocate_work_arrays(ibzwfs, shape=(1,))
43 self._allocate_buffer_arrays(ibzwfs, shape=(1,))
45 wfs = ibzwfs.wfs_qs[0][0]
46 assert isinstance(wfs, PWFDWaveFunctions)
47 domain_comm = wfs.psit_nX.desc.comm
48 band_comm = wfs.band_comm
50 B = ibzwfs.nbands
51 xp = ibzwfs.xp
52 dtype = wfs.psit_nX.desc.dtype
53 if domain_comm.rank == 0 and band_comm.rank == 0:
54 self.H_NN = Matrix(2 * B, 2 * B, dtype=dtype, xp=xp)
55 self.S_NN = Matrix(2 * B, 2 * B, dtype=dtype, xp=xp)
56 else:
57 self.H_NN = self.S_NN = Matrix(0, 0)
59 self.M_nn = Matrix(B, B, dtype=dtype,
60 dist=(band_comm, band_comm.size),
61 xp=xp)
62 self.M2_nn = self.M_nn.new()
64 def iterate1(self,
65 wfs: PWFDWaveFunctions,
66 Ht, dH, dS_aii, weight_n):
67 H_NN = self.H_NN
68 S_NN = self.S_NN
69 M_nn = self.M_nn
70 M2_nn = self.M2_nn
72 xp = M_nn.xp
74 psit_nX = wfs.psit_nX
75 B = psit_nX.dims[0] # number of bands
76 eig_N = xp.empty(2 * B)
77 b = psit_nX.mydims[0]
79 psit2_nX = psit_nX.new(data=self.work_arrays[0, :b])
80 data_buffer = self.data_buffers[0]
82 wfs.subspace_diagonalize(Ht, dH,
83 psit2_nX=psit2_nX,
84 data_buffer=data_buffer)
86 P_ani = wfs.P_ani
87 P2_ani = P_ani.new()
88 P3_ani = P_ani.new()
90 domain_comm = psit_nX.desc.comm
91 band_comm = psit_nX.comm
92 is_domain_band_master = domain_comm.rank == 0 and band_comm.rank == 0
94 M0_nn = M_nn.new(dist=(band_comm, 1, 1))
96 if domain_comm.rank == 0:
97 eig_N[:B] = xp.asarray(wfs.eig_n)
99 me_buffer_mX = psit_nX.create_work_buffer(data_buffer)
101 @trace
102 def me(a, b, function=None):
103 """Matrix elements"""
104 return a.matrix_elements(b,
105 domain_sum=False,
106 out=M_nn,
107 function=function,
108 cc=True)
110 calculate_residuals(wfs.psit_nX,
111 psit2_nX,
112 wfs.pt_aiX,
113 wfs.P_ani,
114 wfs.myeig_n,
115 dH, dS_aii, P2_ani, P3_ani)
117 def copy(C_nn: Array2D, M_nn: Matrix) -> None:
118 domain_comm.sum(M_nn.data, 0)
119 if domain_comm.rank == 0:
120 M_nn.redist(M0_nn)
121 if band_comm.rank == 0:
122 C_nn[:] = M0_nn.data
124 for i in range(self.niter):
125 if i == self.niter - 1: # last iteration
126 # Calculate error before we destroy residuals:
127 if weight_n is None:
128 error = np.inf
129 else:
130 error = (weight_n @ as_np(psit2_nX.norm2())).sum()
132 sliced_preconditioner(psit_nX, psit2_nX,
133 buffer=me_buffer_mX,
134 precon=self.preconditioner)
136 # Calculate projections
137 wfs.pt_aiX.integrate(psit2_nX, out=P2_ani)
138 with tracectx('Matrix elements'):
139 # Sliced matrix elements with hamiltonian. See
140 # sliced_matrix_elements docstring.
141 sliced_matrix_elements(psit_nX, psit2_nX,
142 buffer_mX=me_buffer_mX,
143 Ht=Ht,
144 M1_nn=M_nn,
145 M2_nn=M2_nn)
147 # <psi2 | H | psi2>
148 dH(P2_ani, out_ani=P3_ani)
149 P2_ani.matrix.multiply(P3_ani, opb='C', symmetric=True, beta=1,
150 out=M2_nn)
151 copy(H_NN.data[B:, B:], M2_nn)
153 # <psi2 | H | psi>
154 P3_ani.matrix.multiply(P_ani, opb='C', beta=1.0, out=M_nn)
155 copy(H_NN.data[B:, :B], M_nn)
157 # <psi2 | S | psi2>
158 me(psit2_nX, psit2_nX)
159 P2_ani.block_diag_multiply(dS_aii, out_ani=P3_ani)
160 P2_ani.matrix.multiply(P3_ani, opb='C', symmetric=True, beta=1,
161 out=M_nn)
162 copy(S_NN.data[B:, B:], M_nn)
164 # <psi2 | S | psi>
165 me(psit2_nX, psit_nX)
166 P3_ani.matrix.multiply(P_ani, opb='C', beta=1.0, out=M_nn)
167 copy(S_NN.data[B:, :B], M_nn)
169 with tracectx('Diagonalize'):
170 with broadcast_exception(domain_comm):
171 with broadcast_exception(band_comm):
172 if is_domain_band_master:
173 H_NN.data[:B, :B] = xp.diag(eig_N[:B])
174 S_NN.data[:B, :B] = xp.eye(B)
175 eig_N[:] = H_NN.eigh(S_NN)
176 wfs._eig_n = as_np(eig_N[:B])
177 if domain_comm.rank == 0:
178 band_comm.broadcast(wfs.eig_n, 0)
179 domain_comm.broadcast(wfs.eig_n, 0)
181 if domain_comm.rank == 0:
182 if band_comm.rank == 0:
183 M0_nn.data[:] = H_NN.data[:B, :B]
184 M0_nn.complex_conjugate()
185 M0_nn.redist(M_nn)
186 domain_comm.broadcast(M_nn.data, 0)
188 with tracectx('Rotate Psi'):
189 M_nn.multiply(psit_nX, out=psit_nX,
190 data_buffer=data_buffer)
191 M_nn.multiply(P_ani, out=P3_ani)
193 if domain_comm.rank == 0:
194 if band_comm.rank == 0:
195 M0_nn.data[:] = H_NN.data[:B, B:]
196 M0_nn.complex_conjugate()
197 M0_nn.redist(M_nn)
198 domain_comm.broadcast(M_nn.data, 0)
200 M_nn.multiply(psit2_nX, beta=1.0, out=psit_nX)
201 M_nn.multiply(P2_ani, beta=1.0, out=P3_ani)
202 P_ani, P3_ani = P3_ani, P_ani
203 wfs._P_ani = P_ani
205 if i < self.niter - 1:
206 Ht(psit_nX, out=psit2_nX)
207 calculate_residuals(
208 wfs.psit_nX,
209 psit2_nX,
210 wfs.pt_aiX, wfs.P_ani, wfs.myeig_n,
211 dH, dS_aii, P2_ani, P3_ani)
213 if debug:
214 psit_nX.sanity_check()
216 return error
219def sliced_preconditioner(psit_nX, psit2_nX, buffer, precon):
220 # Sliced recursive preconditioning
221 buffer_size = buffer.data.shape[0]
222 mybands = psit_nX.data.shape[0]
223 if not mybands == 0:
224 for i_local in range(0, mybands, buffer_size):
225 buffer_view = buffer[:mybands - i_local]
226 precon(
227 psit_nX[i_local:i_local + buffer_size],
228 psit2_nX[i_local:i_local + buffer_size],
229 out=buffer_view)
230 psit2_nX.data[i_local:i_local + buffer_size] \
231 = buffer_view.data[:]
234def sliced_matrix_elements(psit1_nX, psit2_nX, buffer_mX, Ht, M1_nn, M2_nn):
235 ''' Method for calculating matrix elements in a sliced manner:
236 <psi2 | H | psi2> -> M2_nn
237 <psi2 | H | psi1> -> M1_nn
239 This function uses less memory than, but is otherwise identical to:
240 psit3_nX = psit2_nX.new()
241 psit2_nX.matrix_elements(psit2_nX,
242 out=M2_nn,
243 domain_sum=False,
244 function=partial(Ht, out=psit3_nX),
245 cc=True)
246 psit3_nX.matrix_elements(psit1_nX,
247 out=M_nn,
248 domain_sum=False,
249 cc=True)
250 '''
251 comm = psit1_nX.comm
252 b = psit1_nX.data.shape[0]
253 blocksize = buffer_mX.data.shape[0]
254 blocksize_world = comm.sum_scalar(blocksize)
255 totalbands = comm.sum_scalar(b)
256 for i1, N1 in enumerate(
257 range(0, totalbands, blocksize_world)):
258 n1 = i1 * blocksize
259 n2 = n1 + blocksize
260 if n2 > b:
261 n2 = b
263 world_N = min(blocksize_world,
264 totalbands - N1)
266 buffer_view_aX = buffer_mX.new(
267 data=buffer_mX.data[:n2 - n1],
268 dims=(world_N,) + buffer_mX.dims[1:],
269 )
270 Ht(psit2_nX[n1:n2], out=buffer_view_aX)
272 out1 = Matrix(
273 M=world_N,
274 N=M1_nn.shape[1],
275 data=M1_nn.data[n1:n2, :],
276 dist=(comm, -1, 1),
277 xp=M1_nn.xp)
278 out2 = Matrix(
279 M=world_N,
280 N=M2_nn.shape[1],
281 data=M2_nn.data[n1:n2, :],
282 dist=(comm, -1, 1),
283 xp=M2_nn.xp)
284 buffer_view_aX.matrix_elements(psit1_nX,
285 out=out1,
286 symmetric=False,
287 domain_sum=False,
288 cc=True)
289 buffer_view_aX.matrix_elements(psit2_nX,
290 out=out2,
291 symmetric=False,
292 domain_sum=False,
293 cc=True)