Coverage for gpaw/inducedfield/inducedfield_tddft.py: 70%
248 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1import numpy as np
3from gpaw import debug
4from gpaw.transformers import Transformer
5from gpaw.lfc import BasisFunctions
6from gpaw.lcaotddft.observer import TDDFTObserver
7from gpaw.utilities import unpack_density, is_contiguous
9from gpaw.inducedfield.inducedfield_base import BaseInducedField, \
10 sendreceive_dict
13class TDDFTInducedField(BaseInducedField, TDDFTObserver):
14 """Induced field class for time propagation TDDFT.
16 Attributes (see also ``BaseInducedField``):
17 -------------------------------------------
18 time: float
19 Current time
20 Fnt_wsG: ndarray (complex)
21 Fourier transform of induced pseudo density
22 n0t_sG: ndarray (float)
23 Ground state pseudo density
24 FD_awsp: dict of ndarray (complex)
25 Fourier transform of induced D_asp
26 D0_asp: dict of ndarray (float)
27 Ground state D_asp
28 """
30 def __init__(self, filename=None, paw=None,
31 frequencies=None, folding='Gauss', width=0.08,
32 interval=1, restart_file=None
33 ):
34 """
35 Parameters (see also ``BaseInducedField``):
36 -------------------------------------------
37 paw: TDDFT object
38 TDDFT object for time propagation
39 width: float
40 Width in eV for the Gaussian (sigma) or Lorentzian (eta) folding
41 Gaussian = exp(- (1/2) * sigma^2 * t^2)
42 Lorentzian = exp(- eta * t)
43 interval: int
44 Number of timesteps between calls (used when attaching)
45 restart_file: string
46 Name of the restart file
47 """
49 TDDFTObserver.__init__(self, paw, interval)
50 # From observer:
51 # self.niter
52 # self.interval
53 # self.timer
54 # Observer does also paw.attach(self, ...)
56 # Restart file
57 self.restart_file = restart_file
59 # These are allocated in allocate()
60 self.Fnt_wsG = None
61 self.n0t_sG = None
62 self.FD_awsp = None
63 self.D0_asp = None
65 self.readwritemode_str_to_list = \
66 {'': ['Fnt', 'n0t', 'FD', 'D0', 'atoms'],
67 'all': ['Fnt', 'n0t', 'FD', 'D0',
68 'Frho', 'Fphi', 'Fef', 'Ffe', 'atoms'],
69 'field': ['Frho', 'Fphi', 'Fef', 'Ffe', 'atoms']}
71 BaseInducedField.__init__(self, filename, paw,
72 frequencies, folding, width)
74 def initialize(self, paw, allocate=True):
75 BaseInducedField.initialize(self, paw, allocate)
77 if self.has_paw:
78 assert hasattr(paw, 'time') and hasattr(paw, 'niter'), 'Use TDDFT!'
79 self.time = paw.time # !
80 self.niter = paw.niter
82 def set_folding(self, folding, width):
83 BaseInducedField.set_folding(self, folding, width)
85 if self.folding is None:
86 self.envelope = lambda t: 1.0
87 else:
88 if self.folding == 'Gauss':
89 self.envelope = lambda t: np.exp(- 0.5 * self.width**2 * t**2)
90 elif self.folding == 'Lorentz':
91 self.envelope = lambda t: np.exp(- self.width * t)
92 else:
93 raise RuntimeError('unknown folding "' + self.folding + '"')
95 def allocate(self):
96 if not self.allocated:
98 # Ground state pseudo density
99 self.n0t_sG = self.gd.empty((self.nspins, )) + np.nan
101 # Fourier transformed pseudo density
102 self.Fnt_wsG = self.gd.zeros((self.nw, self.nspins),
103 dtype=self.dtype)
105 # Ground state D_asp
106 self.D0_asp = {}
107 for a, D_sp in self.density.D_asp.items():
108 self.D0_asp[a] = D_sp.copy()
110 # Size of D_p for each atom
111 self.np_a = {}
112 for a, D_sp in self.D0_asp.items():
113 self.np_a[a] = np.array([len(D_sp[0])])
115 # Fourier transformed D_asp
116 self.FD_awsp = {}
117 for a, np_ in self.np_a.items():
118 self.FD_awsp[a] = np.zeros((self.nw, self.nspins, np_[0]),
119 dtype=self.dtype)
121 self.allocated = True
123 if debug:
124 assert is_contiguous(self.Fnt_wsG, self.dtype)
126 def deallocate(self):
127 BaseInducedField.deallocate(self)
128 self.n0t_sG = None
129 self.Fnt_wsG = None
130 self.D0_asp = None
131 self.FD_awsp = None
133 def _update(self, paw):
134 if paw.action == 'init':
135 if paw.niter == 0:
136 self.n0t_sG[:] = paw.density.nt_sG
137 return
138 elif paw.action == 'kick':
139 # Background electric field
140 self.Fbgef_v = paw.kick_strength
141 return
142 elif paw.action != 'propagate':
143 return
145 assert (self.Fbgef_v is not None
146 and not np.any(np.isnan(self.n0t_sG))), \
147 f'Attach {self.__class__.__name__} before absorption kick'
149 # Update time
150 self.time = paw.time
151 time_step = paw.time_step
153 # Complex exponential with envelope
154 f_w = np.exp(1.0j * self.omega_w * self.time) * \
155 self.envelope(self.time) * time_step
157 # Time-dependent quantities
158 nt_sG = self.density.nt_sG
159 D_asp = self.density.D_asp
161 # Update Fourier transforms
162 for w in range(self.nw):
163 self.Fnt_wsG[w] += (nt_sG - self.n0t_sG) * f_w[w]
164 for a, D_sp in D_asp.items():
165 self.FD_awsp[a][w] += (D_sp - self.D0_asp[a]) * f_w[w]
167 # Restart file
168 # XXX remove this once deprecated dump_interval is removed,
169 # but keep write_restart() as it'll be still used
170 # (see TDDFTObserver class)
171 if (paw.restart_file is not None
172 and self.niter % paw.dump_interval == 0):
173 self.write_restart()
175 def write_restart(self):
176 if self.restart_file is not None:
177 self.write(self.restart_file)
178 self.log(f'{self.__class__.__name__}: Wrote restart file')
180 def interpolate_pseudo_density(self, gridrefinement=2):
182 gd = self.gd
183 Fnt_wsg = self.Fnt_wsG.copy()
185 # Find m for
186 # gridrefinement = 2**m
187 m1 = np.log(gridrefinement) / np.log(2.)
188 m = int(np.round(m1))
190 # Check if m is really integer
191 if np.absolute(m - m1) < 1e-8:
192 for i in range(m):
193 gd2 = gd.refine()
195 # Interpolate
196 interpolator = Transformer(gd, gd2, self.stencil,
197 dtype=self.dtype)
198 Fnt2_wsg = gd2.empty((self.nw, self.nspins), dtype=self.dtype)
199 for w in range(self.nw):
200 for s in range(self.nspins):
201 interpolator.apply(Fnt_wsg[w][s], Fnt2_wsg[w][s],
202 np.ones((3, 2), dtype=complex))
204 gd = gd2
205 Fnt_wsg = Fnt2_wsg
206 else:
207 raise NotImplementedError
209 return Fnt_wsg, gd
211 def comp_charge_correction(self, gridrefinement=2):
213 # TODO: implement for gr==1 also
214 assert gridrefinement == 2
216 # Density
217 Fnt_wsg, gd = self.interpolate_pseudo_density(gridrefinement)
218 Frhot_wg = Fnt_wsg.sum(axis=1)
220 tmp_g = gd.empty(dtype=float)
221 for w in range(self.nw):
222 # Determine compensation charge coefficients:
223 FQ_aL = {}
224 for a, FD_wsp in self.FD_awsp.items():
225 FQ_aL[a] = np.dot(FD_wsp[w].sum(axis=0),
226 self.setups[a].Delta_pL)
228 # Add real part of compensation charges
229 tmp_g[:] = 0
230 FQ2_aL = {}
231 for a, FQ_L in FQ_aL.items():
232 # Take copy to make array contiguous
233 FQ2_aL[a] = FQ_L.real.copy()
234# print is_contiguous(FQ2_aL[a])
235# print is_contiguous(FQ_L.real)
236 self.density.ghat.add(tmp_g, FQ2_aL)
237 Frhot_wg[w] += tmp_g
239 # Add imag part of compensation charges
240 tmp_g[:] = 0
241 FQ2_aL = {}
242 for a, FQ_L in FQ_aL.items():
243 FQ2_aL[a] = FQ_L.imag.copy()
244 self.density.ghat.add(tmp_g, FQ2_aL)
245 Frhot_wg[w] += 1.0j * tmp_g
247 return Frhot_wg, gd
249 def paw_corrections(self, gridrefinement=2):
251 Fn_wsg, gd = self.interpolate_pseudo_density(gridrefinement)
253 # Splines
254 splines = {}
255 phi_aj = []
256 phit_aj = []
257 for a, id in enumerate(self.setups.id_a):
258 if id in splines:
259 phi_j, phit_j = splines[id]
260 else:
261 # Load splines:
262 phi_j, phit_j = self.setups[a].get_partial_waves()[:2]
263 splines[id] = (phi_j, phit_j)
264 phi_aj.append(phi_j)
265 phit_aj.append(phit_j)
267 # Create localized functions from splines
268 phi = BasisFunctions(gd, phi_aj, dtype=float)
269 phit = BasisFunctions(gd, phit_aj, dtype=float)
270# phi = BasisFunctions(gd, phi_aj, dtype=complex)
271# phit = BasisFunctions(gd, phit_aj, dtype=complex)
272 spos_ac = self.atoms.get_scaled_positions()
273 phi.set_positions(spos_ac)
274 phit.set_positions(spos_ac)
276 tmp_g = gd.empty(dtype=float)
277 rho_MM = np.zeros((phi.Mmax, phi.Mmax), dtype=self.dtype)
278 rho2_MM = np.zeros_like(rho_MM)
279 for w in range(self.nw):
280 for s in range(self.nspins):
281 rho_MM[:] = 0
282 M1 = 0
283 for a, setup in enumerate(self.setups):
284 ni = setup.ni
285 FD_wsp = self.FD_awsp.get(a)
286 if FD_wsp is None:
287 FD_p = np.empty((ni * (ni + 1) // 2), dtype=self.dtype)
288 else:
289 FD_p = FD_wsp[w][s]
290 if gd.comm.size > 1:
291 gd.comm.broadcast(FD_p, self.rank_a[a])
292 D_ij = unpack_density(FD_p)
293 # unpack does complex conjugation that we don't want so
294 # remove conjugation
295 D_ij = np.triu(D_ij, 1) + np.conj(np.tril(D_ij))
297# if FD_wsp is None:
298# FD_wsp = np.empty((self.nw, self.nspins,
299# ni * (ni + 1) // 2),
300# dtype=self.dtype)
301# if gd.comm.size > 1:
302# gd.comm.broadcast(FD_wsp, self.rank_a[a])
303# D_ij = unpack_density(FD_wsp[w][s])
304# D_ij = np.triu(D_ij, 1) + np.conj(np.tril(D_ij))
306 M2 = M1 + ni
307 rho_MM[M1:M2, M1:M2] = D_ij
308 M1 = M2
310 # Add real part of AE corrections
311 tmp_g[:] = 0
312 rho2_MM[:] = rho_MM.real
313 # TODO: use ae_valence_density_correction
314 phi.construct_density(rho2_MM, tmp_g, q=-1)
315 phit.construct_density(-rho2_MM, tmp_g, q=-1)
316# phi.lfc.ae_valence_density_correction(rho2_MM, tmp_g,
317# np.zeros(len(phi.M_W),
318# np.intc),
319# np.zeros(self.na))
320# phit.lfc.ae_valence_density_correction(-rho2_MM, tmp_g,
321# np.zeros(len(phi.M_W),
322# np.intc),
323# np.zeros(self.na))
324 Fn_wsg[w][s] += tmp_g
326 # Add imag part of AE corrections
327 tmp_g[:] = 0
328 rho2_MM[:] = rho_MM.imag
329 # TODO: use ae_valence_density_correction
330 phi.construct_density(rho2_MM, tmp_g, q=-1)
331 phit.construct_density(-rho2_MM, tmp_g, q=-1)
332# phi.lfc.ae_valence_density_correction(rho2_MM, tmp_g,
333# np.zeros(len(phi.M_W),
334# np.intc),
335# np.zeros(self.na))
336# phit.lfc.ae_valence_density_correction(-rho2_MM, tmp_g,
337# np.zeros(len(phi.M_W),
338# np.intc),
339# np.zeros(self.na))
340 Fn_wsg[w][s] += 1.0j * tmp_g
342 return Fn_wsg, gd
344 def get_induced_density(self, from_density, gridrefinement):
345 # Return charge density (electrons = negative charge)
346 if from_density == 'pseudo':
347 Fn_wsg, gd = self.interpolate_pseudo_density(gridrefinement)
348 Frho_wg = - Fn_wsg.sum(axis=1)
349 return Frho_wg, gd
350 elif from_density == 'comp':
351 Frho_wg, gd = self.comp_charge_correction(gridrefinement)
352 Frho_wg = - Frho_wg
353 return Frho_wg, gd
354 elif from_density == 'ae':
355 Fn_wsg, gd = self.paw_corrections(gridrefinement)
356 Frho_wg = - Fn_wsg.sum(axis=1)
357 return Frho_wg, gd
358 else:
359 raise RuntimeError('unknown from_density "' + from_density + '"')
361 def _read(self, reader, reads):
362 BaseInducedField._read(self, reader, reads)
364 r = reader
365 time = r.time
366 if self.has_paw:
367 # Test time
368 if abs(time - self.time) >= 1e-9:
369 raise OSError('Timestamp is incompatible with calculator.')
370 else:
371 self.time = time
373 # Allocate
374 self.allocate()
376 # Dimensions for D_p for all atoms
377 self.np_a = r.np_a
379 def readarray(name):
380 if name.split('_')[0] in reads:
381 self.gd.distribute(r.get(name), getattr(self, name))
383 # Read arrays
384 readarray('n0t_sG')
385 readarray('Fnt_wsG')
387 if 'D0' in reads:
388 D0_asp = r.D0_asp
389 self.D0_asp = {}
390 for a in range(self.na):
391 if self.domain_comm.rank == self.rank_a[a]:
392 self.D0_asp[a] = D0_asp[a]
394 if 'FD' in reads:
395 FD_awsp = r.FD_awsp
396 self.FD_awsp = {}
397 for a in range(self.na):
398 if self.domain_comm.rank == self.rank_a[a]:
399 self.FD_awsp[a] = FD_awsp[a]
401 def _write(self, writer, writes):
402 BaseInducedField._write(self, writer, writes)
404 # Collect np_a to master
405 if self.kpt_comm.rank == 0 and self.band_comm.rank == 0:
407 # Create empty dict on domain master
408 if self.domain_comm.rank == 0:
409 np_a = {}
410 for a in range(self.na):
411 np_a[a] = np.empty(1, dtype=int)
412 else:
413 np_a = {}
414 # Collect dict to master
415 sendreceive_dict(self.domain_comm, np_a, 0,
416 self.np_a, self.rank_a, range(self.na))
418 # Write time propagation status
419 writer.write(time=self.time, np_a=np_a)
421 def writearray(name, shape, dtype):
422 if name.split('_')[0] in writes:
423 writer.add_array(name, shape, dtype)
424 a_wxg = getattr(self, name)
425 for w in range(self.nw):
426 writer.fill(self.gd.collect(a_wxg[w]))
428 ng = tuple(self.gd.get_size_of_global_array())
430 # Write time propagation arrays
431 if 'n0t' in writes:
432 writer.write(n0t_sG=self.gd.collect(self.n0t_sG))
433 writearray('Fnt_wsG', (self.nw, self.nspins) + ng, self.dtype)
435 if 'D0' in writes:
436 # Collect D0_asp to world master
437 if self.kpt_comm.rank == 0 and self.band_comm.rank == 0:
438 # Create empty dict on domain master
439 if self.domain_comm.rank == 0:
440 D0_asp = {}
441 for a in range(self.na):
442 npa = np_a[a]
443 D0_asp[a] = np.empty((self.nspins, npa[0]),
444 dtype=float)
445 else:
446 D0_asp = {}
447 # Collect dict to master
448 sendreceive_dict(self.domain_comm, D0_asp, 0,
449 self.D0_asp, self.rank_a, range(self.na))
450 # Write
451 writer.write(D0_asp=D0_asp)
453 if 'FD' in writes:
454 # Collect FD_awsp to world master
455 if self.kpt_comm.rank == 0 and self.band_comm.rank == 0:
456 # Create empty dict on domain master
457 if self.domain_comm.rank == 0:
458 FD_awsp = {}
459 for a in range(self.na):
460 npa = np_a[a]
461 FD_awsp[a] = np.empty((self.nw, self.nspins, npa[0]),
462 dtype=complex)
463 else:
464 FD_awsp = {}
465 # Collect dict to master
466 sendreceive_dict(self.domain_comm, FD_awsp, 0,
467 self.FD_awsp, self.rank_a, range(self.na))
468 # Write
469 writer.write(FD_awsp=FD_awsp)