Coverage for gpaw/lcaotddft/frequencydensitymatrix.py: 92%
149 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1import numpy as np
3from ase.io.ulm import Reader
4from gpaw.io import Writer
6from gpaw.tddft.folding import Frequency
7from gpaw.tddft.folding import FoldedFrequencies
8from gpaw.lcaotddft.observer import TDDFTObserver
9from gpaw.lcaotddft.utilities import read_uMM
10from gpaw.lcaotddft.utilities import read_wuMM
11from gpaw.lcaotddft.utilities import write_uMM
12from gpaw.lcaotddft.utilities import write_wuMM
15def generate_freq_w(foldedfreqs_f):
16 freq_w = []
17 for ff in foldedfreqs_f:
18 for f in ff.frequencies:
19 freq_w.append(Frequency(f, ff.folding, 'au'))
20 return freq_w
23class FrequencyDensityMatrixReader:
24 def __init__(self, filename, ksl, kpt_u):
25 self.ksl = ksl
26 self.kpt_u = kpt_u
27 self.reader = Reader(filename)
28 tag = self.reader.get_tag()
29 if tag != FrequencyDensityMatrix.ulmtag:
30 raise RuntimeError('Unknown tag %s' % tag)
31 self.version = self.reader.version
33 # Read small vectors
34 self.time = self.reader.time
35 self.foldedfreqs_f = [FoldedFrequencies(**ff)
36 for ff in self.reader.foldedfreqs_f]
37 self.freq_w = generate_freq_w(self.foldedfreqs_f)
38 self.Nw = len(self.freq_w)
40 def __getattr__(self, attr):
41 if attr in ['rho0_uMM']:
42 return read_uMM(self.kpt_u, self.ksl, self.reader, attr)
43 if attr in ['FReDrho_wuMM', 'FImDrho_wuMM']:
44 reim = attr[1:3]
45 wlist = range(self.Nw)
46 return self.read_FDrho(reim, wlist)
48 try:
49 return getattr(self.reader, attr)
50 except KeyError:
51 pass
53 raise AttributeError('Attribute %s not defined in version %s' %
54 (repr(attr), repr(self.version)))
56 def read_FDrho(self, reim, wlist):
57 assert reim in ['Re', 'Im']
58 attr = 'F%sDrho_wuMM' % reim
59 return read_wuMM(self.kpt_u, self.ksl, self.reader, attr, wlist)
61 def close(self):
62 self.reader.close()
65class FrequencyDensityMatrix(TDDFTObserver):
66 version = 1
67 ulmtag = 'FDM'
69 def __init__(self,
70 paw,
71 dmat,
72 filename=None,
73 frequencies=None,
74 restart_filename=None,
75 interval=1):
76 TDDFTObserver.__init__(self, paw, interval)
77 self.has_initialized = False
78 self.dmat = dmat
79 self.filename = filename
80 self.restart_filename = restart_filename
81 self.world = paw.world
82 self.ksl = paw.wfs.ksl
83 self.kd = paw.wfs.kd
84 self.kpt_u = paw.wfs.kpt_u
85 self.log = paw.log
86 if self.ksl.using_blacs:
87 ksl_comm = self.ksl.block_comm
88 kd_comm = self.kd.comm
89 assert self.world.size == ksl_comm.size * kd_comm.size
91 assert self.world.rank == self.ksl.world.rank
93 if filename is not None:
94 self.read(filename)
95 return
97 self.time = paw.time
98 if isinstance(frequencies, FoldedFrequencies):
99 frequencies = [frequencies]
100 self.foldedfreqs_f = frequencies
101 self.freq_w = generate_freq_w(self.foldedfreqs_f)
102 self.Nw = np.sum([len(ff.frequencies) for ff in self.foldedfreqs_f])
104 def initialize(self):
105 if self.has_initialized:
106 return
108 if self.kd.gamma:
109 self.rho0_dtype = float
110 else:
111 self.rho0_dtype = complex
113 self.rho0_uMM = []
114 for kpt in self.kpt_u:
115 self.rho0_uMM.append(self.dmat.zeros(self.rho0_dtype))
116 self.FReDrho_wuMM = []
117 self.FImDrho_wuMM = []
118 for w in range(self.Nw):
119 self.FReDrho_wuMM.append([])
120 self.FImDrho_wuMM.append([])
121 for kpt in self.kpt_u:
122 self.FReDrho_wuMM[-1].append(self.dmat.zeros(complex))
123 self.FImDrho_wuMM[-1].append(self.dmat.zeros(complex))
124 self.has_initialized = True
126 def _update(self, paw):
127 if paw.action == 'init':
128 if self.time != paw.time:
129 raise RuntimeError('Timestamp do not match with '
130 'the calculator')
131 self.initialize()
132 if paw.niter == 0:
133 rho_uMM = self.dmat.get_density_matrix(paw.niter)
134 for u, kpt in enumerate(self.kpt_u):
135 rho_MM = rho_uMM[u]
136 if self.rho0_dtype == float:
137 assert np.max(np.absolute(rho_MM.imag)) == 0.0
138 rho_MM = rho_MM.real
139 self.rho0_uMM[u][:] = rho_MM
140 return
142 if paw.action == 'kick':
143 return
145 assert paw.action == 'propagate'
147 time_step = paw.time - self.time
148 self.time = paw.time
150 # Complex exponentials with envelope
151 exp_w = []
152 for ff in self.foldedfreqs_f:
153 exp_i = (np.exp(1.0j * ff.frequencies * self.time) *
154 ff.folding.envelope(self.time) * time_step)
155 exp_w.extend(exp_i.tolist())
157 rho_uMM = self.dmat.get_density_matrix((paw.niter, paw.action))
158 for u, kpt in enumerate(self.kpt_u):
159 Drho_MM = rho_uMM[u] - self.rho0_uMM[u]
160 for w, exp in enumerate(exp_w):
161 # Update Fourier transforms
162 self.FReDrho_wuMM[w][u] += Drho_MM.real * exp
163 self.FImDrho_wuMM[w][u] += Drho_MM.imag * exp
165 def write_restart(self):
166 if self.restart_filename is None:
167 return
168 self.write(self.restart_filename)
170 def write(self, filename):
171 self.log(f'{self.__class__.__name__}: Writing to {filename}')
172 writer = Writer(filename, self.world, mode='w',
173 tag=self.__class__.ulmtag)
174 writer.write(version=self.__class__.version)
175 writer.write(time=self.time)
176 writer.write(foldedfreqs_f=[ff.todict() for ff in self.foldedfreqs_f])
177 write_uMM(self.kd, self.ksl, writer, 'rho0_uMM', self.rho0_uMM)
178 wlist = range(self.Nw)
179 write_wuMM(self.kd, self.ksl, writer, 'FReDrho_wuMM',
180 self.FReDrho_wuMM, wlist)
181 write_wuMM(self.kd, self.ksl, writer, 'FImDrho_wuMM',
182 self.FImDrho_wuMM, wlist)
183 writer.close()
185 def read(self, filename):
186 reader = FrequencyDensityMatrixReader(filename, self.ksl, self.kpt_u)
187 self.time = reader.time
188 self.foldedfreqs_f = reader.foldedfreqs_f
189 self.freq_w = reader.freq_w
190 self.Nw = reader.Nw
191 self.rho0_uMM = reader.rho0_uMM
192 self.rho0_dtype = self.rho0_uMM[0].dtype
193 self.FReDrho_wuMM = reader.FReDrho_wuMM
194 self.FImDrho_wuMM = reader.FImDrho_wuMM
195 reader.close()
196 self.has_initialized = True