Coverage for gpaw/lcaotddft/magneticmomentwriter.py: 95%

234 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import json 

2import re 

3from typing import Tuple 

4 

5import numpy as np 

6from ase import Atoms 

7from ase.units import Bohr 

8from ase.utils import IOContext 

9from gpaw.fd_operators import Gradient 

10from gpaw.lcaotddft.densitymatrix import DensityMatrix 

11from gpaw.lcaotddft.observer import TDDFTObserver 

12from gpaw.typing import Vector 

13from gpaw.utilities.tools import coordinates 

14 

15 

16def calculate_magnetic_moment_on_grid(wfs, grad_v, r_vG, dM_vaii, *, 

17 only_pseudo=False): 

18 """Calculate magnetic moment on grid. 

19 

20 Parameters 

21 ---------- 

22 wfs 

23 Wave functions object 

24 grad_v 

25 List of gradient operators 

26 r_vG 

27 Grid point coordinates 

28 dM_vaii 

29 Atomic PAW corrections for magnetic moment 

30 only_pseudo 

31 If true, do not add atomic corrections 

32 

33 Returns 

34 ------- 

35 Magnetic moment vector 

36 """ 

37 gd = wfs.gd 

38 mode = wfs.mode 

39 bd = wfs.bd 

40 kpt_u = wfs.kpt_u 

41 

42 rxnabla_v = np.zeros(3, dtype=complex) 

43 if mode == 'lcao': 

44 psit_G = gd.empty(dtype=complex) 

45 nabla_psit_vG = gd.empty(3, dtype=complex) 

46 for kpt in kpt_u: 

47 for n, f in enumerate(kpt.f_n): 

48 if mode == 'lcao': 

49 psit_G[:] = 0.0 

50 wfs.basis_functions.lcao_to_grid(kpt.C_nM[n], psit_G, kpt.q) 

51 else: 

52 psit_G = kpt.psit_nG[n] 

53 

54 for v in range(3): 

55 grad_v[v].apply(psit_G, nabla_psit_vG[v], kpt.phase_cd) 

56 

57 # rxnabla = <psi1| r x nabla |psi2> 

58 # rxnabla_x = <psi1| r_y nabla_z - r_z nabla_y |psi2> 

59 # rxnabla_y = <psi1| r_z nabla_x - r_x nabla_z |psi2> 

60 # rxnabla_z = <psi1| r_x nabla_y - r_y nabla_x |psi2> 

61 for v in range(3): 

62 v1 = (v + 1) % 3 

63 v2 = (v + 2) % 3 

64 rnabla_psit_G = (r_vG[v1] * nabla_psit_vG[v2] 

65 - r_vG[v2] * nabla_psit_vG[v1]) 

66 rxnabla_v[v] += f * gd.integrate(psit_G.conj() * rnabla_psit_G) 

67 

68 if not only_pseudo: 

69 paw_rxnabla_v = np.zeros(3, dtype=complex) 

70 for kpt in kpt_u: 

71 for v in range(3): 

72 for a, P_ni in kpt.P_ani.items(): 

73 paw_rxnabla_v[v] += np.einsum('n,ni,ij,nj', 

74 kpt.f_n, P_ni.conj(), 

75 dM_vaii[v][a], P_ni, 

76 optimize=True) 

77 gd.comm.sum(paw_rxnabla_v) 

78 rxnabla_v += paw_rxnabla_v 

79 

80 bd.comm.sum(rxnabla_v) 

81 return -0.5 * rxnabla_v.imag 

82 

83 

84def calculate_magnetic_moment_atomic_corrections(R_av, setups, partition): 

85 """Calculate atomic PAW augmentation corrections for magnetic moment. 

86 

87 Parameters 

88 ---------- 

89 R_av 

90 Atom positions 

91 setups 

92 PAW setups object 

93 partition 

94 Atom partition object 

95 

96 Returns 

97 ------- 

98 Atomic correction matrices 

99 """ 

