Coverage for gpaw/lcaotddft/frequencydensitymatrix.py: 92%

149 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1import numpy as np 

2 

3from ase.io.ulm import Reader 

4from gpaw.io import Writer 

5 

6from gpaw.tddft.folding import Frequency 

7from gpaw.tddft.folding import FoldedFrequencies 

8from gpaw.lcaotddft.observer import TDDFTObserver 

9from gpaw.lcaotddft.utilities import read_uMM 

10from gpaw.lcaotddft.utilities import read_wuMM 

11from gpaw.lcaotddft.utilities import write_uMM 

12from gpaw.lcaotddft.utilities import write_wuMM 

13 

14 

15def generate_freq_w(foldedfreqs_f): 

16 freq_w = [] 

17 for ff in foldedfreqs_f: 

18 for f in ff.frequencies: 

19 freq_w.append(Frequency(f, ff.folding, 'au')) 

20 return freq_w 

21 

22 

23class FrequencyDensityMatrixReader: 

24 def __init__(self, filename, ksl, kpt_u): 

25 self.ksl = ksl 

26 self.kpt_u = kpt_u 

27 self.reader = Reader(filename) 

28 tag = self.reader.get_tag() 

29 if tag != FrequencyDensityMatrix.ulmtag: 

30 raise RuntimeError('Unknown tag %s' % tag) 

31 self.version = self.reader.version 

32 

33 # Read small vectors 

34 self.time = self.reader.time 

35 self.foldedfreqs_f = [FoldedFrequencies(**ff) 

36 for ff in self.reader.foldedfreqs_f] 

37 self.freq_w = generate_freq_w(self.foldedfreqs_f) 

38 self.Nw = len(self.freq_w) 

39 

40 def __getattr__(self, attr): 

41 if attr in ['rho0_uMM']: 

42 return read_uMM(self.kpt_u, self.ksl, self.reader, attr) 

43 if attr in ['FReDrho_wuMM', 'FImDrho_wuMM']: 

44 reim = attr[1:3] 

45 wlist = range(self.Nw) 

46 return self.read_FDrho(reim, wlist) 

47 

48 try: 

49 return getattr(self.reader, attr) 

50 except KeyError: 

51 pass 

52 

53 raise AttributeError('Attribute %s not defined in version %s' % 

54 (repr(attr), repr(self.version))) 

55 

56 def read_FDrho(self, reim, wlist): 

57 assert reim in ['Re', 'Im'] 

58 attr = 'F%sDrho_wuMM' % reim 

59 return read_wuMM(self.kpt_u, self.ksl, self.reader, attr, wlist) 

60 

61 def close(self): 

62 self.reader.close() 

63 

64 

65class FrequencyDensityMatrix(TDDFTObserver): 

66 version = 1 

67 ulmtag = 'FDM' 

68 

69 def __init__(self, 

70 paw, 

71 dmat, 

72 filename=None, 

73 frequencies=None, 

74 restart_filename=None, 

75 interval=1): 

76 TDDFTObserver.__init__(self, paw, interval) 

77 self.has_initialized = False 

78 self.dmat = dmat 

79 self.filename = filename 

80 self.restart_filename = restart_filename 

81 self.world = paw.world 

82 self.ksl = paw.wfs.ksl 

83 self.kd = paw.wfs.kd 

84 self.kpt_u = paw.wfs.kpt_u 

85 self.log = paw.log 

86 if self.ksl.using_blacs: 

87 ksl_comm = self.ksl.block_comm 

88 kd_comm = self.kd.comm 

89 assert self.world.size == ksl_comm.size * kd_comm.size 

90 

91 assert self.world.rank == self.ksl.world.rank 

92 

93 if filename is not None: 

94 self.read(filename) 

95 return 

96 

97 self.time = paw.time 

98 if isinstance(frequencies, FoldedFrequencies): 

99 frequencies = [frequencies] 

100 self.foldedfreqs_f = frequencies 

101 self.freq_w = generate_freq_w(self.foldedfreqs_f) 

102 self.Nw = np.sum([len(ff.frequencies) for ff in self.foldedfreqs_f]) 

103 

104 def initialize(self): 

105 if self.has_initialized: 

