Coverage for gpaw/directmin/tools.py: 86%

271 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1""" 

2Tools for directmin 

3""" 

4 

5import numpy as np 

6import scipy.linalg as lalg 

7from copy import deepcopy 

8from typing import Callable, cast 

9from gpaw.typing import ArrayND, IntVector, RNG 

10 

11 

12def expm_ed(a_mat, evalevec=False): 

13 """ 

14 calculate matrix exponential 

15 using eigendecomposition of matrix a_mat 

16 

17 :param a_mat: matrix to be exponented 

18 :param evalevec: if True then returns eigenvalues 

19 and eigenvectors of A 

20 

21 :return: 

22 """ 

23 

24 eigval, evec = np.linalg.eigh(1.0j * a_mat) 

25 

26 product = (evec * np.exp(-1.0j * eigval)) @ evec.T.conj() 

27 

28 if a_mat.dtype == float: 

29 product = product.real 

30 if evalevec: 

31 return np.ascontiguousarray(product), evec, eigval 

32 

33 return np.ascontiguousarray(product) 

34 

35 

36def expm_ed_unit_inv(a_upp_r, oo_vo_blockonly=False): 

37 """ 

38 calculate matrix exponential using 

39 Eq. (6) from 

40 J. Hutter, M. Parrinello, and S. Vogel, 

41 J. Chem. Phys., 101, 3862 (1994) 

42 :param a_upp_r: X (see eq in paper) 

43 :return: unitary matrix 

44 """ 

45 if np.allclose(a_upp_r, np.zeros_like(a_upp_r)): 

46 dim_v = a_upp_r.shape[1] 

47 dim_o = a_upp_r.shape[0] 

48 if not oo_vo_blockonly: 

49 dim_v = a_upp_r.shape[1] 

50 dim_o = a_upp_r.shape[0] 

51 

52 return np.eye(dim_o + dim_v, dtype=a_upp_r.dtype) 

53 else: 

54 return np.vstack([np.eye(dim_o, dtype=a_upp_r.dtype), 

55 np.zeros(shape=(dim_v, dim_o), 

56 dtype=a_upp_r.dtype)]) 

57 

58 p_nn = a_upp_r @ a_upp_r.T.conj() 

59 eigval, evec = np.linalg.eigh(p_nn) 

60 # Eigenvalues cannot be negative 

61 eigval[eigval.real < 1.0e-13] = 1.0e-13 

62 sqrt_eval = np.sqrt(eigval) 

63 

64 cos_sqrt_p = matrix_function(sqrt_eval, evec, np.cos) 

65 psin = matrix_function(sqrt_eval / np.pi, evec, np.sinc) 

66 u_oo = cos_sqrt_p 

67 u_vo = - a_upp_r.T.conj() @ psin 

68 

69 if not oo_vo_blockonly: 

70 u_ov = psin @ a_upp_r 

71 dim_v = a_upp_r.shape[1] 

72 

73 pcos = matrix_function((np.cos(sqrt_eval) - 1) / eigval, evec) 

74 u_vv = np.eye(dim_v) + a_upp_r.T.conj() @ pcos @ a_upp_r 

75 u = np.vstack([ 

76 np.hstack([u_oo, u_ov]), 

77 np.hstack([u_vo, u_vv])]) 

78 else: 

79 u = np.vstack([u_oo, u_vo]) 

80 

81 return np.ascontiguousarray(u) 

82 

83 

84def d_matrix(omega): 

85 """ 

86 Helper function for calculation of gradient 

87 w.r.t. skew-hermitian matrix 

88 see eq. 40 from 

89 A. V. Ivanov, E. Jónsson, T. Vegge, and H. Jónsso 

90 Comput. Phys. Commun., 267, 108047 (2021). 

91 arXiv:2101.12597 [physics.comp-ph] 

92 """ 

93 

94 m = omega.shape[0] 

95 u_m = np.ones(shape=(m, m)) 

96 