100 # augmentation contributions to magnetic moment 

101 # <psi1| r x nabla |psi2> = <psi1| (r - Ra + Ra) x nabla |psi2> 

102 # = <psi1| (r - Ra) x nabla |psi2> 

103 # + Ra x <psi1| nabla |psi2> 

104 

105 def shape(a): 

106 ni = setups[a].ni 

107 return ni, ni 

108 

109 dM_vaii = [] 

110 for _ in range(3): 

111 dM_aii = partition.arraydict(shapes=shape, dtype=complex) 

112 dM_vaii.append(dM_aii) 

113 

114 for a in partition.my_indices: 

115 Ra_v = R_av[a] 

116 rxnabla_iiv = setups[a].rxnabla_iiv 

117 nabla_iiv = setups[a].nabla_iiv 

118 

119 # rxnabla = <psi1| (r - Ra) x nabla |psi2> 

120 # Rxnabla = Ra x <psi1| nabla |psi2> 

121 # Rxnabla_x = Ra_y nabla_z - Ra_z nabla_y 

122 # Rxnabla_y = Ra_z nabla_x - Ra_x nabla_z 

123 # Rxnabla_z = Ra_x nabla_y - Ra_y nabla_x 

124 for v in range(3): 

125 v1 = (v + 1) % 3 

126 v2 = (v + 2) % 3 

127 Rxnabla_ii = (Ra_v[v1] * nabla_iiv[:, :, v2] 

128 - Ra_v[v2] * nabla_iiv[:, :, v1]) 

129 dM_vaii[v][a][:] = Rxnabla_ii + rxnabla_iiv[:, :, v] 

130 

131 return dM_vaii 

132 

133 

134def calculate_magnetic_moment_matrix(kpt_u, bfs, correction, r_vG, dM_vaii, *, 

135 only_pseudo=False): 

136 """Calculate magnetic moment matrix in LCAO basis. 

137 

138 Parameters 

139 ---------- 

140 kpt_u 

141 K-points 

142 bfs 

143 Basis functions object 

144 correction 

145 Correction object 

146 r_vG 

147 Grid point coordinates 

148 dM_vaii 

149 Atomic PAW corrections for magnetic moment 

150 only_pseudo 

151 If true, do not add PAW corrections 

152 

153 Returns 

154 ------- 

155 Magnetic moment matrix 

156 """ 

157 Mstart = correction.Mstart 

158 Mstop = correction.Mstop 

159 mynao = Mstop - Mstart 

160 nao = bfs.Mmax 

161 

162 assert bfs.Mstart == Mstart 

163 assert bfs.Mstop == Mstop 

164 

165 M_vmM = np.zeros((3, mynao, nao), dtype=complex) 

166 rnabla_vmM = np.empty((3, mynao, nao), dtype=complex) 

167 

168 for v in range(3): 

169 v1 = (v + 1) % 3 

170 v2 = (v + 2) % 3 

171 rnabla_vmM[:] = 0.0 

172 bfs.calculate_potential_matrix_derivative(r_vG[v], rnabla_vmM, 0) 

173 M_vmM[v1] += rnabla_vmM[v2] 

174 M_vmM[v2] -= rnabla_vmM[v1] 

175 

176 if not only_pseudo: 

177 for kpt in kpt_u: 

178 assert kpt.k == 0 

179 

180 for v in range(3): 

181 correction.calculate(kpt_u[0].q, dM_vaii[v], M_vmM[v], 

182 Mstart, Mstop) 

183 

184 # The matrices should be real 

185 assert np.max(np.absolute(M_vmM.imag)) == 0.0 

186 M_vmM = M_vmM.real.copy() 

187 return -0.5 * M_vmM 

188 

189 

190def calculate_magnetic_moment_in_lcao(ksl, rho_mm, M_vmm): 

