Coverage for gpaw/inducedfield/inducedfield_tddft.py: 70%

248 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import numpy as np 

2 

3from gpaw import debug 

4from gpaw.transformers import Transformer 

5from gpaw.lfc import BasisFunctions 

6from gpaw.lcaotddft.observer import TDDFTObserver 

7from gpaw.utilities import unpack_density, is_contiguous 

8 

9from gpaw.inducedfield.inducedfield_base import BaseInducedField, \ 

10 sendreceive_dict 

11 

12 

13class TDDFTInducedField(BaseInducedField, TDDFTObserver): 

14 """Induced field class for time propagation TDDFT. 

15 

16 Attributes (see also ``BaseInducedField``): 

17 ------------------------------------------- 

18 time: float 

19 Current time 

20 Fnt_wsG: ndarray (complex) 

21 Fourier transform of induced pseudo density 

22 n0t_sG: ndarray (float) 

23 Ground state pseudo density 

24 FD_awsp: dict of ndarray (complex) 

25 Fourier transform of induced D_asp 

26 D0_asp: dict of ndarray (float) 

27 Ground state D_asp 

28 """ 

29 

30 def __init__(self, filename=None, paw=None, 

31 frequencies=None, folding='Gauss', width=0.08, 

32 interval=1, restart_file=None 

33 ): 

34 """ 

35 Parameters (see also ``BaseInducedField``): 

36 ------------------------------------------- 

37 paw: TDDFT object 

38 TDDFT object for time propagation 

39 width: float 

40 Width in eV for the Gaussian (sigma) or Lorentzian (eta) folding 

41 Gaussian = exp(- (1/2) * sigma^2 * t^2) 

42 Lorentzian = exp(- eta * t) 

43 interval: int 

44 Number of timesteps between calls (used when attaching) 

45 restart_file: string 

46 Name of the restart file 

47 """ 

48 

49 TDDFTObserver.__init__(self, paw, interval) 

50 # From observer: 

51 # self.niter 

52 # self.interval 

53 # self.timer 

54 # Observer does also paw.attach(self, ...) 

55 

56 # Restart file 

57 self.restart_file = restart_file 

58 

59 # These are allocated in allocate() 

60 self.Fnt_wsG = None 

61 self.n0t_sG = None 

62 self.FD_awsp = None 

63 self.D0_asp = None 

64 

65 self.readwritemode_str_to_list = \ 

66 {'': ['Fnt', 'n0t', 'FD', 'D0', 'atoms'], 

67 'all': ['Fnt', 'n0t', 'FD', 'D0', 

68 'Frho', 'Fphi', 'Fef', 'Ffe', 'atoms'], 

69 'field': ['Frho', 'Fphi', 'Fef', 'Ffe', 'atoms']} 

70 

71 BaseInducedField.__init__(self, filename, paw, 

72 frequencies, folding, width) 

73 

74 def initialize(self, paw, allocate=True): 

75 BaseInducedField.initialize(self, paw, allocate) 

76 

77 if self.has_paw: 

78 assert hasattr(paw, 'time') and hasattr(paw, 'niter'), 'Use TDDFT!' 

79 self.time = paw.time # ! 

80 self.niter = paw.niter 

81 

82 def set_folding(self, folding, width): 

83 BaseInducedField.set_folding(self, folding, width) 

84 

85 if self.folding is None: 

86 self.envelope = lambda t: 1.0 

87 else: 

88 if self.folding == 'Gauss': 

89 self.envelope = lambda t: np.exp(- 0.5 * self.width**2 * t**2) 

90 elif self.folding == 'Lorentz': 

91 self.envelope = lambda t: np.exp(- self.width * t) 

92 else: 

93 raise RuntimeError('unknown folding "' + self.folding + '"') 

94 

95 def allocate(self): 

96 if not self.allocated: 

97 

98 # Ground state pseudo density 