106 return 

107 

108 if self.kd.gamma: 

109 self.rho0_dtype = float 

110 else: 

111 self.rho0_dtype = complex 

112 

113 self.rho0_uMM = [] 

114 for kpt in self.kpt_u: 

115 self.rho0_uMM.append(self.dmat.zeros(self.rho0_dtype)) 

116 self.FReDrho_wuMM = [] 

117 self.FImDrho_wuMM = [] 

118 for w in range(self.Nw): 

119 self.FReDrho_wuMM.append([]) 

120 self.FImDrho_wuMM.append([]) 

121 for kpt in self.kpt_u: 

122 self.FReDrho_wuMM[-1].append(self.dmat.zeros(complex)) 

123 self.FImDrho_wuMM[-1].append(self.dmat.zeros(complex)) 

124 self.has_initialized = True 

125 

126 def _update(self, paw): 

127 if paw.action == 'init': 

128 if self.time != paw.time: 

129 raise RuntimeError('Timestamp do not match with ' 

130 'the calculator') 

131 self.initialize() 

132 if paw.niter == 0: 

133 rho_uMM = self.dmat.get_density_matrix(paw.niter) 

134 for u, kpt in enumerate(self.kpt_u): 

135 rho_MM = rho_uMM[u] 

136 if self.rho0_dtype == float: 

137 assert np.max(np.absolute(rho_MM.imag)) == 0.0 

138 rho_MM = rho_MM.real 

139 self.rho0_uMM[u][:] = rho_MM 

140 return 

141 

142 if paw.action == 'kick': 

143 return 

144 

145 assert paw.action == 'propagate' 

146 

147 time_step = paw.time - self.time 

148 self.time = paw.time 

149 

150 # Complex exponentials with envelope 

151 exp_w = [] 

152 for ff in self.foldedfreqs_f: 

153 exp_i = (np.exp(1.0j * ff.frequencies * self.time) * 

154 ff.folding.envelope(self.time) * time_step) 

155 exp_w.extend(exp_i.tolist()) 

156 

157 rho_uMM = self.dmat.get_density_matrix((paw.niter, paw.action)) 

158 for u, kpt in enumerate(self.kpt_u): 

159 Drho_MM = rho_uMM[u] - self.rho0_uMM[u] 

160 for w, exp in enumerate(exp_w): 

161 # Update Fourier transforms 

162 self.FReDrho_wuMM[w][u] += Drho_MM.real * exp 

163 self.FImDrho_wuMM[w][u] += Drho_MM.imag * exp 

164 

165 def write_restart(self): 

166 if self.restart_filename is None: 

167 return 

168 self.write(self.restart_filename) 

169 

170 def write(self, filename): 

171 self.log(f'{self.__class__.__name__}: Writing to {filename}') 

172 writer = Writer(filename, self.world, mode='w', 

173 tag=self.__class__.ulmtag) 

174 writer.write(version=self.__class__.version) 

175 writer.write(time=self.time) 

176 writer.write(foldedfreqs_f=[ff.todict() for ff in self.foldedfreqs_f]) 

177 write_uMM(self.kd, self.ksl, writer, 'rho0_uMM', self.rho0_uMM) 

178 wlist = range(self.Nw) 

179 write_wuMM(self.kd, self.ksl, writer, 'FReDrho_wuMM', 

180 self.FReDrho_wuMM, wlist) 

181 write_wuMM(self.kd, self.ksl, writer, 'FImDrho_wuMM', 

182 self.FImDrho_wuMM, wlist) 

183 writer.close() 

184 

185 def read(self, filename): 

186 reader = FrequencyDensityMatrixReader(filename, self.ksl, self.kpt_u) 

187 self.time = reader.time 

188 self.foldedfreqs_f = reader.foldedfreqs_f 

189 self.freq_w = reader.freq_w 

190 self.Nw = reader.Nw 

191 self.rho0_uMM = reader.rho0_uMM 

192 self.rho0_dtype = self.rho0_uMM[0].dtype 

193 self.FReDrho_wuMM = reader.FReDrho_wuMM 

194 self.FImDrho_wuMM = reader.FImDrho_wuMM 

195 reader.close() 

196 self.has_initialized = True