191 """Calculate magnetic moment in LCAO. 

192 

193 Parameters 

194 ---------- 

195 ksl 

196 Kohn-Sham Layouts object 

197 rho_mm 

198 Density matrix in LCAO basis 

199 M_vmm 

200 Magnetic moment matrix in LCAO basis 

201 

202 Returns 

203 ------- 

204 Magnetic moment vector 

205 """ 

206 assert M_vmm.dtype == float 

207 mm_v = np.sum(rho_mm.imag * M_vmm, axis=(1, 2)) 

208 if ksl.using_blacs: 

209 ksl.mmdescriptor.blacsgrid.comm.sum(mm_v) 

210 return mm_v 

211 

212 

213def get_origin_coordinates(atoms: Atoms, 

214 origin: str, 

215 origin_shift: Vector) -> np.ndarray: 

216 """Get origin coordinates. 

217 

218 Parameters 

219 ---------- 

220 atoms 

221 Atoms object 

222 origin 

223 See :class:`~gpaw.tddft.MagneticMomentWriter` 

224 origin_shift 

225 See :class:`~gpaw.tddft.MagneticMomentWriter` 

226 

227 Returns 

228 ------- 

229 Origin coordinates in atomic units 

230 """ 

231 if origin == 'COM': 

232 origin_v = atoms.get_center_of_mass() 

233 elif origin == 'COC': 

234 origin_v = 0.5 * atoms.get_cell().sum(0) 

235 elif origin == 'zero': 

236 origin_v = np.zeros(3, dtype=float) 

237 else: 

238 raise ValueError('unknown origin') 

239 origin_v += np.asarray(origin_shift, dtype=float) 

240 return origin_v / Bohr 

241 

242 

243def parse_header(line: str) -> Tuple[str, int, dict]: 

244 """Parse header line. 

245 

246 Example header line (keyword arguments as json): 

247 

248 NameOfWriter[version=1](**{"arg1": "abc", ...}) 

249 

250 Parameters 

251 ---------- 

252 line 

253 Header line 

254 

255 Returns 

256 ------- 

257 name 

258 Name 

259 version 

260 Version 

261 kwargs 

262 Keyword arguments 

263 

264 Raises 

265 ------ 

266 ValueError 

267 Line cannot be parsed 

268 """ 

269 regexp = r"^(?P<name>\w+)\[version=(?P<ver>\d+)\]\(\*\*(?P<args>.*)\)$" 

270 m = re.match(regexp, line) 

271 if m is None: 

272 raise ValueError('unable parse header') 

273 name = m.group('name') 

274 version = int(m.group('ver')) 

275 try: 

276 kwargs = json.loads(m.group('args')) 

277 except json.decoder.JSONDecodeError: 

278 raise ValueError('unable parse keyword arguments') 

279 return name, version, kwargs 

280 

281 

282class MagneticMomentWriter(TDDFTObserver): 

283 """Observer for writing time-dependent magnetic moment data. 

284 

285 The data is written in atomic units. 

286 

287 The observer attaches to the TDDFT calculator during creation. 

288 

289 Parameters 

290 ---------- 

291 paw 

292 TDDFT calculator 

293 filename 

294 File for writing magnetic moment data 

295 origin 

296 Origin of the coordinate system used in calculation. 

297 Possible values: 

298 ``'COM'``: center of mass (default), 

299 ``'COC'``: center of cell, 

300 ``'zero'``: (0, 0, 0) 

301 origin_shift 

302 Vector in Å shifting the origin from the position defined 

303 by *origin*. 

304 dmat 

305 Density matrix object. 

306 Define this for LCAO calculations to avoid 

307 expensive recalculations of the density matrix. 

308 calculate_on_grid 

309 Parameter for testing. 

310 In LCAO mode, if true, calculation is performed on real-space grid. 

311 In fd mode, calculation is always performed on real-space grid 

312 and this parameter is neglected. 

313 only_pseudo 

314 Parameter for testing. 

315 If true, PAW corrections are neglected. 

316 interval 

317 Update interval. Value of 1 corresponds to evaluating and 

318 writing data after every propagation step. 

319 """ 