99 self.n0t_sG = self.gd.empty((self.nspins, )) + np.nan 

100 

101 # Fourier transformed pseudo density 

102 self.Fnt_wsG = self.gd.zeros((self.nw, self.nspins), 

103 dtype=self.dtype) 

104 

105 # Ground state D_asp 

106 self.D0_asp = {} 

107 for a, D_sp in self.density.D_asp.items(): 

108 self.D0_asp[a] = D_sp.copy() 

109 

110 # Size of D_p for each atom 

111 self.np_a = {} 

112 for a, D_sp in self.D0_asp.items(): 

113 self.np_a[a] = np.array([len(D_sp[0])]) 

114 

115 # Fourier transformed D_asp 

116 self.FD_awsp = {} 

117 for a, np_ in self.np_a.items(): 

118 self.FD_awsp[a] = np.zeros((self.nw, self.nspins, np_[0]), 

119 dtype=self.dtype) 

120 

121 self.allocated = True 

122 

123 if debug: 

124 assert is_contiguous(self.Fnt_wsG, self.dtype) 

125 

126 def deallocate(self): 

127 BaseInducedField.deallocate(self) 

128 self.n0t_sG = None 

129 self.Fnt_wsG = None 

130 self.D0_asp = None 

131 self.FD_awsp = None 

132 

133 def _update(self, paw): 

134 if paw.action == 'init': 

135 if paw.niter == 0: 

136 self.n0t_sG[:] = paw.density.nt_sG 

137 return 

138 elif paw.action == 'kick': 

139 # Background electric field 

140 self.Fbgef_v = paw.kick_strength 

141 return 

142 elif paw.action != 'propagate': 

143 return 

144 

145 assert (self.Fbgef_v is not None 

146 and not np.any(np.isnan(self.n0t_sG))), \ 

147 f'Attach {self.__class__.__name__} before absorption kick' 

148 

149 # Update time 

150 self.time = paw.time 

151 time_step = paw.time_step 

152 

153 # Complex exponential with envelope 

154 f_w = np.exp(1.0j * self.omega_w * self.time) * \ 

155 self.envelope(self.time) * time_step 

156 

157 # Time-dependent quantities 

158 nt_sG = self.density.nt_sG 

159 D_asp = self.density.D_asp 

160 

161 # Update Fourier transforms 

162 for w in range(self.nw): 

163 self.Fnt_wsG[w] += (nt_sG - self.n0t_sG) * f_w[w] 

164 for a, D_sp in D_asp.items(): 

165 self.FD_awsp[a][w] += (D_sp - self.D0_asp[a]) * f_w[w] 

166 

167 # Restart file 

168 # XXX remove this once deprecated dump_interval is removed, 

169 # but keep write_restart() as it'll be still used 

170 # (see TDDFTObserver class) 

171 if (paw.restart_file is not None 

172 and self.niter % paw.dump_interval == 0): 

173 self.write_restart() 

174 

175 def write_restart(self): 

176 if self.restart_file is not None: 

177 self.write(self.restart_file) 

178 self.log(f'{self.__class__.__name__}: Wrote restart file') 

179 

180 def interpolate_pseudo_density(self, gridrefinement=2): 

181 

182 gd = self.gd 

183 Fnt_wsg = self.Fnt_wsG.copy() 

184 

185 # Find m for 

186 # gridrefinement = 2**m 

187 m1 = np.log(gridrefinement) / np.log(2.) 

188 m = int(np.round(m1)) 

189 

190 # Check if m is really integer 

191 if np.absolute(m - m1) < 1e-8: 

192 for i in range(m): 

193 gd2 = gd.refine() 

194 

195 # Interpolate 

196 interpolator = Transformer(gd, gd2, self.stencil, 

197 dtype=self.dtype) 

198 Fnt2_wsg = gd2.empty((self.nw, self.nspins), dtype=self.dtype) 

199 for w in range(self.nw): 

