Coverage for gpaw/directmin/fdpw/etdm_inner_loop.py: 91%
267 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1"""
2Optimization of orbitals
3among occupied and a few virtual states
4represented on a grid or with plane waves
5in order to calculate and excited state
7arXiv:2102.06542 [physics.comp-ph]
8"""
10from gpaw.directmin.tools import get_n_occ, get_indices, expm_ed, \
11 sort_orbitals_according_to_occ
12from gpaw.directmin.sd_etdm import LSR1P
13from gpaw.directmin.ls_etdm import MaxStep
14from gpaw.directmin.derivatives import get_approx_analytical_hessian
15from ase.units import Hartree
16import numpy as np
17import time
20class ETDMInnerLoop:
22 def __init__(self, odd_pot, wfs, nstates='all', maxiter=100,
23 maxstepxst=0.2, g_tol=5.0e-4, useprec=False, momevery=10):
25 self.odd_pot = odd_pot
26 self.n_kps = wfs.kd.nibzkpts
27 self.g_tol = g_tol / Hartree
28 self.dtype = wfs.dtype
29 self.get_en_and_grad_iters = 0
30 self.precond = {}
31 self.max_iter_line_search = 6
32 self.n_counter = maxiter
33 self.maxstep = maxstepxst
34 self.eg_count = 0
35 self.total_eg_count = 0
36 self.run_count = 0
37 self.U_k = {}
38 self.Unew_k = {}
39 self.e_total = 0.0
40 self.n_occ = {}
41 self.useprec = useprec
42 for kpt in wfs.kpt_u:
43 k = self.kpointval(kpt)
44 if nstates == 'all':
45 self.n_occ[k] = wfs.bd.nbands
46 elif nstates == 'occupied':
47 self.n_occ[k] = get_n_occ(kpt)[0]
48 else:
49 raise NotImplementedError
50 self.U_k[k] = np.eye(self.n_occ[k], dtype=self.dtype)
51 self.Unew_k[k] = np.eye(self.n_occ[k], dtype=self.dtype)
52 self.momcounter = 0
53 self.momevery = momevery
54 self.restart = False
55 self.eks = 0.0
56 self.esic = 0.0
57 self.kappa = 0.0
59 def update_ks_energy(self, wfs, dens, ham):
60 wfs.timer.start('Update Kohn-Sham energy')
61 # calc projectors
62 for kpt in wfs.kpt_u:
63 wfs.pt.integrate(kpt.psit_nG, kpt.P_ani, kpt.q)
65 dens.update(wfs)
66 ham.update(dens, wfs, False)
67 wfs.timer.stop('Update Kohn-Sham energy')
68 return ham.get_energy(0.0, wfs, False)
70 def get_energy_and_gradients(self, a_k, wfs, dens, ham):
71 """
72 Energy E = E[A]. Gradients G_ij[A] = dE/dA_ij
73 Returns E[A] and G[A] at psi = exp(A).T kpt.psi
74 :param a_k: A
75 :return:
76 """
78 g_k = {}
79 self.e_total = 0.0
80 self.esic = 0.0
81 self.kappa = 0.0
82 for k, kpt in enumerate(wfs.kpt_u):
83 n_occ = self.n_occ[k]
84 if n_occ == 0:
85 g_k[k] = np.zeros_like(a_k[k])
87 evecs, evals = self.rotate_wavefunctions(wfs, a_k)
89 self.eks = self.update_ks_energy(wfs, dens, ham)
91 for k, kpt in enumerate(wfs.kpt_u):
92 wfs.timer.start('Energy and gradients')
93 g_k[k], esic, kappa1 = \
94 self.odd_pot.get_energy_and_gradients_inner_loop(
95 wfs, kpt, a_k[k], evals[k], evecs[k], ham=ham,
96 exstate=True)
97 wfs.timer.stop('Energy and gradients')
98 if kappa1 > self.kappa:
99 self.kappa = kappa1
100 self.esic += esic
102 self.check_mom(wfs, dens)
103 self.e_total = self.eks + self.esic
105 self.kappa = wfs.kd.comm.max_scalar(self.kappa)
106 self.eg_count += 1
107 self.total_eg_count += 1
109 return self.e_total, g_k
111 def rotate_wavefunctions(self, wfs, a_k):
112 evecs = {}
113 evals = {}
114 for k, kpt in enumerate(wfs.kpt_u):
115 n_occ = self.n_occ[k]
116 if n_occ == 0:
117 continue
118 wfs.timer.start('Unitary matrix')
119 u_mat, evecs[k], evals[k] = expm_ed(a_k[k], evalevec=True)
120 wfs.timer.stop('Unitary matrix')
121 self.Unew_k[k] = u_mat.copy()
122 kpt.psit_nG[:n_occ] = \
123 np.tensordot(u_mat.T, self.psit_knG[k][:n_occ], axes=1)
124 # calc projectors
125 wfs.pt.integrate(kpt.psit_nG, kpt.P_ani, kpt.q)
127 return evecs, evals
129 def evaluate_phi_and_der_phi(self, a_k, p_k, alpha,
130 wfs, dens, ham,
131 phi=None, g_k=None):
132 """
133 phi = f(x_k + alpha_k*p_k)
134 der_phi = grad f(x_k + alpha_k*p_k) cdot p_k
135 :return: phi, der_phi, grad f(x_k + alpha_k*p_k)
136 """
137 if phi is None or g_k is None:
138 x_k = {k: a_k[k] + alpha * p_k[k] for k in a_k.keys()}
139 phi, g_k = \
140 self.get_energy_and_gradients(x_k, wfs, dens, ham)
141 del x_k
142 else:
143 pass
145 der_phi = 0.0
146 for k in p_k.keys():
147 il1 = get_indices(p_k[k].shape[0])
148 der_phi += np.dot(g_k[k][il1].conj(), p_k[k][il1]).real
150 der_phi = wfs.kd.comm.sum_scalar(der_phi)
152 return phi, der_phi, g_k
154 def get_search_direction(self, a_k, g_k, wfs):
156 # structure of vector is
157 # (x_1_up, x_2_up,..,y_1_up, y_2_up,..,
158 # x_1_down, x_2_down,..,y_1_down, y_2_down,.. )
160 a = {}
161 g = {}
163 for k in a_k.keys():
164 il1 = get_indices(a_k[k].shape[0])
165 a[k] = a_k[k][il1]
166 g[k] = g_k[k][il1]
168 p = self.sd.update_data(wfs, a, g, precond=self.precond, mode='lcao')
169 del a, g
171 p_k = {}
172 for k in p.keys():
173 p_k[k] = np.zeros_like(a_k[k])
174 il1 = get_indices(a_k[k].shape[0])
175 p_k[k][il1] = p[k]
176 # make it skew-hermitian
177 ind_l = np.tril_indices(p_k[k].shape[0], -1)
178 p_k[k][(ind_l[1], ind_l[0])] = -p_k[k][ind_l].conj()
179 del p
181 return p_k
183 def run(self, wfs, dens, log, outer_counter=0, ham=None):
185 log = log
186 self.run_count += 1
187 self.counter = 0
188 self.eg_count = 0
189 self.momcounter = 1
190 self.converged = False
191 # initial things
192 self.psit_knG = {}
193 for kpt in wfs.kpt_u:
194 k = self.kpointval(kpt)
195 n_occ = self.n_occ[k]
196 self.psit_knG[k] = np.tensordot(
197 self.U_k[k].T, kpt.psit_nG[:n_occ], axes=1)
199 a_k = {}
200 for kpt in wfs.kpt_u:
201 k = self.kpointval(kpt)
202 d = self.n_occ[k]
203 a_k[k] = np.zeros(shape=(d, d), dtype=self.dtype)
205 self.sd = LSR1P(memory=50)
206 self.ls = MaxStep(self.evaluate_phi_and_der_phi, max_step=self.maxstep)
208 threelasten = []
209 # get initial energy and gradients
210 self.e_total, g_k = self.get_energy_and_gradients(a_k, wfs, dens, ham)
211 threelasten.append(self.e_total)
212 g_max = g_max_norm(g_k, wfs)
213 if g_max < self.g_tol:
214 self.converged = True
215 for kpt in wfs.kpt_u:
216 k = self.kpointval(kpt)
217 n_occ = self.n_occ[k]
218 kpt.psit_nG[:n_occ] = np.tensordot(
219 self.U_k[k].conj(), self.psit_knG[k], axes=1)
220 # calc projectors
221 wfs.pt.integrate(kpt.psit_nG, kpt.P_ani, kpt.q)
223 self.U_k[k] = self.U_k[k] @ self.Unew_k[k]
224 if outer_counter is None:
225 return self.e_total, self.counter
226 else:
227 return self.e_total, outer_counter
229 if self.restart:
230 del self.psit_knG
231 return 0.0, 0
233 # stuff which are needed for minim.
234 phi_0 = self.e_total
235 phi_old = None
236 der_phi_old = None
237 phi_old_2 = None
238 der_phi_old_2 = None
240 outer_counter += 1
241 if log is not None:
242 log_f(log, self.counter, self.kappa, self.eks, self.esic,
243 outer_counter, g_max)
245 alpha = 1.0
246 not_converged = True
247 while not_converged:
248 self.precond = self.update_preconditioning(wfs, self.useprec)
250 # calculate search direction fot current As and Gs
251 p_k = self.get_search_direction(a_k, g_k, wfs)
253 # calculate derivative along the search direction
254 phi_0, der_phi_0, g_k = \
255 self.evaluate_phi_and_der_phi(
256 a_k, p_k, 0.0, wfs, dens, ham=ham, phi=phi_0, g_k=g_k)
257 if self.counter > 1:
258 phi_old = phi_0
259 der_phi_old = der_phi_0
261 # choose optimal step length along the search direction
262 # also get energy and gradients for optimal step
263 alpha, phi_0, der_phi_0, g_k = \
264 self.ls.step_length_update(
265 a_k, p_k, wfs, dens, ham, mode='lcao',
266 phi_0=phi_0, der_phi_0=der_phi_0,
267 phi_old=phi_old_2, der_phi_old=der_phi_old_2,
268 alpha_max=3.0, alpha_old=alpha, kpdescr=wfs.kd)
270 # broadcast data is gd.comm > 1
271 if wfs.gd.comm.size > 1:
272 alpha_phi_der_phi = np.array([alpha, phi_0, der_phi_0])
273 wfs.gd.comm.broadcast(alpha_phi_der_phi, 0)
274 alpha = alpha_phi_der_phi[0]
275 phi_0 = alpha_phi_der_phi[1]
276 der_phi_0 = alpha_phi_der_phi[2]
277 for kpt in wfs.kpt_u:
278 k = self.kpointval(kpt)
279 if self.n_occ[k] == 0:
280 continue
281 wfs.gd.comm.broadcast(g_k[k], 0)
283 phi_old_2 = phi_old
284 der_phi_old_2 = der_phi_old
286 if self.restart:
287 if log is not None:
288 log('MOM has detected variational collapse, '
289 'occupied orbitals have changed')
290 break
292 if alpha > 1.0e-10:
293 # calculate new matrices at optimal step lenght
294 a_k = {k: a_k[k] + alpha * p_k[k] for k in a_k.keys()}
295 g_max = g_max_norm(g_k, wfs)
297 # output
298 self.counter += 1
299 if outer_counter is not None:
300 outer_counter += 1
301 if log is not None:
302 log_f(
303 log, self.counter, self.kappa, self.eks, self.esic,
304 outer_counter, g_max)
306 not_converged = \
307 g_max > self.g_tol and \
308 self.counter < self.n_counter
309 if g_max <= self.g_tol:
310 self.converged = True
311 else:
312 break
314 if log is not None:
315 log('INNER LOOP FINISHED.\n')
316 log('Total number of e/g calls:' + str(self.eg_count))
318 if not self.restart:
319 for kpt in wfs.kpt_u:
320 k = self.kpointval(kpt)
321 n_occ = self.n_occ[k]
322 kpt.psit_nG[:n_occ] = np.tensordot(self.U_k[k].conj(),
323 self.psit_knG[k],
324 axes=1)
325 # calc projectors
326 wfs.pt.integrate(kpt.psit_nG, kpt.P_ani, kpt.q)
327 self.U_k[k] = self.U_k[k] @ self.Unew_k[k]
329 if outer_counter is None:
330 return self.e_total, self.counter
331 else:
332 return self.e_total, outer_counter
334 def update_preconditioning(self, wfs, use_prec):
335 counter = 30
336 if use_prec:
337 if self.counter % counter == 0:
338 for kpt in wfs.kpt_u:
339 k = self.kpointval(kpt)
340 hess = get_approx_analytical_hessian(kpt, self.dtype)
341 if self.dtype == float:
342 self.precond[k] = np.zeros_like(hess)
343 for i in range(hess.shape[0]):
344 if abs(hess[i]) < 1.0e-4:
345 self.precond[k][i] = 1.0
346 else:
347 self.precond[k][i] = 1.0 / hess[i].real
348 else:
349 self.precond[k] = np.zeros_like(hess)
350 for i in range(hess.shape[0]):
351 if abs(hess[i]) < 1.0e-4:
352 self.precond[k][i] = 1.0 + 1.0j
353 else:
354 self.precond[k][i] = \
355 1.0 / hess[i].real + 1.0j / hess[i].imag
356 return self.precond
357 else:
358 return self.precond
359 else:
360 return None
362 def check_mom(self, wfs, dens):
363 if self.momcounter % self.momevery == 0:
364 occ_name = getattr(wfs.occupations, "name", None)
365 if occ_name == 'mom':
366 wfs.calculate_occupation_numbers(dens.fixed)
367 self.restart = sort_orbitals_according_to_occ(
368 wfs, update_mom=True, update_eps=False)
369 self.momcounter += 1
371 def kpointval(self, kpt):
372 return self.n_kps * kpt.s + kpt.q
375def log_f(log, niter, kappa, eks, esic, outer_counter=None, g_max=np.inf):
377 t = time.localtime()
379 if niter == 0:
380 header0 = '\nINNER LOOP:\n'
381 header = ' Kohn-Sham SIC' \
382 ' Total \n' \
383 ' time energy: energy:' \
384 ' energy: Error: G_max:'
385 log(header0 + header)
387 if outer_counter is not None:
388 niter = outer_counter
390 log('iter: %3d %02d:%02d:%02d ' %
391 (niter,
392 t[3], t[4], t[5]
393 ), end='')
394 log('%11.6f %11.6f %11.6f %11.1e %11.1e' %
395 (Hartree * eks,
396 Hartree * esic,
397 Hartree * (eks + esic),
398 kappa,
399 Hartree * g_max), end='')
400 log(flush=True)
403def g_max_norm(g_k, wfs):
404 # get maximum of gradients
405 n_kps = wfs.kd.nibzkpts
407 max_norm = []
408 for kpt in wfs.kpt_u:
409 k = n_kps * kpt.s + kpt.q
410 dim = g_k[k].shape[0]
411 if dim == 0:
412 max_norm.append(0.0)
413 else:
414 max_norm.append(np.max(np.absolute(g_k[k])))
415 max_norm = np.max(np.asarray(max_norm))
416 g_max = wfs.world.max_scalar(max_norm)
418 return g_max