Coverage for gpaw/test/lcaotddft/test_periodic.py: 100%

35 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1import numpy as np 

2import pytest 

3 

4from ase.build import fcc111 

5 

6from gpaw import GPAW 

7from gpaw.mpi import world, serial_comm 

8from gpaw.lcaotddft.wfwriter import WaveFunctionReader 

9 

10from gpaw.test import only_on_master 

11from . import (parallel_options, calculate_error, calculate_time_propagation, 

12 check_wfs) 

13 

14pytestmark = pytest.mark.usefixtures('module_tmp_path') 

15 

16parallel_i = parallel_options(include_kpt=True) 

17 

18 

19@pytest.fixture(scope='module') 

20@only_on_master(world) 

21def initialize_system(): 

22 comm = serial_comm 

23 

24 # Ground-state calculation 

25 atoms = fcc111('Al', size=(1, 1, 2), vacuum=4.0) 

26 atoms.symbols[0] = 'Li' 

27 calc = GPAW(nbands=4, 

28 h=0.4, 

29 kpts={'size': (3, 3, 1)}, 

30 basis='sz(dzp)', 

31 mode='lcao', 

32 convergence={'density': 1e-8}, 

33 symmetry={'point_group': False}, 

34 communicator=comm, 

35 txt='gs.out') 

36 atoms.calc = calc 

37 atoms.get_potential_energy() 

38 calc.write('gs.gpw', mode='all') 

39 

40 # Time-propagation calculation 

41 calculate_time_propagation('gs.gpw', 

42 kick=[0, 0, 1e-5], 

43 communicator=comm) 

44 

45 

46@pytest.mark.rttddft 

47def test_propagated_wave_function(initialize_system, module_tmp_path): 

48 wfr = WaveFunctionReader(module_tmp_path / 'wf.ulm') 

49 coeff = wfr[-1].wave_functions.coefficients 

50 coeff = coeff[np.ix_([0], [0, 1], [1, 3], [0, 1, 2])] 

51 # Normalize the wave function sign 

52 coeff = np.sign(coeff.real[..., 0, np.newaxis]) * coeff 

53 ref = [[[[5.4119034398864430e-01 + 4.6958807325576735e-01j, 

54 -5.8836045927143954e-01 - 5.1047688429408378e-01j, 

55 -6.5609314466400698e-06 - 5.8109609173527947e-06j], 

56 [1.6425837099429430e-06 - 1.4779657236004961e-06j, 

57 -8.7230715222772428e-07 + 8.9374679369814926e-07j, 

58 3.1300283337601806e+00 - 2.7306795126551076e+00j]], 

59 [[1.9820345503468246e+00 + 1.0562314330323577e+00j, 

60 -1.5008623926242098e-01 + 4.5817475674967340e-01j, 

61 -4.8385783015916195e-01 - 5.3676335879786385e-01j], 

62 [2.4227856141643818e+00 + 3.7767002050641824e-01j, 

63 -2.6174901880264838e+00 + 1.9885717875694848e+00j, 

64 7.2641847473298660e-01 + 1.6020733667409095e+00j]]]] 

65 err = calculate_error(coeff, ref) 

66 assert err < 7e-9 

67 

68 

69@pytest.mark.rttddft 

70@pytest.mark.parametrize('parallel', parallel_i) 

71def test_propagation(initialize_system, module_tmp_path, parallel, in_tmp_dir): 

72 calculate_time_propagation(module_tmp_path / 'gs.gpw', 

73 kick=[0, 0, 1e-5], 

74 parallel=parallel) 

75 check_wfs(module_tmp_path / 'wf.ulm', 'wf.ulm', atol=1e-12)