Coverage for gpaw/test/tddft/test_molecule.py: 92%
78 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1import pytest
3from ase.build import molecule
5from gpaw import GPAW
6from gpaw.tddft import TDDFT, DipoleMomentWriter
7from gpaw.mpi import world, serial_comm
8from gpaw.utilities import compiled_with_sl
10from ..lcaotddft.test_molecule import only_on_master
13pytestmark = pytest.mark.usefixtures('module_tmp_path')
16def calculate_time_propagation(gpw_fpath, *,
17 iterations=3,
18 kick=[1e-5, 1e-5, 1e-5],
19 propagator='SICN',
20 communicator=world,
21 write_and_continue=False,
22 force_new_dm_file=False,
23 parallel={}):
24 td_calc = TDDFT(gpw_fpath,
25 propagator=propagator,
26 communicator=communicator,
27 parallel=parallel,
28 txt='td.out')
29 DipoleMomentWriter(td_calc, 'dm.dat',
30 force_new_file=force_new_dm_file)
31 if kick is not None:
32 td_calc.absorption_kick(kick)
33 td_calc.propagate(20, iterations)
34 if write_and_continue:
35 td_calc.write('td.gpw', mode='all')
36 # Switch dipole moment writer and output
37 td_calc.observers.pop()
38 dm = DipoleMomentWriter(td_calc, 'dm2.dat', force_new_file=True)
39 dm._update(td_calc)
40 td_calc.propagate(20, iterations)
41 communicator.barrier()
44def check_dm(ref_fpath, fpath, rtol=1e-8, atol=1e-12):
45 from gpaw.tddft.spectrum import read_dipole_moment_file
47 world.barrier()
48 _, time_ref_t, _, dm_ref_tv = read_dipole_moment_file(ref_fpath)
49 _, time_t, _, dm_tv = read_dipole_moment_file(fpath)
50 assert time_t == pytest.approx(time_ref_t, abs=0)
51 assert dm_tv == pytest.approx(dm_ref_tv, rel=rtol, abs=atol)
54# Generate different parallelization options
55parallel_i = [{}]
56if world.size > 1:
57 parallel_i.append({'band': 2})
58if compiled_with_sl():
59 parallel_i.append({'sl_auto': True})
60 if world.size > 1:
61 parallel_i.append({'sl_auto': True, 'band': 2})
64@pytest.fixture(scope='module')
65@only_on_master(world)
66def ground_state():
67 atoms = molecule('SiH4')
68 atoms.center(vacuum=4.0)
70 calc = GPAW(mode='fd', nbands=6, h=0.4,
71 convergence={'density': 1e-8},
72 communicator=serial_comm,
73 xc='LDA',
74 symmetry={'point_group': False},
75 txt='gs.out')
76 atoms.calc = calc
77 atoms.get_potential_energy()
78 calc.write('gs.gpw', mode='all')
81@pytest.fixture(scope='module')
82@only_on_master(world)
83def time_propagation_reference(ground_state):
84 calculate_time_propagation('gs.gpw',
85 communicator=serial_comm,
86 write_and_continue=True)
89def test_dipole_moment_values(time_propagation_reference,
90 module_tmp_path, in_tmp_dir):
91 with open('dm.dat', 'w') as fd:
92 fd.write('''
93# DipoleMomentWriter[version=1](center=False, density='comp')
94# time norm dmx dmy dmz
95# Start; Time = 0.00000000
96 0.00000000 -8.62679509e-16 8.856042552837e-09 5.230011358635e-11 1.624559936066e-10
97# Kick = [ 1.000000000000e-05, 1.000000000000e-05, 1.000000000000e-05]; Time = 0.00000000
98 0.00000000 9.59295128e-16 8.826542661185e-09 -1.968118480737e-10 -1.260104338852e-10
99 0.82682747 1.64702342e-15 6.016062457419e-05 6.015263997632e-05 6.015070820074e-05
100 1.65365493 1.36035859e-15 1.075409609786e-04 1.075366083805e-04 1.075339737111e-04
101 2.48048240 1.53109134e-15 1.388608139179e-04 1.388701618472e-04 1.388666380740e-04
102'''.strip()) # noqa: E501
104 with open('dm2.dat', 'w') as fd:
105 fd.write('''
106# DipoleMomentWriter[version=1](center=False, density='comp')
107# time norm dmx dmy dmz
108 2.48048240 1.53109134e-15 1.388608139179e-04 1.388701618472e-04 1.388666380740e-04
109 3.30730987 1.36214053e-15 1.528275514998e-04 1.528424797241e-04 1.528388409079e-04
110 4.13413733 -5.46885441e-16 1.498039918400e-04 1.498178744836e-04 1.498147055362e-04
111 4.96096480 -3.62630566e-16 1.324275745486e-04 1.324479404917e-04 1.324450415780e-04
112'''.strip()) # noqa: E501
114 rtol = 5e-4
115 atol = 1e-7
116 check_dm('dm.dat', module_tmp_path / 'dm.dat', rtol=rtol, atol=atol)
117 check_dm('dm2.dat', module_tmp_path / 'dm2.dat', rtol=rtol, atol=atol)
120@pytest.mark.parametrize('parallel', parallel_i)
121@pytest.mark.parametrize('propagator', [
122 'SICN', 'ECN', 'ETRSCN', 'SIKE'])
123def test_propagation(time_propagation_reference,
124 parallel, propagator,
125 module_tmp_path, in_tmp_dir):
126 calculate_time_propagation(module_tmp_path / 'gs.gpw',
127 propagator=propagator,
128 parallel=parallel)
129 atol = 1e-12
130 if propagator == 'SICN':
131 # This is the same propagator as the reference;
132 # error comes only from parallelization
133 rtol = 1e-8
134 if 'band' in parallel:
135 # XXX band parallelization is inaccurate!
136 rtol = 7e-4
137 atol = 5e-8
138 else:
139 # Other propagators match qualitatively
140 rtol = 5e-2
141 if 'band' in parallel:
142 # XXX band parallelization is inaccurate!
143 atol = 5e-8
144 check_dm(module_tmp_path / 'dm.dat', 'dm.dat', rtol=rtol, atol=atol)
147@pytest.mark.parametrize('parallel', parallel_i)
148def test_restart(time_propagation_reference,
149 parallel,
150 module_tmp_path, in_tmp_dir):
151 calculate_time_propagation(module_tmp_path / 'td.gpw',
152 kick=None,
153 force_new_dm_file=True,
154 parallel=parallel)
155 rtol = 1e-8
156 if 'band' in parallel:
157 rtol = 5e-4
158 check_dm(module_tmp_path / 'dm2.dat', 'dm.dat', rtol=rtol)