97 u_m = omega[:, np.newaxis] * u_m - omega * u_m 

98 

99 with np.errstate(divide='ignore', invalid='ignore'): 

100 u_m = 1.0j * np.divide(np.exp(-1.0j * u_m) - 1.0, u_m) 

101 

102 u_m[np.isnan(u_m)] = 1.0 

103 u_m[np.isinf(u_m)] = 1.0 

104 

105 return u_m 

106 

107 

108def minimum_cubic_interpol(x_0, x_1, f_0, f_1, df_0, df_1): 

109 """ 

110 given f, f' at boundaries of interval [x0, x1] 

111 calc. x_min where cubic interpolation is minimal 

112 :return: x_min 

113 """ 

114 

115 def cubic_function(a, b, c, d, x): 

116 """ 

117 f(x) = a x^3 + b x^2 + c x + d 

118 :return: f(x) 

119 """ 

120 return a * x ** 3 + b * x ** 2 + c * x + d 

121 

122 if x_0 > x_1: 

123 x_0, x_1 = x_1, x_0 

124 f_0, f_1 = f_1, f_0 

125 df_0, df_1 = df_1, df_0 

126 

127 r = x_1 - x_0 

128 a = - 2.0 * (f_1 - f_0) / r ** 3.0 + \ 

129 (df_1 + df_0) / r ** 2.0 

130 b = 3.0 * (f_1 - f_0) / r ** 2.0 - \ 

131 (df_1 + 2.0 * df_0) / r 

132 c = df_0 

133 d = f_0 

134 D = b ** 2.0 - 3.0 * a * c 

135 

136 if D < 0.0: 

137 if f_0 < f_1: 

138 x_min = x_0 

139 else: 

140 x_min = x_1 

141 else: 

142 r0 = (-b + np.sqrt(D)) / (3.0 * a) + x_0 

143 if x_0 < r0 < x_1: 

144 f_r0 = cubic_function(a, b, c, d, r0 - x_0) 

145 if f_0 > f_r0 and f_1 > f_r0: 

146 x_min = r0 

147 else: 

148 if f_0 < f_1: 

149 x_min = x_0 

150 else: 

151 x_min = x_1 

152 else: 

153 if f_0 < f_1: 

154 x_min = x_0 

155 else: 

156 x_min = x_1 

157 

158 return x_min 

159 

160 

161def matrix_function(evals, evecs, func=lambda x: x): 

162 """ 

163 calculate matrix function func(A) 

164 you need to provide 

165 :param evals: eigenvalues of A 

166 :param evecs: eigenvectors of A 

167 :return: func(A) 

168 """ 

169 return (evecs * func(evals)) @ evecs.T.conj() 

170 

171 

172def loewdin_lcao(C_nM, S_MM): 

173 """ 

174 Loewdin based orthonormalization 

175 for LCAO mode 

176 

177 C_nM <- sum_m C_nM[m] [1/sqrt(S)]_mn 

178 S_mn = (C_nM[m].conj(), S_MM C_nM[n]) 

179 

180 :param C_nM: LCAO coefficients 

181 :param S_MM: Overlap matrix between basis functions 

182 :return: Orthonormalized coefficients so that new S_mn = delta_mn 

183 """ 

184 

185 ev, S_overlapp = np.linalg.eigh(C_nM.conj() @ S_MM @ C_nM.T) 

186 ev_sqrt = np.diag(1.0 / np.sqrt(ev)) 

187 

188 S = S_overlapp @ ev_sqrt @ S_overlapp.T.conj() 

189 

190 return S.T @ C_nM 

191 

192 

193def gramschmidt_lcao(C_nM, S_MM): 

194 """ 

195 Gram-Schmidt orthonormalization using Cholesky decomposition 

196 for LCAO mode 

197 

198 :param C_nM: LCAO coefficients 

199 :param S_MM: Overlap matrix between basis functions 

200 :return: Orthonormalized coefficients so that new S_mn = delta_mn 

201 """ 

202 