320 version = 5 

321 

322 def __init__(self, paw, filename: str, *, 

323 origin: str = None, 

324 origin_shift: Vector = None, 

325 dmat: DensityMatrix = None, 

326 calculate_on_grid: bool = None, 

327 only_pseudo: bool = None, 

328 interval: int = 1): 

329 super().__init__(paw, interval) 

330 self.ioctx = IOContext() 

331 mode = paw.wfs.mode 

332 assert mode in ['fd', 'lcao'], f'unknown mode: {mode}' 

333 if paw.niter == 0: 

334 if origin is None: 

335 origin = 'COM' 

336 if origin_shift is None: 

337 origin_shift = [0., 0., 0.] 

338 if calculate_on_grid is None: 

339 calculate_on_grid = mode == 'fd' 

340 if only_pseudo is None: 

341 only_pseudo = False 

342 _kwargs = dict(origin=origin, 

343 origin_shift=origin_shift, 

344 calculate_on_grid=calculate_on_grid, 

345 only_pseudo=only_pseudo) 

346 

347 # Initialize 

348 self.fd = self.ioctx.openfile(filename, comm=paw.world, mode='w') 

349 self._write_header(paw, _kwargs) 

350 else: 

351 if origin is not None: 

352 raise ValueError('Do not set origin in restart') 

353 if origin_shift is not None: 

354 raise ValueError('Do not set origin_shift in restart') 

355 if calculate_on_grid is not None: 

356 raise ValueError('Do not set calculate_on_grid in restart') 

357 if only_pseudo is not None: 

358 raise ValueError('Do not set only_pseudo in restart') 

359 

360 # Read and continue 

361 _kwargs = self._read_header(filename) 

362 origin = _kwargs['origin'] # type: ignore 

363 origin_shift = _kwargs['origin_shift'] # type: ignore 

364 calculate_on_grid = _kwargs['calculate_on_grid'] # type: ignore 

365 only_pseudo = _kwargs['only_pseudo'] # type: ignore 

366 self.fd = self.ioctx.openfile(filename, comm=paw.world, mode='a') 

367 

368 atoms = paw.atoms 

369 gd = paw.wfs.gd 

370 self.timer = paw.timer 

371 

372 assert isinstance(origin, str) 

373 assert isinstance(origin_shift, list) 

374 origin_v = get_origin_coordinates(atoms, origin, origin_shift) 

375 R_av = atoms.positions / Bohr - origin_v[np.newaxis, :] 

376 r_vG, _ = coordinates(gd, origin=origin_v) 

377 

378 dM_vaii = calculate_magnetic_moment_atomic_corrections( 

379 R_av, paw.setups, paw.hamiltonian.dH_asp.partition) 

380 

381 self.calculate_on_grid = calculate_on_grid 

382 if self.calculate_on_grid: 

383 self.only_pseudo = only_pseudo 

384 self.r_vG = r_vG 

385 self.dM_vaii = dM_vaii 

386 

387 grad_v = [] 

388 for v in range(3): 

389 grad_v.append(Gradient(gd, v, dtype=complex, n=2)) 

390 self.grad_v = grad_v 

391 else: 

392 M_vmM = calculate_magnetic_moment_matrix( 

393 paw.wfs.kpt_u, paw.wfs.basis_functions, 

394 paw.wfs.atomic_correction, r_vG, dM_vaii, 

395 only_pseudo=only_pseudo) 

396 

397 # TODO: All observers recalculate density matrix 

398 # unless dmat is given. 

399 # Calculator itself could have a density matrix object to avoid 

400 # this expensive recalculation in a clean way. 

401 if dmat is None: 

