Coverage for gpaw/test/lcaotddft/__init__.py: 97%

71 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1import numpy as np 

2 

3from gpaw.mpi import world, broadcast_float 

4from gpaw.lcaotddft import LCAOTDDFT 

5from gpaw.lcaotddft.dipolemomentwriter import DipoleMomentWriter 

6from gpaw.lcaotddft.wfwriter import WaveFunctionWriter, WaveFunctionReader 

7from gpaw.lcaotddft.densitymatrix import DensityMatrix 

8from gpaw.lcaotddft.frequencydensitymatrix import FrequencyDensityMatrix 

9from gpaw.tddft.folding import frequencies 

10from gpaw.utilities import compiled_with_sl 

11 

12 

13def parallel_options(*, include_kpt=False, fix_sl_auto=False): 

14 """Generate different parallelization options""" 

15 parallel_i = [] 

16 for sl_auto in [False, True]: 

17 if not compiled_with_sl() and sl_auto: 

18 continue 

19 for band in [1, 2]: 

20 for kpt in [1, 2] if include_kpt else [1]: 

21 if world.size < band * kpt: 

22 continue 

23 parallel = {'sl_auto': sl_auto, 'band': band, 'kpt': kpt} 

24 

25 if fix_sl_auto and world.size == 1 and parallel['sl_auto']: 

26 # Choose BLACS grid manually as the one given by sl_auto 

27 # doesn't work well for the small test system and 1 process 

28 del parallel['sl_auto'] 

29 parallel['sl_default'] = (1, 1, 8) 

30 

31 parallel_i.append(parallel) 

32 return parallel_i 

33 

34 

35def calculate_time_propagation(gs_fpath, *, kick, 

36 communicator=world, parallel={}, 

37 do_fdm=False): 

38 td_calc = LCAOTDDFT(gs_fpath, 

39 communicator=communicator, 

40 parallel=parallel, 

41 txt='td.out') 

42 if do_fdm: 

43 dmat = DensityMatrix(td_calc) 

44 ffreqs = frequencies(range(0, 31, 5), 'Gauss', 0.1) 

45 fdm = FrequencyDensityMatrix(td_calc, dmat, frequencies=ffreqs) 

46 DipoleMomentWriter(td_calc, 'dm.dat') 

47 WaveFunctionWriter(td_calc, 'wf.ulm') 

48 td_calc.absorption_kick(kick) 

49 td_calc.propagate(20, 3) 

50 if do_fdm: 

51 fdm.write('fdm.ulm') 

52 

53 communicator.barrier() 

54 

55 if do_fdm: 

56 return fdm 

57 

58 

59def calculate_error(a, ref_a): 

60 if world.rank == 0: 

61 err = np.abs(a - ref_a).max() 

62 print() 

63 print('ERR', err) 

64 else: 

65 err = np.nan 

66 err = broadcast_float(err, world) 

67 return err 

68 

69 

70def check_txt_data(ref_fpath, data_fpath, atol): 

71 world.barrier() 

72 ref = np.loadtxt(ref_fpath, encoding='utf-8') 

73 data = np.loadtxt(data_fpath, encoding='utf-8') 

74 err = calculate_error(data, ref) 

75 print('err', err, atol) 

76 assert err < atol 

77 

78 

79def check_wfs(wf_ref_fpath, wf_fpath, atol=1e-12): 

80 wfr_ref = WaveFunctionReader(wf_ref_fpath) 

81 wfr = WaveFunctionReader(wf_fpath) 

82 assert len(wfr) == len(wfr_ref) 

83 for i in range(1, len(wfr)): 

84 ref = wfr_ref[i].wave_functions.coefficients 

85 coeff = wfr[i].wave_functions.coefficients 

86 err = calculate_error(coeff, ref) 

87 assert err < atol, f'error at i={i}' 

88 

89 

90def copy_and_cut_file(src, dst, *, cut_lines=0): 

91 with open(src, 'r', encoding='utf-8') as fd: 

92 lines = fd.readlines() 

93 if cut_lines > 0: 

94 lines = lines[:-cut_lines] 

95 

96 with open(dst, 'w', encoding='utf-8') as fd: 

97 for line in lines: 

98 fd.write(line)