203 S_nn = C_nM @ S_MM.conj() @ C_nM.T.conj() 

204 L_nn = lalg.cholesky(S_nn, lower=True, 

205 overwrite_a=True, check_finite=False) 

206 return lalg.solve(L_nn, C_nM) 

207 

208 

209def excite(calc, i, a, spin=(0, 0), sort=False): 

210 """Helper function to initialize a variational excited state calculation. 

211 

212 Promote an electron from homo + i of k-point spin[0] to lumo + a of 

213 k-point spin[1]. 

214 

215 Parameters 

216 ---------- 

217 calc: GPAW instance 

218 GPAW calculator object. 

219 i: int 

220 Subtract 1 from the occupation number of the homo + i orbital of 

221 k-point spin[0]. E.g. if i=-1, an electron is removed from the 

222 homo - 1 orbital. 

223 a: int 

224 Add 1 to the occupation number of the lumo + a orbital of k-point 

225 spin[1]. E.g. if a=1, an electron is added to the lumo + 1 orbital. 

226 spin: tuple of two int 

227 spin[0] is the k-point from which an electron is removed and spin[1] 

228 is the k-point where an electron is added. 

229 sort: bool 

230 If True, sort the orbitals in the wfs object according to the new 

231 occupation numbers, and modify the f_n attribute of the kpt objects. 

232 Default is False. 

233 

234 Returns 

235 ------- 

236 list of numpy.ndarray 

237 List of new occupation numbers. Can be supplied to 

238 mom.prepare_mom_calculation to initialize an excited state calculation 

239 with MOM. 

240 """ 

241 f_sn = [calc.get_occupation_numbers(spin=s).copy() 

242 for s in range(calc.wfs.nspins)] 

243 

244 f_n0 = np.asarray(f_sn[spin[0]]) 

245 lumo = len(f_n0[f_n0 > 0]) 

246 homo = lumo - 1 

247 

248 f_sn[spin[0]][homo + i] -= 1.0 

249 f_sn[spin[1]][lumo + a] += 1.0 

250 

251 if sort: 

252 for s in spin: 

253 for kpt in calc.wfs.kpt_u: 

254 if kpt.s == s: 

255 kpt.f_n = f_sn[s] 

256 changedocc = sort_orbitals_according_to_occ_kpt( 

257 calc.wfs, kpt, update_mom=False)[0] 

258 if changedocc: 

259 f_sn[s] = kpt.f_n 

260 

261 return f_sn 

262 

263 

264def sort_orbitals_according_to_occ( 

265 wfs, constraints=None, update_mom=False, update_eps=True): 

266 """ 

267 Sort orbitals according to the occupation 

268 numbers so that there are no holes in the 

269 distribution of occupation numbers 

270 :return: 

271 """ 

272 restart = False 

273 for kpt in wfs.kpt_u: 

274 changedocc, ind = sort_orbitals_according_to_occ_kpt( 

275 wfs, kpt, update_mom=update_mom, update_eps=update_eps) 

276 

277 if changedocc: 

278 if constraints: 

279 k = wfs.kd.nibzkpts * kpt.s + kpt.q 

280 # Identities of the constrained orbitals have 

281 # changed and needs to be updated 

282 constraints[k] = update_constraints_kpt( 

283 constraints[k], list(ind)) 

284 restart = True 

285 

286 return restart 

287 

288 

289def sort_orbitals_according_to_occ_kpt( 

290 wfs, kpt, update_mom=False, update_eps=True): 

291 """ 

292 Sort orbitals according to the occupation 

293 numbers so that there are no holes in the 

294 distribution of occupation numbers 

295 :return: 

296 """ 

297 changedocc = False 

298 update_proj = True 

299 ind = np.array([]) 

300 

301 # Need to initialize the wave functions if 

302 # restarting from gpw file in fd or pw mode 

303 if kpt.psit_nG is not None: 

304 if not isinstance(kpt.psit_nG, np.ndarray): 

305 wfs.initialize_wave_functions_from_restart_file() 