402 self.dmat = DensityMatrix(paw) 

403 else: 

404 self.dmat = dmat 

405 ksl = paw.wfs.ksl 

406 if ksl.using_blacs: 

407 self.M_vmm = ksl.distribute_overlap_matrix(M_vmM) 

408 else: 

409 gd.comm.sum(M_vmM) 

410 self.M_vmm = M_vmM 

411 

412 def _write(self, line): 

413 self.fd.write(line) 

414 self.fd.flush() 

415 

416 def _write_header(self, paw, kwargs): 

417 origin_v = get_origin_coordinates( 

418 paw.atoms, kwargs['origin'], kwargs['origin_shift']) 

419 lines = [f'{self.__class__.__name__}[version={self.version}]' 

420 f'(**{json.dumps(kwargs)})', 

421 'origin_v = [%.6f, %.6f, %.6f] Å' % tuple(origin_v * Bohr)] 

422 self._write('# ' + '\n# '.join(lines) + '\n') 

423 self._write(f'# {"time":>15} {"mmx":>17} {"mmy":>22} {"mmz":>22}\n') 

424 

425 def _read_header(self, filename): 

426 with open(filename, encoding='utf-8') as fd: 

427 line = fd.readline() 

428 try: 

429 name, version, kwargs = parse_header(line[2:]) 

430 except ValueError as e: 

431 raise ValueError(f'File {filename} cannot be parsed: {e}') 

432 if name != self.__class__.__name__: 

433 raise ValueError(f'File {filename} is not ' 

434 f'for {self.__class__.__name__}') 

435 if version != self.version: 

436 raise ValueError(f'File {filename} is not ' 

437 f'of version {self.version}') 

438 return kwargs 

439 

440 def _write_init(self, paw): 

441 time = paw.time 

442 line = '# Start; Time = %.8lf\n' % time 

443 self._write(line) 

444 

445 def _write_kick(self, paw): 

446 time = paw.time 

447 kick = paw.kick_strength 

448 gauge = paw.kick_gauge 

449 line = '# Kick = [%22.12le, %22.12le, %22.12le]; ' % tuple(kick) 

450 line += 'Gauge = %s; ' % gauge 

451 line += 'Time = %.8lf\n' % time 

452 self._write(line) 

453 

454 def _calculate_mm(self, paw): 

455 if self.calculate_on_grid: 

456 self.timer.start('Calculate magnetic moment on grid') 

457 mm_v = calculate_magnetic_moment_on_grid( 

458 paw.wfs, self.grad_v, self.r_vG, self.dM_vaii, 

459 only_pseudo=self.only_pseudo) 

460 self.timer.stop('Calculate magnetic moment on grid') 

461 else: 

462 self.timer.start('Calculate magnetic moment in LCAO') 

463 

464 mm_v = 0.0 

465 for kpt in paw.wfs.kpt_u: 

466 assert kpt.q == 0 

467 for rho_mm in self.dmat.get_density_matrix((paw.niter, 

468 paw.action)): 

469 mm_v += calculate_magnetic_moment_in_lcao( 

470 paw.wfs.ksl, rho_mm, self.M_vmm) 

471 self.timer.stop('Calculate magnetic moment in LCAO') 

472 assert mm_v.shape == (3,) 

473 assert mm_v.dtype == float 

474 return mm_v 

475 

476 def _write_mm(self, paw): 

477 time = paw.time 

478 mm_v = self._calculate_mm(paw) 

479 line = ('%20.8lf %22.12le %22.12le %22.12le\n' 

480 % (time, mm_v[0], mm_v[1], mm_v[2])) 

481 self._write(line) 

482 

483 def _update(self, paw): 

484 if paw.action == 'init': 

485 self._write_init(paw) 

486 elif paw.action == 'kick': 

487 self._write_kick(paw) 

488 self._write_mm(paw) 

489 

490 def __del__(self): 

491 self.ioctx.close()