200 for s in range(self.nspins): 

201 interpolator.apply(Fnt_wsg[w][s], Fnt2_wsg[w][s], 

202 np.ones((3, 2), dtype=complex)) 

203 

204 gd = gd2 

205 Fnt_wsg = Fnt2_wsg 

206 else: 

207 raise NotImplementedError 

208 

209 return Fnt_wsg, gd 

210 

211 def comp_charge_correction(self, gridrefinement=2): 

212 

213 # TODO: implement for gr==1 also 

214 assert gridrefinement == 2 

215 

216 # Density 

217 Fnt_wsg, gd = self.interpolate_pseudo_density(gridrefinement) 

218 Frhot_wg = Fnt_wsg.sum(axis=1) 

219 

220 tmp_g = gd.empty(dtype=float) 

221 for w in range(self.nw): 

222 # Determine compensation charge coefficients: 

223 FQ_aL = {} 

224 for a, FD_wsp in self.FD_awsp.items(): 

225 FQ_aL[a] = np.dot(FD_wsp[w].sum(axis=0), 

226 self.setups[a].Delta_pL) 

227 

228 # Add real part of compensation charges 

229 tmp_g[:] = 0 

230 FQ2_aL = {} 

231 for a, FQ_L in FQ_aL.items(): 

232 # Take copy to make array contiguous 

233 FQ2_aL[a] = FQ_L.real.copy() 

234# print is_contiguous(FQ2_aL[a]) 

235# print is_contiguous(FQ_L.real) 

236 self.density.ghat.add(tmp_g, FQ2_aL) 

237 Frhot_wg[w] += tmp_g 

238 

239 # Add imag part of compensation charges 

240 tmp_g[:] = 0 

241 FQ2_aL = {} 

242 for a, FQ_L in FQ_aL.items(): 

243 FQ2_aL[a] = FQ_L.imag.copy() 

244 self.density.ghat.add(tmp_g, FQ2_aL) 

245 Frhot_wg[w] += 1.0j * tmp_g 

246 

247 return Frhot_wg, gd 

248 

249 def paw_corrections(self, gridrefinement=2): 

250 

251 Fn_wsg, gd = self.interpolate_pseudo_density(gridrefinement) 

252 

253 # Splines 

254 splines = {} 

255 phi_aj = [] 

256 phit_aj = [] 

257 for a, id in enumerate(self.setups.id_a): 

258 if id in splines: 

259 phi_j, phit_j = splines[id] 

260 else: 

261 # Load splines: 

262 phi_j, phit_j = self.setups[a].get_partial_waves()[:2] 

263 splines[id] = (phi_j, phit_j) 

264 phi_aj.append(phi_j) 

265 phit_aj.append(phit_j) 

266 

267 # Create localized functions from splines 

268 phi = BasisFunctions(gd, phi_aj, dtype=float) 

269 phit = BasisFunctions(gd, phit_aj, dtype=float) 

270# phi = BasisFunctions(gd, phi_aj, dtype=complex) 

271# phit = BasisFunctions(gd, phit_aj, dtype=complex) 

272 spos_ac = self.atoms.get_scaled_positions() 

273 phi.set_positions(spos_ac) 

274 phit.set_positions(spos_ac) 

275 

276 tmp_g = gd.empty(dtype=float) 

277 rho_MM = np.zeros((phi.Mmax, phi.Mmax), dtype=self.dtype) 

278 rho2_MM = np.zeros_like(rho_MM) 

279 for w in range(self.nw): 

280 for s in range(self.nspins): 

281 rho_MM[:] = 0 

282 M1 = 0 

283 for a, setup in enumerate(self.setups): 

284 ni = setup.ni 

285 FD_wsp = self.FD_awsp.get(a) 

286 if FD_wsp is None: 