306 update_proj = False 

307 

308 n_occ, occupied = get_n_occ(kpt) 

309 if n_occ != 0.0 and np.min(kpt.f_n[:n_occ]) == 0: 

310 ind_occ = np.argwhere(occupied) 

311 ind_unocc = np.argwhere(~occupied) 

312 ind = np.vstack((ind_occ, ind_unocc)) 

313 ind = np.squeeze(ind) 

314 

315 if hasattr(wfs.eigensolver, 'dm_helper'): 

316 wfs.eigensolver.dm_helper.sort_orbitals(wfs, kpt, ind) 

317 else: 

318 sort_orbitals_kpt(wfs, kpt, ind, update_proj) 

319 

320 kpt.f_n = kpt.f_n[ind] 

321 if update_eps: 

322 kpt.eps_n[:] = kpt.eps_n[ind] 

323 

324 if update_mom: 

325 # OccupationsMOM.numbers needs 

326 # to be updated after sorting 

327 update_mom_numbers(wfs, kpt) 

328 

329 changedocc = True 

330 

331 return changedocc, ind 

332 

333 

334def sort_orbitals_according_to_energies( 

335 ham, wfs, constraints=None): 

336 """ 

337 Sort orbitals according to the eigenvalues or 

338 the diagonal elements of the Hamiltonian matrix 

339 """ 

340 eigensolver_name = getattr(wfs.eigensolver, "name", None) 

341 if hasattr(wfs.eigensolver, 'dm_helper'): 

342 dm_helper = wfs.eigensolver.dm_helper 

343 is_sic = 'SIC' in dm_helper.func.name 

344 else: 

345 dm_helper = None 

346 if hasattr(wfs.eigensolver, 'odd'): 

347 is_sic = 'SIC' in wfs.eigensolver.odd.name 

348 

349 lcao_sic = eigensolver_name == 'etdm-lcao' and is_sic 

350 fdpw_sic = eigensolver_name == 'etdm-fdpw' and is_sic 

351 

352 for kpt in wfs.kpt_u: 

353 k = wfs.kd.nibzkpts * kpt.s + kpt.q 

354 if lcao_sic: 

355 orb_energies = wfs.eigensolver.dm_helper.orbital_energies( 

356 wfs, ham, kpt) 

357 elif fdpw_sic: 

358 orb_energies = wfs.eigensolver.odd.lagr_diag_s[k] 

359 else: 

360 orb_energies = kpt.eps_n 

361 

362 if is_sic: 

363 # For SIC, we sort occupied and unoccupied orbitals 

364 # separately, so the occupation numbers of canonical 

365 # and optimal orbitals are always consistent 

366 n_occ, occupied = get_n_occ(kpt) 

367 ind_occ = np.argsort(orb_energies[occupied]) 

368 ind_unocc = np.argsort(orb_energies[~occupied]) 

369 ind = np.concatenate((ind_occ, ind_unocc + n_occ)) 

370 # For SIC, we need to sort both the diagonal elements of 

371 # the Lagrange matrix and the self-interaction energies 

372 if dm_helper is None: 

373 # Directly sort the solver energies 

374 wfs.eigensolver.odd.lagr_diag_s[k] = orb_energies[ind] 

375 wfs.eigensolver.odd.e_sic_by_orbitals[k] = ( 

376 wfs.eigensolver.odd.e_sic_by_orbitals)[k][ind_occ] 

377 else: 

378 dm_helper.func.lagr_diag_s[k] = orb_energies[ind] 

379 dm_helper.func.e_sic_by_orbitals[k] = ( 

380 dm_helper.func.e_sic_by_orbitals)[k][ind_occ] 

381 else: 

382 ind = np.argsort(orb_energies) 

383 kpt.eps_n[np.arange(len(ind))] = orb_energies[ind] 

384 

385 # now sort wfs according to orbital energies 

386 if dm_helper: 

387 dm_helper.sort_orbitals(wfs, kpt, ind) 

