Coverage for gpaw/lcaotddft/energywriter.py: 88%

67 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import numpy as np 

2 

3from ase.utils import IOContext 

4 

5from gpaw.lcaotddft.observer import TDDFTObserver 

6from gpaw.utilities.scalapack import scalapack_zero 

7 

8 

9class EnergyWriter(TDDFTObserver): 

10 version = 1 

11 

12 def __init__(self, paw, dmat, filename, interval=1): 

13 TDDFTObserver.__init__(self, paw, interval) 

14 self.ioctx = IOContext() 

15 self.dmat = dmat 

16 if paw.niter == 0: 

17 # Initialize 

18 self.fd = self.ioctx.openfile(filename, comm=paw.world, mode='w') 

19 else: 

20 # Read and continue 

21 self.fd = self.ioctx.openfile(filename, comm=paw.world, mode='a') 

22 

23 def _write(self, line): 

24 self.fd.write(line) 

25 self.fd.flush() 

26 

27 def _write_header(self, paw): 

28 if paw.niter != 0: 

29 return 

30 line = f'# {self.__class__.__name__}[version={self.version}]\n' 

31 line += ('# %15s %22s %22s %22s %22s %22s %22s\n' % 

32 ('time', 'kinetic0', 'coulomb', 'zero', 'external', 

33 'xc', 'band')) 

34 self._write(line) 

35 

36 def _write_kick(self, paw): 

37 time = paw.time 

38 kick = paw.kick_strength 

39 line = '# Kick = [%22.12le, %22.12le, %22.12le]; ' % tuple(kick) 

40 line += 'Time = %.8lf\n' % time 

41 self._write(line) 

42 

43 def _get_energies(self, paw): 

44 e_band = 0.0 

45 rho_uMM = self.dmat.get_density_matrix((paw.niter, paw.action)) 

46 get_H_MM = paw.td_hamiltonian.get_hamiltonian_matrix 

47 ksl = paw.wfs.ksl 

48 for u, kpt in enumerate(paw.wfs.kpt_u): 

49 rho_MM = rho_uMM[u] 

50 

51 # H_MM = get_H_MM(kpt, paw.time) 

52 H_MM = get_H_MM(kpt, paw.time, addfxc=False, addpot=False) 

53 

54 if ksl.using_blacs: 

55 # rhoH_MM = (rho_MM * H_MM).real # General case 

56 rhoH_MM = rho_MM.real * H_MM.real # Hamiltonian is real 

57 # Hamiltonian has correct values only in lower half, so 

58 # 1. Add lower half and diagonal twice 

59 scalapack_zero(ksl.mmdescriptor, rhoH_MM, 'U') 

60 e = 2 * np.sum(rhoH_MM) 

61 # 2. Reduce the extra diagonal 

62 scalapack_zero(ksl.mmdescriptor, rhoH_MM, 'L') 

63 e -= np.sum(rhoH_MM) 

64 # Sum over all ranks 

65 e = ksl.block_comm.sum(e) 

66 else: 

67 e = np.sum(rho_MM * H_MM).real 

68 

69 e_band += e 

70 

71 paw.wfs.occupations.e_band = e_band 

72 paw.wfs.occupations.e_entropy = 0.0 

73 e_kinetic0 = paw.hamiltonian.e_kinetic0 

74 e_coulomb = paw.hamiltonian.e_coulomb 

75 e_zero = paw.hamiltonian.e_zero 

76 e_external = paw.hamiltonian.e_external 

77 e_xc = paw.hamiltonian.e_xc 

78 

79 return np.array((e_kinetic0, e_coulomb, e_zero, 

80 e_external, e_xc, e_band)) 

81 

82 def _write_energy(self, paw): 

83 time = paw.time 

84 energy_i = self._get_energies(paw) - self.energy0_i 

85 line = ( 

86 '%20.8lf %22.12le %22.12le %22.12le %22.12le %22.12le %22.12le\n' % 

87 ((time, ) + tuple(energy_i))) 

88 self._write(line) 

89 

90 def _update(self, paw): 

91 if paw.action == 'init': 

92 self._write_header(paw) 

93 self.energy0_i = self._get_energies(paw) 

94 elif paw.action == 'kick': 

95 self._write_kick(paw) 

96 self._write_energy(paw) 

97 

98 def __del__(self): 

99 self.ioctx.close()