Coverage for gpaw/directmin/tools.py: 86%
271 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1"""
2Tools for directmin
3"""
5import numpy as np
6import scipy.linalg as lalg
7from copy import deepcopy
8from typing import Callable, cast
9from gpaw.typing import ArrayND, IntVector, RNG
12def expm_ed(a_mat, evalevec=False):
13 """
14 calculate matrix exponential
15 using eigendecomposition of matrix a_mat
17 :param a_mat: matrix to be exponented
18 :param evalevec: if True then returns eigenvalues
19 and eigenvectors of A
21 :return:
22 """
24 eigval, evec = np.linalg.eigh(1.0j * a_mat)
26 product = (evec * np.exp(-1.0j * eigval)) @ evec.T.conj()
28 if a_mat.dtype == float:
29 product = product.real
30 if evalevec:
31 return np.ascontiguousarray(product), evec, eigval
33 return np.ascontiguousarray(product)
36def expm_ed_unit_inv(a_upp_r, oo_vo_blockonly=False):
37 """
38 calculate matrix exponential using
39 Eq. (6) from
40 J. Hutter, M. Parrinello, and S. Vogel,
41 J. Chem. Phys., 101, 3862 (1994)
42 :param a_upp_r: X (see eq in paper)
43 :return: unitary matrix
44 """
45 if np.allclose(a_upp_r, np.zeros_like(a_upp_r)):
46 dim_v = a_upp_r.shape[1]
47 dim_o = a_upp_r.shape[0]
48 if not oo_vo_blockonly:
49 dim_v = a_upp_r.shape[1]
50 dim_o = a_upp_r.shape[0]
52 return np.eye(dim_o + dim_v, dtype=a_upp_r.dtype)
53 else:
54 return np.vstack([np.eye(dim_o, dtype=a_upp_r.dtype),
55 np.zeros(shape=(dim_v, dim_o),
56 dtype=a_upp_r.dtype)])
58 p_nn = a_upp_r @ a_upp_r.T.conj()
59 eigval, evec = np.linalg.eigh(p_nn)
60 # Eigenvalues cannot be negative
61 eigval[eigval.real < 1.0e-13] = 1.0e-13
62 sqrt_eval = np.sqrt(eigval)
64 cos_sqrt_p = matrix_function(sqrt_eval, evec, np.cos)
65 psin = matrix_function(sqrt_eval / np.pi, evec, np.sinc)
66 u_oo = cos_sqrt_p
67 u_vo = - a_upp_r.T.conj() @ psin
69 if not oo_vo_blockonly:
70 u_ov = psin @ a_upp_r
71 dim_v = a_upp_r.shape[1]
73 pcos = matrix_function((np.cos(sqrt_eval) - 1) / eigval, evec)
74 u_vv = np.eye(dim_v) + a_upp_r.T.conj() @ pcos @ a_upp_r
75 u = np.vstack([
76 np.hstack([u_oo, u_ov]),
77 np.hstack([u_vo, u_vv])])
78 else:
79 u = np.vstack([u_oo, u_vo])
81 return np.ascontiguousarray(u)
84def d_matrix(omega):
85 """
86 Helper function for calculation of gradient
87 w.r.t. skew-hermitian matrix
88 see eq. 40 from
89 A. V. Ivanov, E. Jónsson, T. Vegge, and H. Jónsso
90 Comput. Phys. Commun., 267, 108047 (2021).
91 arXiv:2101.12597 [physics.comp-ph]
92 """
94 m = omega.shape[0]
95 u_m = np.ones(shape=(m, m))
97 u_m = omega[:, np.newaxis] * u_m - omega * u_m
99 with np.errstate(divide='ignore', invalid='ignore'):
100 u_m = 1.0j * np.divide(np.exp(-1.0j * u_m) - 1.0, u_m)
102 u_m[np.isnan(u_m)] = 1.0
103 u_m[np.isinf(u_m)] = 1.0
105 return u_m
108def minimum_cubic_interpol(x_0, x_1, f_0, f_1, df_0, df_1):
109 """
110 given f, f' at boundaries of interval [x0, x1]
111 calc. x_min where cubic interpolation is minimal
112 :return: x_min
113 """
115 def cubic_function(a, b, c, d, x):
116 """
117 f(x) = a x^3 + b x^2 + c x + d
118 :return: f(x)
119 """
120 return a * x ** 3 + b * x ** 2 + c * x + d
122 if x_0 > x_1:
123 x_0, x_1 = x_1, x_0
124 f_0, f_1 = f_1, f_0
125 df_0, df_1 = df_1, df_0
127 r = x_1 - x_0
128 a = - 2.0 * (f_1 - f_0) / r ** 3.0 + \
129 (df_1 + df_0) / r ** 2.0
130 b = 3.0 * (f_1 - f_0) / r ** 2.0 - \
131 (df_1 + 2.0 * df_0) / r
132 c = df_0
133 d = f_0
134 D = b ** 2.0 - 3.0 * a * c
136 if D < 0.0:
137 if f_0 < f_1:
138 x_min = x_0
139 else:
140 x_min = x_1
141 else:
142 r0 = (-b + np.sqrt(D)) / (3.0 * a) + x_0
143 if x_0 < r0 < x_1:
144 f_r0 = cubic_function(a, b, c, d, r0 - x_0)
145 if f_0 > f_r0 and f_1 > f_r0:
146 x_min = r0
147 else:
148 if f_0 < f_1:
149 x_min = x_0
150 else:
151 x_min = x_1
152 else:
153 if f_0 < f_1:
154 x_min = x_0
155 else:
156 x_min = x_1
158 return x_min
161def matrix_function(evals, evecs, func=lambda x: x):
162 """
163 calculate matrix function func(A)
164 you need to provide
165 :param evals: eigenvalues of A
166 :param evecs: eigenvectors of A
167 :return: func(A)
168 """
169 return (evecs * func(evals)) @ evecs.T.conj()
172def loewdin_lcao(C_nM, S_MM):
173 """
174 Loewdin based orthonormalization
175 for LCAO mode
177 C_nM <- sum_m C_nM[m] [1/sqrt(S)]_mn
178 S_mn = (C_nM[m].conj(), S_MM C_nM[n])
180 :param C_nM: LCAO coefficients
181 :param S_MM: Overlap matrix between basis functions
182 :return: Orthonormalized coefficients so that new S_mn = delta_mn
183 """
185 ev, S_overlapp = np.linalg.eigh(C_nM.conj() @ S_MM @ C_nM.T)
186 ev_sqrt = np.diag(1.0 / np.sqrt(ev))
188 S = S_overlapp @ ev_sqrt @ S_overlapp.T.conj()
190 return S.T @ C_nM
193def gramschmidt_lcao(C_nM, S_MM):
194 """
195 Gram-Schmidt orthonormalization using Cholesky decomposition
196 for LCAO mode
198 :param C_nM: LCAO coefficients
199 :param S_MM: Overlap matrix between basis functions
200 :return: Orthonormalized coefficients so that new S_mn = delta_mn
201 """
203 S_nn = C_nM @ S_MM.conj() @ C_nM.T.conj()
204 L_nn = lalg.cholesky(S_nn, lower=True,
205 overwrite_a=True, check_finite=False)
206 return lalg.solve(L_nn, C_nM)
209def excite(calc, i, a, spin=(0, 0), sort=False):
210 """Helper function to initialize a variational excited state calculation.
212 Promote an electron from homo + i of k-point spin[0] to lumo + a of
213 k-point spin[1].
215 Parameters
216 ----------
217 calc: GPAW instance
218 GPAW calculator object.
219 i: int
220 Subtract 1 from the occupation number of the homo + i orbital of
221 k-point spin[0]. E.g. if i=-1, an electron is removed from the
222 homo - 1 orbital.
223 a: int
224 Add 1 to the occupation number of the lumo + a orbital of k-point
225 spin[1]. E.g. if a=1, an electron is added to the lumo + 1 orbital.
226 spin: tuple of two int
227 spin[0] is the k-point from which an electron is removed and spin[1]
228 is the k-point where an electron is added.
229 sort: bool
230 If True, sort the orbitals in the wfs object according to the new
231 occupation numbers, and modify the f_n attribute of the kpt objects.
232 Default is False.
234 Returns
235 -------
236 list of numpy.ndarray
237 List of new occupation numbers. Can be supplied to
238 mom.prepare_mom_calculation to initialize an excited state calculation
239 with MOM.
240 """
241 f_sn = [calc.get_occupation_numbers(spin=s).copy()
242 for s in range(calc.wfs.nspins)]
244 f_n0 = np.asarray(f_sn[spin[0]])
245 lumo = len(f_n0[f_n0 > 0])
246 homo = lumo - 1
248 f_sn[spin[0]][homo + i] -= 1.0
249 f_sn[spin[1]][lumo + a] += 1.0
251 if sort:
252 for s in spin:
253 for kpt in calc.wfs.kpt_u:
254 if kpt.s == s:
255 kpt.f_n = f_sn[s]
256 changedocc = sort_orbitals_according_to_occ_kpt(
257 calc.wfs, kpt, update_mom=False)[0]
258 if changedocc:
259 f_sn[s] = kpt.f_n
261 return f_sn
264def sort_orbitals_according_to_occ(
265 wfs, constraints=None, update_mom=False, update_eps=True):
266 """
267 Sort orbitals according to the occupation
268 numbers so that there are no holes in the
269 distribution of occupation numbers
270 :return:
271 """
272 restart = False
273 for kpt in wfs.kpt_u:
274 changedocc, ind = sort_orbitals_according_to_occ_kpt(
275 wfs, kpt, update_mom=update_mom, update_eps=update_eps)
277 if changedocc:
278 if constraints:
279 k = wfs.kd.nibzkpts * kpt.s + kpt.q
280 # Identities of the constrained orbitals have
281 # changed and needs to be updated
282 constraints[k] = update_constraints_kpt(
283 constraints[k], list(ind))
284 restart = True
286 return restart
289def sort_orbitals_according_to_occ_kpt(
290 wfs, kpt, update_mom=False, update_eps=True):
291 """
292 Sort orbitals according to the occupation
293 numbers so that there are no holes in the
294 distribution of occupation numbers
295 :return:
296 """
297 changedocc = False
298 update_proj = True
299 ind = np.array([])
301 # Need to initialize the wave functions if
302 # restarting from gpw file in fd or pw mode
303 if kpt.psit_nG is not None:
304 if not isinstance(kpt.psit_nG, np.ndarray):
305 wfs.initialize_wave_functions_from_restart_file()
306 update_proj = False
308 n_occ, occupied = get_n_occ(kpt)
309 if n_occ != 0.0 and np.min(kpt.f_n[:n_occ]) == 0:
310 ind_occ = np.argwhere(occupied)
311 ind_unocc = np.argwhere(~occupied)
312 ind = np.vstack((ind_occ, ind_unocc))
313 ind = np.squeeze(ind)
315 if hasattr(wfs.eigensolver, 'dm_helper'):
316 wfs.eigensolver.dm_helper.sort_orbitals(wfs, kpt, ind)
317 else:
318 sort_orbitals_kpt(wfs, kpt, ind, update_proj)
320 kpt.f_n = kpt.f_n[ind]
321 if update_eps:
322 kpt.eps_n[:] = kpt.eps_n[ind]
324 if update_mom:
325 # OccupationsMOM.numbers needs
326 # to be updated after sorting
327 update_mom_numbers(wfs, kpt)
329 changedocc = True
331 return changedocc, ind
334def sort_orbitals_according_to_energies(
335 ham, wfs, constraints=None):
336 """
337 Sort orbitals according to the eigenvalues or
338 the diagonal elements of the Hamiltonian matrix
339 """
340 eigensolver_name = getattr(wfs.eigensolver, "name", None)
341 if hasattr(wfs.eigensolver, 'dm_helper'):
342 dm_helper = wfs.eigensolver.dm_helper
343 is_sic = 'SIC' in dm_helper.func.name
344 else:
345 dm_helper = None
346 if hasattr(wfs.eigensolver, 'odd'):
347 is_sic = 'SIC' in wfs.eigensolver.odd.name
349 lcao_sic = eigensolver_name == 'etdm-lcao' and is_sic
350 fdpw_sic = eigensolver_name == 'etdm-fdpw' and is_sic
352 for kpt in wfs.kpt_u:
353 k = wfs.kd.nibzkpts * kpt.s + kpt.q
354 if lcao_sic:
355 orb_energies = wfs.eigensolver.dm_helper.orbital_energies(
356 wfs, ham, kpt)
357 elif fdpw_sic:
358 orb_energies = wfs.eigensolver.odd.lagr_diag_s[k]
359 else:
360 orb_energies = kpt.eps_n
362 if is_sic:
363 # For SIC, we sort occupied and unoccupied orbitals
364 # separately, so the occupation numbers of canonical
365 # and optimal orbitals are always consistent
366 n_occ, occupied = get_n_occ(kpt)
367 ind_occ = np.argsort(orb_energies[occupied])
368 ind_unocc = np.argsort(orb_energies[~occupied])
369 ind = np.concatenate((ind_occ, ind_unocc + n_occ))
370 # For SIC, we need to sort both the diagonal elements of
371 # the Lagrange matrix and the self-interaction energies
372 if dm_helper is None:
373 # Directly sort the solver energies
374 wfs.eigensolver.odd.lagr_diag_s[k] = orb_energies[ind]
375 wfs.eigensolver.odd.e_sic_by_orbitals[k] = (
376 wfs.eigensolver.odd.e_sic_by_orbitals)[k][ind_occ]
377 else:
378 dm_helper.func.lagr_diag_s[k] = orb_energies[ind]
379 dm_helper.func.e_sic_by_orbitals[k] = (
380 dm_helper.func.e_sic_by_orbitals)[k][ind_occ]
381 else:
382 ind = np.argsort(orb_energies)
383 kpt.eps_n[np.arange(len(ind))] = orb_energies[ind]
385 # now sort wfs according to orbital energies
386 if dm_helper:
387 dm_helper.sort_orbitals(wfs, kpt, ind)
388 else:
389 sort_orbitals_kpt(wfs, kpt, ind, update_proj=True)
391 assert len(ind) == len(kpt.f_n)
392 # kpt.f_n[np.arange(len(ind))] = kpt.f_n[ind]
393 kpt.f_n = kpt.f_n[ind]
395 occ_name = getattr(wfs.occupations, "name", None)
396 if occ_name == 'mom':
397 # OccupationsMOM.numbers needs to be updated
398 # after sorting
399 update_mom_numbers(wfs, kpt)
400 if constraints:
401 # Identity if the constrained orbitals have
402 # changed and need to be updated
403 constraints[k] = update_constraints_kpt(
404 constraints[k], list(ind))
407def update_mom_numbers(wfs, kpt):
408 if wfs.collinear and wfs.nspins == 1:
409 degeneracy = 2
410 else:
411 degeneracy = 1
412 wfs.occupations.numbers[kpt.s] = \
413 kpt.f_n / (kpt.weightk * degeneracy)
416def sort_orbitals_kpt(wfs, kpt, ind, update_proj=False):
417 if wfs.mode == 'lcao':
418 kpt.C_nM[np.arange(len(ind)), :] = kpt.C_nM[ind, :]
419 wfs.atomic_correction.calculate_projections(wfs, kpt)
420 else:
421 kpt.psit_nG[np.arange(len(ind))] = kpt.psit_nG[ind]
422 if update_proj:
423 wfs.pt.integrate(kpt.psit_nG, kpt.P_ani, kpt.q)
426def update_constraints_kpt(constraints, ind):
427 """
428 Change the constraint indices to match a new indexation, e.g. due to
429 sorting the orbitals
431 :param constraints: The list of constraints for one K-point
432 :param ind: List containing information about the change in indexation
433 """
435 new = deepcopy(constraints)
436 for i in range(len(constraints)):
437 for k in range(len(constraints[i])):
438 new[i][k] = ind.index(constraints[i][k])
439 return new
442def dict_to_array(x):
443 """
444 Converts dictionaries with integer keys to one long array by appending.
446 :param x: Dictionary
447 :return: Long array, dimensions of original dictionary parts, total
448 dimensions
449 """
450 y = []
451 dim = []
452 dimtot = 0
453 for k in x.keys():
454 assert isinstance(k, int), (
455 'Cannot convert dict to array if keys are not '
456 'integer.')
457 y += list(x[k])
458 dim.append(len(x[k]))
459 dimtot += len(x[k])
460 return np.asarray(y), dim, dimtot
463def array_to_dict(x, dim):
464 """
465 Converts long array to dictionary with integer keys with values of
466 dimensionality specified in dim.
468 :param x: Array
469 :param dim: List with dimensionalities of parts of the dictionary
470 :return: Dictionary
471 """
472 y = {}
473 start = 0
474 stop = 0
475 for i in range(len(dim)):
476 stop += dim[i]
477 y[i] = x[start: stop]
478 start += dim[i]
479 return y
482def rotate_orbitals(etdm, wfs, indices, angles, channels):
483 """
484 Applies rotations between pairs of orbitals.
486 :param etdm: ETDM object for a converged or at least initialized
487 calculation
488 :param indices: List of indices. Each element must be a list of an
489 orbital pair corresponding to the orbital rotation.
490 For occupied-virtual rotations (unitary invariant or
491 sparse representations), the first index represents the
492 occupied, the second the virtual orbital.
493 For occupied-occupied rotations (sparse representation
494 only), the first index must always be smaller than the
495 second.
496 :param angles: List of angles in radians.
497 :param channels: List of spin channels.
498 """
500 angles = - np.array(angles) * np.pi / 180.0
501 a_vec_u = get_a_vec_u(etdm, wfs, indices, angles, channels)
502 c = {}
503 for kpt in wfs.kpt_u:
504 k = etdm.kpointval(kpt)
505 c[k] = wfs.kpt_u[k].C_nM.copy()
506 etdm.rotate_wavefunctions(wfs, a_vec_u, c)
509def get_a_vec_u(etdm, wfs, indices, angles, channels, occ=None):
510 """
511 Creates an orbital rotation vector based on given indices, angles and
512 corresponding spin channels.
514 :param etdm: ETDM object for a converged or at least initialized
515 calculation
516 :param indices: List of indices. Each element must be a list of an
517 orbital pair corresponding to the orbital rotation.
518 For occupied-virtual rotations (unitary invariant or
519 sparse representations), the first index represents the
520 occupied, the second the virtual orbital.
521 For occupied-occupied rotations (sparse representation
522 only), the first index must always be smaller than the
523 second.
524 :param angles: List of angles in radians.
525 :param channels: List of spin channels.
526 :param occ: Occupation numbers for each k-point. Must be specified
527 if the orbitals in the ETDM object are not ordered
528 canonically, as the user orbital indexation is different
529 from the one in the ETDM object then.
531 :return new_vec_u: Orbital rotation coordinate vector containing the
532 specified values.
533 """
535 sort_orbitals_according_to_occ(wfs, etdm.constraints, update_mom=True)
537 new_vec_u = {}
538 ind_up = etdm.ind_up
539 a_vec_u = deepcopy(etdm.a_vec_u)
540 conversion = []
541 for k in a_vec_u.keys():
542 new_vec_u[k] = np.zeros_like(a_vec_u[k])
543 if occ is not None:
544 f_n = occ[k]
545 occupied = f_n > 1.0e-10
546 n_occ = len(f_n[occupied])
547 if n_occ == 0.0:
548 continue
549 if np.min(f_n[:n_occ]) == 0:
550 ind_occ = np.argwhere(occupied)
551 ind_unocc = np.argwhere(~occupied)
552 ind = np.vstack((ind_occ, ind_unocc))
553 ind = np.squeeze(ind)
554 conversion.append(list(ind))
555 else:
556 conversion.append(None)
558 for ind, ang, s in zip(indices, angles, channels):
559 if occ is not None:
560 if conversion[s] is not None:
561 ind[0] = conversion[s].index(ind[0])
562 ind[1] = conversion[s].index(ind[1])
563 m = np.where(ind_up[s][0] == ind[0])[0]
564 n = np.where(ind_up[s][1] == ind[1])[0]
565 res = None
566 for i in m:
567 for j in n:
568 if i == j:
569 res = i
570 if res is None:
571 raise ValueError('Orbital rotation does not exist.')
572 new_vec_u[s][res] = ang
574 return new_vec_u
577def get_n_occ(kpt):
578 occupied = kpt.f_n > 1.0e-10
579 n_occ = len(kpt.f_n[occupied])
581 return n_occ, occupied
584def get_indices(dimens):
585 return np.tril_indices(dimens, -1)
588def random_a(shape, dtype, rng: RNG = cast(RNG, np.random)):
589 sample_unit_interval: Callable[[IntVector], ArrayND] = rng.random
590 a = sample_unit_interval(shape)
591 if dtype == complex:
592 a = a.astype(complex)
593 a += 1.0j * sample_unit_interval(shape)
595 return a