287 FD_p = np.empty((ni * (ni + 1) // 2), dtype=self.dtype) 

288 else: 

289 FD_p = FD_wsp[w][s] 

290 if gd.comm.size > 1: 

291 gd.comm.broadcast(FD_p, self.rank_a[a]) 

292 D_ij = unpack_density(FD_p) 

293 # unpack does complex conjugation that we don't want so 

294 # remove conjugation 

295 D_ij = np.triu(D_ij, 1) + np.conj(np.tril(D_ij)) 

296 

297# if FD_wsp is None: 

298# FD_wsp = np.empty((self.nw, self.nspins, 

299# ni * (ni + 1) // 2), 

300# dtype=self.dtype) 

301# if gd.comm.size > 1: 

302# gd.comm.broadcast(FD_wsp, self.rank_a[a]) 

303# D_ij = unpack_density(FD_wsp[w][s]) 

304# D_ij = np.triu(D_ij, 1) + np.conj(np.tril(D_ij)) 

305 

306 M2 = M1 + ni 

307 rho_MM[M1:M2, M1:M2] = D_ij 

308 M1 = M2 

309 

310 # Add real part of AE corrections 

311 tmp_g[:] = 0 

312 rho2_MM[:] = rho_MM.real 

313 # TODO: use ae_valence_density_correction 

314 phi.construct_density(rho2_MM, tmp_g, q=-1) 

315 phit.construct_density(-rho2_MM, tmp_g, q=-1) 

316# phi.lfc.ae_valence_density_correction(rho2_MM, tmp_g, 

317# np.zeros(len(phi.M_W), 

318# np.intc), 

319# np.zeros(self.na)) 

320# phit.lfc.ae_valence_density_correction(-rho2_MM, tmp_g, 

321# np.zeros(len(phi.M_W), 

322# np.intc), 

323# np.zeros(self.na)) 

324 Fn_wsg[w][s] += tmp_g 

325 

326 # Add imag part of AE corrections 

327 tmp_g[:] = 0 

328 rho2_MM[:] = rho_MM.imag 

329 # TODO: use ae_valence_density_correction 

330 phi.construct_density(rho2_MM, tmp_g, q=-1) 

331 phit.construct_density(-rho2_MM, tmp_g, q=-1) 

332# phi.lfc.ae_valence_density_correction(rho2_MM, tmp_g, 

333# np.zeros(len(phi.M_W), 

334# np.intc), 

335# np.zeros(self.na)) 

336# phit.lfc.ae_valence_density_correction(-rho2_MM, tmp_g, 

337# np.zeros(len(phi.M_W), 

338# np.intc), 

339# np.zeros(self.na)) 

340 Fn_wsg[w][s] += 1.0j * tmp_g 

341 

342 return Fn_wsg, gd 

343 

344 def get_induced_density(self, from_density, gridrefinement): 

345 # Return charge density (electrons = negative charge) 

346 if from_density == 'pseudo': 

347 Fn_wsg, gd = self.interpolate_pseudo_density(gridrefinement) 

348 Frho_wg = - Fn_wsg.sum(axis=1) 

349 return Frho_wg, gd 

350 elif from_density == 'comp': 

351 Frho_wg, gd = self.comp_charge_correction(gridrefinement) 

352 Frho_wg = - Frho_wg 

353 return Frho_wg, gd 

354 elif from_density == 'ae': 

355 Fn_wsg, gd = self.paw_corrections(gridrefinement) 

356 Frho_wg = - Fn_wsg.sum(axis=1) 

357 return Frho_wg, gd 

358 else: 

359 raise RuntimeError('unknown from_density "' + from_density + '"') 

360 

361 def _read(self, reader, reads): 

362 BaseInducedField._read(self, reader, reads) 

363 

364 r = reader 

365 time = r.time 

366 if self.has_paw: 

367 # Test time 

368 if abs(time - self.time) >= 1e-9: 

369 raise OSError('Timestamp is incompatible with calculator.') 

370 else: 

371 self.time = time 

372 

373 # Allocate 

374 self.allocate() 

375 

376 # Dimensions for D_p for all atoms 

377 self.np_a = r.np_a 

378 

379 def readarray(name): 

380 if name.split('_')[0] in reads: 

381 self.gd.distribute(r.get(name), getattr(self, name)) 

382 

383 # Read arrays 

384 readarray('n0t_sG') 

385 readarray('Fnt_wsG') 

386 

387 if 'D0' in reads: 

388 D0_asp = r.D0_asp 

389 self.D0_asp = {} 

390 for a in range(self.na): 

391 if self.domain_comm.rank == self.rank_a[a]: 

392 self.D0_asp[a] = D0_asp[a] 

393 

394 if 'FD' in reads: 

395 FD_awsp = r.FD_awsp 

396 self.FD_awsp = {} 

397 for a in range(self.na): 

398 if self.domain_comm.rank == self.rank_a[a]: 

399 self.FD_awsp[a] = FD_awsp[a] 

400 

401 def _write(self, writer, writes): 

402 BaseInducedField._write(self, writer, writes) 

403 

404 # Collect np_a to master 

405 if self.kpt_comm.rank == 0 and self.band_comm.rank == 0: 

406 

407 # Create empty dict on domain master 

408 if self.domain_comm.rank == 0: 

409 np_a = {} 

410 for a in range(self.na): 

411 np_a[a] = np.empty(1, dtype=int) 

412 else: 

413 np_a = {} 

414 # Collect dict to master 

415 sendreceive_dict(self.domain_comm, np_a, 0, 

416 self.np_a, self.rank_a, range(self.na)) 

417 

418 # Write time propagation status 

419 writer.write(time=self.time, np_a=np_a) 

420 

421 def writearray(name, shape, dtype): 

422 if name.split('_')[0] in writes: 

423 writer.add_array(name, shape, dtype) 

424 a_wxg = getattr(self, name) 

425 for w in range(self.nw): 

426 writer.fill(self.gd.collect(a_wxg[w])) 

427 

428 ng = tuple(self.gd.get_size_of_global_array()) 

429 

430 # Write time propagation arrays 

431 if 'n0t' in writes: 

432 writer.write(n0t_sG=self.gd.collect(self.n0t_sG)) 

433 writearray('Fnt_wsG', (self.nw, self.nspins) + ng, self.dtype) 

434 

435 if 'D0' in writes: 

436 # Collect D0_asp to world master 

437 if self.kpt_comm.rank == 0 and self.band_comm.rank == 0: 

438 # Create empty dict on domain master 

439 if self.domain_comm.rank == 0: 

440 D0_asp = {} 

441 for a in range(self.na): 

442 npa = np_a[a] 

443 D0_asp[a] = np.empty((self.nspins, npa[0]), 

444 dtype=float) 

445 else: 

446 D0_asp = {} 

447 # Collect dict to master 

448 sendreceive_dict(self.domain_comm, D0_asp, 0, 

449 self.D0_asp, self.rank_a, range(self.na)) 

450 # Write 

451 writer.write(D0_asp=D0_asp) 

452 

453 if 'FD' in writes: 

454 # Collect FD_awsp to world master 

455 if self.kpt_comm.rank == 0 and self.band_comm.rank == 0: 

456 # Create empty dict on domain master 

457 if self.domain_comm.rank == 0: 

458 FD_awsp = {} 

459 for a in range(self.na): 

460 npa = np_a[a] 

461 FD_awsp[a] = np.empty((self.nw, self.nspins, npa[0]), 

462 dtype=complex) 

463 else: 

464 FD_awsp = {} 

465 # Collect dict to master 

466 sendreceive_dict(self.domain_comm, FD_awsp, 0, 

467 self.FD_awsp, self.rank_a, range(self.na)) 

468 # Write 

469 writer.write(FD_awsp=FD_awsp)