Coverage for gpaw/pes/tddft.py: 81%
165 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1"""PES using approximate LrTDDFT scheme.
3"""
4import numpy as np
6from ase.units import Hartree
8from gpaw.pes import BasePES
9from gpaw.pes.state import State
10from gpaw.utilities import packed_index
12from numpy import sqrt, pi
15class TDDFTPES(BasePES):
17 def __init__(self, mother, excited_daughter, daughter=None,
18 shift=True, tolerance={}):
19 self.tolerance = {
20 'occupation': 1e-10,
21 'magnetic': 2e-6,
22 'grid': 0,
23 }
24 for key in tolerance.keys():
25 if key not in self.tolerance:
26 raise RuntimeError("Tolerance key '%s' not known."
27 % key)
28 self.tolerance[key] = tolerance[key]
30 if excited_daughter.calculator is not None:
31 self.c_d = excited_daughter.calculator
32 else:
33 self.c_d = daughter
35 self.c_m = mother
36 self.gd = self.c_m.wfs.gd
37 self.lr_d = excited_daughter
39 self.c_m.converge_wave_functions()
40 self.c_d.converge_wave_functions()
41 self.lr_d.diagonalize()
43 self.check_systems()
44 self.lr_d.jend = self.lr_d.kss[-1].j
46 # Make good way for initialising these
48 kmax = 0
49 lmax = 0
50 for kss in self.lr_d.kss:
51 kmax = max(kmax, kss.i)
52 lmax = max(lmax, kss.j)
53 self.kmax = kmax + 1
54 self.lmax = lmax + 1
56 self.f = None
57 self.be = None
58 self.shift = shift
60 def gs_orbitals(calc):
61 indicees = []
62 nbands = calc.get_number_of_bands()
63 spin = (calc.get_number_of_spins() == 2)
64 f_tolerance = (2 - spin) * self.tolerance['occupation']
65 for kpt in calc.wfs.kpt_u:
66 for i in range(nbands):
67 if kpt.f_n[i] > f_tolerance:
68 indicees.append(i + kpt.s * nbands)
69 if not spin:
70 indicees.append(i + nbands)
71 return indicees
72 self.gs_m = gs_orbitals(self.c_m)
73 self.imax = len(self.gs_m)
74 self.gs_d = gs_orbitals(self.c_d)
76 if (len(self.gs_m) != len(self.gs_d) + 1):
77 raise RuntimeError(('Mother valence %d does not correspond ' +
78 'to daughter valence %d. ' +
79 'Modify tolerance["occupation"] ?') %
80 (len(self.gs_m), len(self.gs_d)))
82 def _calculate(self):
84 self.ks_overlaps()
85 self.single_overlaps()
86 self.full_overlap_matrix()
88 self._create_f()
90 def ks_overlaps(self):
91 """Evaluate KS overlaps of mother and daughter."""
92 bands_m = self.c_m.get_number_of_bands()
93 spin_m = self.c_m.get_number_of_spins() == 2
94 bands_d = self.c_d.get_number_of_bands()
95 spin_d = self.c_d.get_number_of_spins() == 2
97 self.overlap = np.zeros((2 * bands_m, 2 * bands_d))
98 for i_m in range(bands_m):
99 for s_m in range(2):
100 k_m = spin_m * s_m
101 wf_m = self.c_m.wfs.kpt_u[k_m].psit_nG[i_m]
103 for j_d in range(bands_d):
104 k_d = spin_d * s_m
106 wf_d = self.c_d.wfs.kpt_u[k_d].psit_nG[j_d]
107 me = self.gd.integrate(wf_m * wf_d)
109 i = s_m * bands_m + i_m
110 j = s_m * bands_d + j_d
111 self.overlap[i, j] = me + self._nuc_corr(i_m, j_d,
112 k_m, k_d)
114 def single_overlaps(self):
115 self.singles = np.zeros((self.imax, len(self.lr_d)))
116 nbands_d = self.c_d.get_number_of_bands()
118 for i, i_m in enumerate(self.gs_m):
119 for kl, kss in enumerate(self.lr_d.kss):
120 if kss.fij > self.tolerance['occupation']:
121 spin = kss.pspin
123 keep_row = list(self.gs_m)
124 keep_row.remove(i_m)
126 k_d = kss.i + spin * nbands_d
127 l_d = kss.j + spin * nbands_d
128 keep_col = list(self.gs_d)
129 keep_col.remove(k_d)
130 keep_col.append(l_d)
132 d_ikl = np.zeros((len(keep_row), len(keep_col)))
134 for col in range(len(keep_col)):
135 for row in range(len(keep_row)):
136 d_ikl[row, col] = self.overlap[keep_row[row],
137 keep_col[col]]
139 self.singles[i, kl] = np.linalg.det(d_ikl)
141 def gs_gs_overlaps(self):
142 """Evaluate overlap matrix of mother and daughter ground states.
144 """
145 g0 = np.zeros(self.imax)
146 for i, i_m in enumerate(self.gs_m):
147 keep_row = list(self.gs_m)
148 keep_row.remove(i_m)
150 keep_col = list(self.gs_d)
151 d_i00 = np.zeros((len(keep_row), len(keep_col)))
153 for col in range(len(keep_col)):
154 for row in range(len(keep_row)):
155 d_i00[row, col] = self.overlap[keep_row[row],
156 keep_col[col]]
158 g0[i] = (-1) ** (self.imax + i) * np.linalg.det(d_i00)
159 return g0
161 def full_overlap_matrix(self):
162 """Full overlap matrix of mother and daughter many particle states.
164 """
165 self.g_Ii = np.zeros((len(self.lr_d) + 1, self.imax))
166 self.g_Ii[0, :] = self.gs_gs_overlaps()
168 for I in range(len(self.lr_d)):
169 for i in range(self.imax):
170 gi = 0
171 for kl in range(len(self.lr_d)):
172 gi += self.lr_d[I].f[kl] * self.singles[i, kl]
173 self.g_Ii[1 + I, i] = (-1.) ** (self.imax + i) * gi
175 def _create_f(self):
176 self.f = (self.g_Ii * self.g_Ii).sum(axis=1)
178 if self.shift:
179 shift = (self.c_d.get_potential_energy() -
180 self.c_m.get_potential_energy())
181 else:
182 shift = float(self.shift)
184 self.be = (np.array([0] + list(self.lr_d.get_energies() * Hartree)) +
185 shift)
187 def _nuc_corr(self, i_m, j_d, k_m, k_d):
188 ma = 0.0
190 for a, P_ni_m in self.c_m.wfs.kpt_u[k_m].P_ani.items():
191 P_ni_d = self.c_d.wfs.kpt_u[k_d].P_ani[a]
192 Pi_i = P_ni_m[i_m]
193 Pj_i = P_ni_d[j_d]
194 Delta_pL = self.c_m.wfs.setups[a].Delta_pL
196 for i in range(len(Pi_i)):
197 for j in range(len(Pj_i)):
198 pij = Pi_i[i] * Pj_i[j]
199 ij = packed_index(i, j, len(Pi_i))
200 ma += Delta_pL[ij, 0] * pij
202 ma = self.gd.comm.sum_scalar(ma)
203 return sqrt(4 * pi) * ma
205 def check_systems(self):
206 """Check that mother and daughter systems correspond to each other.
208 """
209 gtol = self.tolerance['grid']
210 mtol = self.tolerance['magnetic']
211 if (np.abs(self.c_m.wfs.gd.cell_cv -
212 self.c_d.wfs.gd.cell_cv) > gtol).any():
213 raise RuntimeError('Not the same grid:' +
214 str(self.c_m.wfs.gd.cell_cv) + ' !=' +
215 str(self.c_d.wfs.gd.cell_cv))
216 if (np.abs(self.c_m.wfs.gd.h_cv -
217 self.c_d.wfs.gd.h_cv) > gtol).any():
218 raise RuntimeError('Not the same grid')
219 if (self.c_m.atoms.positions != self.c_m.atoms.positions).any():
220 raise RuntimeError('Not the same atomic positions')
221 if np.abs(np.abs(self.c_m.get_magnetic_moment() -
222 self.c_d.get_magnetic_moment()) - 1) > mtol:
223 print(self.c_m.get_magnetic_moment(),
224 self.c_d.get_magnetic_moment())
225 print(self.tolerance)
226 raise RuntimeError(('Mother (%g) ' %
227 self.c_m.get_magnetic_moment()) +
228 ('and daughter spin (%g) ' %
229 self.c_d.get_magnetic_moment()) +
230 'are not compatible')
232 def Dyson_orbital(self, I):
233 """Return the Dyson orbital corresponding to excition I."""
234 if not hasattr(self, 'g'):
235 self._calculate()
236 if not hasattr(self, 'morbitals'):
237 nbands = self.c_m.get_number_of_bands()
238 spin = self.c_m.get_number_of_spins() == 2
239 morbitals_ig = []
240 for i in self.gs_m:
241 k = int(i >= nbands) * spin
242 i -= nbands * int(i >= nbands)
243 morbitals_ig.append(self.c_m.wfs.kpt_u[k].psit_nG[i])
244 self.morbitals_ig = np.array(morbitals_ig)
246 dyson = State(self.gd)
247 gridnn = self.gd.zeros()
248 for i, g in enumerate(self.g_Ii[I]):
249 gridnn += g * self.morbitals_ig[i]
250 dyson.set_grid(gridnn)
251 dyson.set_energy(-self.be[I])
252 return dyson