Coverage for gpaw/lcao/scissors.py: 100%
108 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1"""Scissors operator for LCAO."""
2from __future__ import annotations
4from typing import Sequence
6import numpy as np
7from ase.units import Ha
9from gpaw.lcao.eigensolver import DirectLCAO
10from gpaw.new.calculation import DFTCalculation
11from gpaw.new.lcao.eigensolver import LCAOEigensolver
12from gpaw.new.symmetry import Symmetries
13from gpaw.core.matrix import Matrix
16def non_self_consistent_scissors_shift(
17 shifts: Sequence[tuple[float, float, int]],
18 dft: DFTCalculation) -> np.ndarray:
19 """Apply non self-consistent scissors shift.
21 Return eigenvalues as a::
23 (nspins, nibzkpts, nbands)
25 shaped ndarray in eV units.
27 The *shifts* are given as a sequence of tuples
28 (energy shifts in eV)::
30 [(<shift for occupied states>,
31 <shift for unoccupied states>,
32 <number of atoms>),
33 ...]
35 Here we open a gap for states on atoms with indices 3, 4 and 5::
37 eig_skM = non_self_consistent_scissors_shift(
38 [(0.0, 0.0, 3),
39 (-0.5, 0.5, 3)],
40 dft)
41 """
42 ibzwfs = dft.ibzwfs
43 check_symmetries(ibzwfs.ibz.symmetries, shifts)
44 shifts = [(homo / Ha, lumo / Ha, natoms)
45 for homo, lumo, natoms in shifts]
46 matcalc = dft.scf_loop.hamiltonian.create_hamiltonian_matrix_calculator(
47 dft.potential)
48 matcalc = MyMatCalc(matcalc, shifts)
49 eig_skn = np.zeros((ibzwfs.nspins, len(ibzwfs.ibz), ibzwfs.nbands))
50 for wfs in ibzwfs:
51 H_MM = matcalc.calculate_matrix(wfs)
52 eig_M = H_MM.eighg(wfs.L_MM, wfs.domain_comm)
53 eig_skn[wfs.spin, wfs.k] = eig_M[:ibzwfs.nbands]
54 ibzwfs.kpt_comm.sum(eig_skn)
55 return eig_skn * Ha
58def check_symmetries(symmetries: Symmetries,
59 shifts: Sequence[tuple[float, float, int]]) -> None:
60 """Make sure shifts don't break any symmetries.
62 >>> from gpaw.new.symmetry import create_symmetries_object
63 >>> from ase import Atoms
64 >>> atoms = Atoms('HH', [(0, 0, 1), (0, 0, -1)], cell=[3, 3, 3])
65 >>> sym = create_symmetries_object(atoms)
66 >>> check_symmetries(sym, [(1.0, 1.0, 1)])
67 Traceback (most recent call last):
68 ...
69 ValueError: A symmetry maps atom 0 onto atom 1,
70 but those atoms have different scissors shifts
71 """
72 b_sa = symmetries.atommap_sa
73 shift_a = []
74 for ho, lu, natoms in shifts:
75 shift_a += [(ho, lu)] * natoms
76 shift_a += [(0.0, 0.0)] * (b_sa.shape[1] - len(shift_a))
77 for b_a in b_sa:
78 for a, b in enumerate(b_a):
79 if shift_a[a] != shift_a[b]:
80 raise ValueError(f'A symmetry maps atom {a} onto atom {b},\n'
81 'but those atoms have different '
82 'scissors shifts')
85class ScissorsLCAOEigensolver(LCAOEigensolver):
86 def __init__(self,
87 basis,
88 shifts: Sequence[tuple[float, float, int]],
89 symmetries: Symmetries):
90 """Scissors-operator eigensolver."""
91 check_symmetries(symmetries, shifts)
92 super().__init__(basis)
93 self.shifts = []
94 for homo, lumo, natoms in shifts:
95 self.shifts.append((homo / Ha, lumo / Ha, natoms))
97 def iterate(self,
98 ibzwfs,
99 density,
100 potential,
101 hamiltonian,
102 pot_calc=None,
103 energies=None): # -> tuple[float, DFTEnergies]:
104 eps_error, _, energies = \
105 super().iterate(ibzwfs, density, potential,
106 hamiltonian, pot_calc, energies)
107 if ibzwfs.wfs_qs[0][0]._occ_n is None:
108 wfs_error = np.nan
109 else:
110 wfs_error = 0.0
111 return eps_error, wfs_error, energies
113 def iterate1(self,
114 wfs,
115 weight_n,
116 matrix_calculator):
117 super().iterate1(wfs, weight_n,
118 MyMatCalc(matrix_calculator, self.shifts))
120 def __repr__(self):
121 txt = DirectLCAO.__repr__(self)
122 txt += '\n Scissors operators:\n'
123 a1 = 0
124 for homo, lumo, natoms in self.shifts:
125 a2 = a1 + natoms
126 txt += (f' Atoms {a1}-{a2 - 1}: '
127 f'VB: {homo * Ha:+.3f} eV, '
128 f'CB: {lumo * Ha:+.3f} eV\n')
129 a1 = a2
130 return txt
133class MyMatCalc:
134 def __init__(self, matcalc, shifts):
135 self.matcalc = matcalc
136 self.shifts = shifts
138 def calculate_matrix(self, wfs):
139 H_MM = self.matcalc.calculate_matrix(wfs)
141 try:
142 nocc = int(round(wfs.occ_n.sum()))
143 except ValueError:
144 return H_MM
146 self.add_scissors(wfs, H_MM, nocc)
147 return H_MM
149 def add_scissors(self, wfs, H_MM, nocc):
150 ''' Serial implementation for readability:
151 C_nM = wfs.C_nM.data
152 S_MM = wfs.S_MM.data
154 # Find Z=S^(1/2):
155 e_N, U_MN = np.linalg.eigh(S_MM)
156 # We now have: S_MM @ U_MN = U_MN @ diag(e_N)
157 Z_MM = U_MN @ (e_N[np.newaxis]**0.5 * U_MN).T.conj()
159 # Density matrix:
160 A_nM = C_nM[:nocc].conj() @ Z_MM
161 R_MM = A_nM.conj().T @ A_nM
163 M1 = 0
164 a1 = 0
165 for homo, lumo, natoms in self.shifts:
166 a2 = a1 + natoms
167 M2 = M1 + sum(setup.nao for setup in wfs.setups[a1:a2])
168 H_MM.data += Z_MM[:, M1:M2] @ \
169 ((homo - lumo) * R_MM[M1:M2, M1:M2] + np.eye(M2 - M1) * lumo) \
170 @ Z_MM.conj().T[M1:M2, :]
171 a1 = a2
172 M1 = M2
174 return H_MM
175 '''
177 # Parallel implementation:
178 U_NM = wfs.S_MM.copy()
180 C_nM = wfs.C_nM
181 comm = wfs.C_nM.dist.comm
182 dist = (comm, comm.size, 1)
184 M = C_nM.shape[1]
185 C0_nM = C_nM.gather()
186 C1_nM = Matrix(nocc, M, dtype=C_nM.dtype, dist=(comm, 1, 1))
187 if comm.rank == 0:
188 C1_nM.data[:] = C0_nM.data[:nocc, :]
189 C_nM = C1_nM.new(dist=dist)
190 C1_nM.redist(C_nM)
192 # Find Z=S^(1/2):
193 e_N = U_NM.eigh()
194 e_NM = U_NM.copy()
195 # We now have: S_MM @ U_MN = U_MN @ diag(e_N)
197 # Next: Z_MM = U_MN @ (e_N[np.newaxis]**0.5 * U_MN).T.conj()
198 n1, n2 = U_NM.dist.my_row_range()
199 e_NM.data *= e_N[n1:n2, None]**0.5
200 e_NM.complex_conjugate()
201 Z_MM = U_NM.multiply(e_NM, opa='T')
203 # Density matrix:
204 C_nM.complex_conjugate()
205 Q_nM = C_nM.multiply(Z_MM, opb='C')
207 n = Q_nM.shape[0]
209 M1 = 0
210 a1 = 0
211 for homo, lumo, natoms in self.shifts:
212 a2 = a1 + natoms
213 M2 = M1 + sum(setup.nao for setup in wfs.setups[a1:a2])
214 A_Mm = Matrix(M, M2 - M1, dtype=Z_MM.dtype, dist=dist)
215 A_Mm.data[:] = Z_MM.data[:, M1:M2]
216 Q_nm = Matrix(n, M2 - M1, dtype=Q_nM.dtype, dist=dist)
217 Q_nm.data[:] = Q_nM.data[:, M1:M2]
219 Q2_nm = Q_nm.copy()
220 Q2_nm.complex_conjugate()
222 R_mm = Q2_nm.multiply(Q_nm, opa='T')
223 R_mm.data *= (homo - lumo)
224 R_mm.add_to_diagonal(lumo)
225 B_mM = R_mm.multiply(A_Mm, opb='C')
226 A_Mm.multiply(B_mM, beta=1.0, out=H_MM)
228 a1 = a2
229 M1 = M2
231 return H_MM