388 else: 

389 sort_orbitals_kpt(wfs, kpt, ind, update_proj=True) 

390 

391 assert len(ind) == len(kpt.f_n) 

392 # kpt.f_n[np.arange(len(ind))] = kpt.f_n[ind] 

393 kpt.f_n = kpt.f_n[ind] 

394 

395 occ_name = getattr(wfs.occupations, "name", None) 

396 if occ_name == 'mom': 

397 # OccupationsMOM.numbers needs to be updated 

398 # after sorting 

399 update_mom_numbers(wfs, kpt) 

400 if constraints: 

401 # Identity if the constrained orbitals have 

402 # changed and need to be updated 

403 constraints[k] = update_constraints_kpt( 

404 constraints[k], list(ind)) 

405 

406 

407def update_mom_numbers(wfs, kpt): 

408 if wfs.collinear and wfs.nspins == 1: 

409 degeneracy = 2 

410 else: 

411 degeneracy = 1 

412 wfs.occupations.numbers[kpt.s] = \ 

413 kpt.f_n / (kpt.weightk * degeneracy) 

414 

415 

416def sort_orbitals_kpt(wfs, kpt, ind, update_proj=False): 

417 if wfs.mode == 'lcao': 

418 kpt.C_nM[np.arange(len(ind)), :] = kpt.C_nM[ind, :] 

419 wfs.atomic_correction.calculate_projections(wfs, kpt) 

420 else: 

421 kpt.psit_nG[np.arange(len(ind))] = kpt.psit_nG[ind] 

422 if update_proj: 

423 wfs.pt.integrate(kpt.psit_nG, kpt.P_ani, kpt.q) 

424 

425 

426def update_constraints_kpt(constraints, ind): 

427 """ 

428 Change the constraint indices to match a new indexation, e.g. due to 

429 sorting the orbitals 

430 

431 :param constraints: The list of constraints for one K-point 

432 :param ind: List containing information about the change in indexation 

433 """ 

434 

435 new = deepcopy(constraints) 

436 for i in range(len(constraints)): 

437 for k in range(len(constraints[i])): 

438 new[i][k] = ind.index(constraints[i][k]) 

439 return new 

440 

441 

442def dict_to_array(x): 

443 """ 

444 Converts dictionaries with integer keys to one long array by appending. 

445 

446 :param x: Dictionary 

447 :return: Long array, dimensions of original dictionary parts, total 

448 dimensions 

449 """ 

450 y = [] 

451 dim = [] 

452 dimtot = 0 

453 for k in x.keys(): 

454 assert isinstance(k, int), ( 

455 'Cannot convert dict to array if keys are not ' 

456 'integer.') 

457 y += list(x[k]) 

458 dim.append(len(x[k])) 

459 dimtot += len(x[k]) 

460 return np.asarray(y), dim, dimtot 

461 

462 

463def array_to_dict(x, dim): 

464 """ 

465 Converts long array to dictionary with integer keys with values of 

466 dimensionality specified in dim. 

467 

468 :param x: Array 

469 :param dim: List with dimensionalities of parts of the dictionary 

470 :return: Dictionary 

471 """ 

472 y = {} 

473 start = 0 

474 stop = 0 

475 for i in range(len(dim)): 

476 stop += dim[i] 

477 y[i] = x[start: stop] 

478 start += dim[i] 

479 return y 

480 

481 

482def rotate_orbitals(etdm, wfs, indices, angles, channels): 

483 """ 

484 Applies rotations between pairs of orbitals. 

485 

486 :param etdm: ETDM object for a converged or at least initialized 

487 calculation 

488 :param indices: List of indices. Each element must be a list of an 

489 orbital pair corresponding to the orbital rotation. 

490 For occupied-virtual rotations (unitary invariant or 

491 sparse representations), the first index represents the 

492 occupied, the second the virtual orbital. 

493 For occupied-occupied rotations (sparse representation 

494 only), the first index must always be smaller than the 

495 second. 

496 :param angles: List of angles in radians. 

497 :param channels: List of spin channels. 

498 """ 

