Coverage for gpaw/test/lcaotddft/__init__.py: 97%
71 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1import numpy as np
3from gpaw.mpi import world, broadcast_float
4from gpaw.lcaotddft import LCAOTDDFT
5from gpaw.lcaotddft.dipolemomentwriter import DipoleMomentWriter
6from gpaw.lcaotddft.wfwriter import WaveFunctionWriter, WaveFunctionReader
7from gpaw.lcaotddft.densitymatrix import DensityMatrix
8from gpaw.lcaotddft.frequencydensitymatrix import FrequencyDensityMatrix
9from gpaw.tddft.folding import frequencies
10from gpaw.utilities import compiled_with_sl
13def parallel_options(*, include_kpt=False, fix_sl_auto=False):
14 """Generate different parallelization options"""
15 parallel_i = []
16 for sl_auto in [False, True]:
17 if not compiled_with_sl() and sl_auto:
18 continue
19 for band in [1, 2]:
20 for kpt in [1, 2] if include_kpt else [1]:
21 if world.size < band * kpt:
22 continue
23 parallel = {'sl_auto': sl_auto, 'band': band, 'kpt': kpt}
25 if fix_sl_auto and world.size == 1 and parallel['sl_auto']:
26 # Choose BLACS grid manually as the one given by sl_auto
27 # doesn't work well for the small test system and 1 process
28 del parallel['sl_auto']
29 parallel['sl_default'] = (1, 1, 8)
31 parallel_i.append(parallel)
32 return parallel_i
35def calculate_time_propagation(gs_fpath, *, kick,
36 communicator=world, parallel={},
37 do_fdm=False):
38 td_calc = LCAOTDDFT(gs_fpath,
39 communicator=communicator,
40 parallel=parallel,
41 txt='td.out')
42 if do_fdm:
43 dmat = DensityMatrix(td_calc)
44 ffreqs = frequencies(range(0, 31, 5), 'Gauss', 0.1)
45 fdm = FrequencyDensityMatrix(td_calc, dmat, frequencies=ffreqs)
46 DipoleMomentWriter(td_calc, 'dm.dat')
47 WaveFunctionWriter(td_calc, 'wf.ulm')
48 td_calc.absorption_kick(kick)
49 td_calc.propagate(20, 3)
50 if do_fdm:
51 fdm.write('fdm.ulm')
53 communicator.barrier()
55 if do_fdm:
56 return fdm
59def calculate_error(a, ref_a):
60 if world.rank == 0:
61 err = np.abs(a - ref_a).max()
62 print()
63 print('ERR', err)
64 else:
65 err = np.nan
66 err = broadcast_float(err, world)
67 return err
70def check_txt_data(ref_fpath, data_fpath, atol):
71 world.barrier()
72 ref = np.loadtxt(ref_fpath, encoding='utf-8')
73 data = np.loadtxt(data_fpath, encoding='utf-8')
74 err = calculate_error(data, ref)
75 print('err', err, atol)
76 assert err < atol
79def check_wfs(wf_ref_fpath, wf_fpath, atol=1e-12):
80 wfr_ref = WaveFunctionReader(wf_ref_fpath)
81 wfr = WaveFunctionReader(wf_fpath)
82 assert len(wfr) == len(wfr_ref)
83 for i in range(1, len(wfr)):
84 ref = wfr_ref[i].wave_functions.coefficients
85 coeff = wfr[i].wave_functions.coefficients
86 err = calculate_error(coeff, ref)
87 assert err < atol, f'error at i={i}'
90def copy_and_cut_file(src, dst, *, cut_lines=0):
91 with open(src, 'r', encoding='utf-8') as fd:
92 lines = fd.readlines()
93 if cut_lines > 0:
94 lines = lines[:-cut_lines]
96 with open(dst, 'w', encoding='utf-8') as fd:
97 for line in lines:
98 fd.write(line)