Coverage for gpaw/test/tddft/test_molecule.py: 92%

78 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1import pytest 

2 

3from ase.build import molecule 

4 

5from gpaw import GPAW 

6from gpaw.tddft import TDDFT, DipoleMomentWriter 

7from gpaw.mpi import world, serial_comm 

8from gpaw.utilities import compiled_with_sl 

9 

10from ..lcaotddft.test_molecule import only_on_master 

11 

12 

13pytestmark = pytest.mark.usefixtures('module_tmp_path') 

14 

15 

16def calculate_time_propagation(gpw_fpath, *, 

17 iterations=3, 

18 kick=[1e-5, 1e-5, 1e-5], 

19 propagator='SICN', 

20 communicator=world, 

21 write_and_continue=False, 

22 force_new_dm_file=False, 

23 parallel={}): 

24 td_calc = TDDFT(gpw_fpath, 

25 propagator=propagator, 

26 communicator=communicator, 

27 parallel=parallel, 

28 txt='td.out') 

29 DipoleMomentWriter(td_calc, 'dm.dat', 

30 force_new_file=force_new_dm_file) 

31 if kick is not None: 

32 td_calc.absorption_kick(kick) 

33 td_calc.propagate(20, iterations) 

34 if write_and_continue: 

35 td_calc.write('td.gpw', mode='all') 

36 # Switch dipole moment writer and output 

37 td_calc.observers.pop() 

38 dm = DipoleMomentWriter(td_calc, 'dm2.dat', force_new_file=True) 

39 dm._update(td_calc) 

40 td_calc.propagate(20, iterations) 

41 communicator.barrier() 

42 

43 

44def check_dm(ref_fpath, fpath, rtol=1e-8, atol=1e-12): 

45 from gpaw.tddft.spectrum import read_dipole_moment_file 

46 

47 world.barrier() 

48 _, time_ref_t, _, dm_ref_tv = read_dipole_moment_file(ref_fpath) 

49 _, time_t, _, dm_tv = read_dipole_moment_file(fpath) 

50 assert time_t == pytest.approx(time_ref_t, abs=0) 

51 assert dm_tv == pytest.approx(dm_ref_tv, rel=rtol, abs=atol) 

52 

53 

54# Generate different parallelization options 

55parallel_i = [{}] 

56if world.size > 1: 

57 parallel_i.append({'band': 2}) 

58if compiled_with_sl(): 

59 parallel_i.append({'sl_auto': True}) 

60 if world.size > 1: 

61 parallel_i.append({'sl_auto': True, 'band': 2}) 

62 

63 

64@pytest.fixture(scope='module') 

65@only_on_master(world) 

66def ground_state(): 

67 atoms = molecule('SiH4') 

68 atoms.center(vacuum=4.0) 

69 

70 calc = GPAW(mode='fd', nbands=6, h=0.4, 

71 convergence={'density': 1e-8}, 

72 communicator=serial_comm, 

73 xc='LDA', 

74 symmetry={'point_group': False}, 

75 txt='gs.out') 

76 atoms.calc = calc 

77 atoms.get_potential_energy() 

78 calc.write('gs.gpw', mode='all') 

79 

80 

81@pytest.fixture(scope='module') 

82@only_on_master(world) 

83def time_propagation_reference(ground_state): 

84 calculate_time_propagation('gs.gpw', 

85 communicator=serial_comm, 

86 write_and_continue=True) 

87 

88 

89def test_dipole_moment_values(time_propagation_reference, 

90 module_tmp_path, in_tmp_dir): 

91 with open('dm.dat', 'w') as fd: 

92 fd.write(''' 

93# DipoleMomentWriter[version=1](center=False, density='comp') 

94# time norm dmx dmy dmz 

95# Start; Time = 0.00000000 

96 0.00000000 -8.62679509e-16 8.856042552837e-09 5.230011358635e-11 1.624559936066e-10 

97# Kick = [ 1.000000000000e-05, 1.000000000000e-05, 1.000000000000e-05]; Time = 0.00000000 

98 0.00000000 9.59295128e-16 8.826542661185e-09 -1.968118480737e-10 -1.260104338852e-10 

99 0.82682747 1.64702342e-15 6.016062457419e-05 6.015263997632e-05 6.015070820074e-05 

100 1.65365493 1.36035859e-15 1.075409609786e-04 1.075366083805e-04 1.075339737111e-04 

101 2.48048240 1.53109134e-15 1.388608139179e-04 1.388701618472e-04 1.388666380740e-04 

102'''.strip()) # noqa: E501 

103 

104 with open('dm2.dat', 'w') as fd: 

105 fd.write(''' 

106# DipoleMomentWriter[version=1](center=False, density='comp') 

107# time norm dmx dmy dmz 

108 2.48048240 1.53109134e-15 1.388608139179e-04 1.388701618472e-04 1.388666380740e-04 

109 3.30730987 1.36214053e-15 1.528275514998e-04 1.528424797241e-04 1.528388409079e-04 

110 4.13413733 -5.46885441e-16 1.498039918400e-04 1.498178744836e-04 1.498147055362e-04 

111 4.96096480 -3.62630566e-16 1.324275745486e-04 1.324479404917e-04 1.324450415780e-04 

112'''.strip()) # noqa: E501 

113 

114 rtol = 5e-4 

115 atol = 1e-7 

116 check_dm('dm.dat', module_tmp_path / 'dm.dat', rtol=rtol, atol=atol) 

117 check_dm('dm2.dat', module_tmp_path / 'dm2.dat', rtol=rtol, atol=atol) 

118 

119 

120@pytest.mark.parametrize('parallel', parallel_i) 

121@pytest.mark.parametrize('propagator', [ 

122 'SICN', 'ECN', 'ETRSCN', 'SIKE']) 

123def test_propagation(time_propagation_reference, 

124 parallel, propagator, 

125 module_tmp_path, in_tmp_dir): 

126 calculate_time_propagation(module_tmp_path / 'gs.gpw', 

127 propagator=propagator, 

128 parallel=parallel) 

129 atol = 1e-12 

130 if propagator == 'SICN': 

131 # This is the same propagator as the reference; 

132 # error comes only from parallelization 

133 rtol = 1e-8 

134 if 'band' in parallel: 

135 # XXX band parallelization is inaccurate! 

136 rtol = 7e-4 

137 atol = 5e-8 

138 else: 

139 # Other propagators match qualitatively 

140 rtol = 5e-2 

141 if 'band' in parallel: 

142 # XXX band parallelization is inaccurate! 

143 atol = 5e-8 

144 check_dm(module_tmp_path / 'dm.dat', 'dm.dat', rtol=rtol, atol=atol) 

145 

146 

147@pytest.mark.parametrize('parallel', parallel_i) 

148def test_restart(time_propagation_reference, 

149 parallel, 

150 module_tmp_path, in_tmp_dir): 

151 calculate_time_propagation(module_tmp_path / 'td.gpw', 

152 kick=None, 

153 force_new_dm_file=True, 

154 parallel=parallel) 

155 rtol = 1e-8 

156 if 'band' in parallel: 

157 rtol = 5e-4 

158 check_dm(module_tmp_path / 'dm2.dat', 'dm.dat', rtol=rtol)