499 

500 angles = - np.array(angles) * np.pi / 180.0 

501 a_vec_u = get_a_vec_u(etdm, wfs, indices, angles, channels) 

502 c = {} 

503 for kpt in wfs.kpt_u: 

504 k = etdm.kpointval(kpt) 

505 c[k] = wfs.kpt_u[k].C_nM.copy() 

506 etdm.rotate_wavefunctions(wfs, a_vec_u, c) 

507 

508 

509def get_a_vec_u(etdm, wfs, indices, angles, channels, occ=None): 

510 """ 

511 Creates an orbital rotation vector based on given indices, angles and 

512 corresponding spin channels. 

513 

514 :param etdm: ETDM object for a converged or at least initialized 

515 calculation 

516 :param indices: List of indices. Each element must be a list of an 

517 orbital pair corresponding to the orbital rotation. 

518 For occupied-virtual rotations (unitary invariant or 

519 sparse representations), the first index represents the 

520 occupied, the second the virtual orbital. 

521 For occupied-occupied rotations (sparse representation 

522 only), the first index must always be smaller than the 

523 second. 

524 :param angles: List of angles in radians. 

525 :param channels: List of spin channels. 

526 :param occ: Occupation numbers for each k-point. Must be specified 

527 if the orbitals in the ETDM object are not ordered 

528 canonically, as the user orbital indexation is different 

529 from the one in the ETDM object then. 

530 

531 :return new_vec_u: Orbital rotation coordinate vector containing the 

532 specified values. 

533 """ 

534 

535 sort_orbitals_according_to_occ(wfs, etdm.constraints, update_mom=True) 

536 

537 new_vec_u = {} 

538 ind_up = etdm.ind_up 

539 a_vec_u = deepcopy(etdm.a_vec_u) 

540 conversion = [] 

541 for k in a_vec_u.keys(): 

542 new_vec_u[k] = np.zeros_like(a_vec_u[k]) 

543 if occ is not None: 

544 f_n = occ[k] 

545 occupied = f_n > 1.0e-10 

546 n_occ = len(f_n[occupied]) 

547 if n_occ == 0.0: 

548 continue 

549 if np.min(f_n[:n_occ]) == 0: 

550 ind_occ = np.argwhere(occupied) 

551 ind_unocc = np.argwhere(~occupied) 

552 ind = np.vstack((ind_occ, ind_unocc)) 

553 ind = np.squeeze(ind) 

554 conversion.append(list(ind)) 

555 else: 

556 conversion.append(None) 

557 

558 for ind, ang, s in zip(indices, angles, channels): 

559 if occ is not None: 

560 if conversion[s] is not None: 

561 ind[0] = conversion[s].index(ind[0]) 

562 ind[1] = conversion[s].index(ind[1]) 

563 m = np.where(ind_up[s][0] == ind[0])[0] 

564 n = np.where(ind_up[s][1] == ind[1])[0] 

565 res = None 

566 for i in m: 

567 for j in n: 

568 if i == j: 

569 res = i 

570 if res is None: 

571 raise ValueError('Orbital rotation does not exist.') 

572 new_vec_u[s][res] = ang 

573 

574 return new_vec_u 

575 

576 

577def get_n_occ(kpt): 

578 occupied = kpt.f_n > 1.0e-10 

579 n_occ = len(kpt.f_n[occupied]) 

580 

581 return n_occ, occupied 

582 

583 

584def get_indices(dimens): 

585 return np.tril_indices(dimens, -1) 

586 

587 

588def random_a(shape, dtype, rng: RNG = cast(RNG, np.random)): 

589 sample_unit_interval: Callable[[IntVector], ArrayND] = rng.random 

590 a = sample_unit_interval(shape) 

591 if dtype == complex: 

592 a = a.astype(complex) 

593 a += 1.0j * sample_unit_interval(shape) 

594 

595 return a