Coverage for gpaw/test/response/test_parallel_kptpair_extraction.py: 39%

71 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import pytest 

2from itertools import product 

3 

4import numpy as np 

5 

6from gpaw import GPAW 

7from gpaw.mpi import world 

8from gpaw.response import ResponseContext, ResponseGroundStateAdapter 

9from gpaw.response.pw_parallelization import block_partition 

10from gpaw.response.kspair import KohnShamKPointPairExtractor 

11from gpaw.response.pair_transitions import PairTransitions 

12from gpaw.response.pair_integrator import KPointPairPointIntegral 

13from gpaw.response.symmetry import QSymmetryAnalyzer 

14 

15from gpaw.test.response.test_chiks import (generate_system_s, generate_qrel_q, 

16 get_q_c, generate_nblocks_n) 

17from gpaw.test.gpwfile import response_band_cutoff 

18 

19pytestmark = pytest.mark.skipif(world.size == 1, reason='world.size == 1') 

20 

21 

22# ---------- Actual tests ---------- # 

23 

24 

25@pytest.mark.response 

26@pytest.mark.kspair 

27@pytest.mark.parametrize('system,qrel,nblocks', product(generate_system_s(), 

28 generate_qrel_q(), 

29 generate_nblocks_n())) 

30def test_parallel_extract_kptdata(in_tmp_dir, gpw_files, 

31 system, qrel, nblocks): 

32 """Test that the KohnShamKPointPair data extracted from a serial and a 

33 parallel calculator object is identical.""" 

34 

35 # ---------- Inputs ---------- # 

36 

37 wfs, spincomponent = system 

38 q_c = get_q_c(wfs, qrel) 

39 

40 # ---------- Script ---------- # 

41 

42 # Initialize serial ground state adapter 

43 serial_gs = ResponseGroundStateAdapter.from_gpw_file(gpw_files[wfs]) 

44 assert not serial_gs.is_parallelized() 

45 

46 # Initialize parallel ground state adapter 

47 calc = GPAW(gpw_files[wfs], parallel=dict(domain=1)) 

48 nbands = response_band_cutoff[wfs] 

49 parallel_gs = ResponseGroundStateAdapter(calc) 

50 assert parallel_gs.is_parallelized() 

51 

52 # Set up extractors and integrals 

53 context = ResponseContext() 

54 tcomm, kcomm = block_partition(context.comm, nblocks) 

55 serial_extractor = initialize_extractor( 

56 serial_gs, context, tcomm, kcomm) 

57 assert serial_extractor.gs.world.size == 1 

58 parallel_extractor = initialize_extractor( 

59 parallel_gs, context, tcomm, kcomm) 

60 assert parallel_extractor.gs.world.size > 1 

61 serial_integral = initialize_integral(serial_extractor, context, q_c) 

62 parallel_integral = initialize_integral(parallel_extractor, context, q_c) 

63 

64 # Set up transitions 

65 transitions = initialize_transitions(serial_gs, spincomponent, nbands) 

66 parallel_transitions = initialize_transitions( 

67 parallel_gs, spincomponent, nbands) 

68 assert np.allclose(parallel_transitions.n1_t, transitions.n1_t) 

69 assert np.allclose(parallel_transitions.n2_t, transitions.n2_t) 

70 assert np.allclose(parallel_transitions.s1_t, transitions.s1_t) 

71 assert np.allclose(parallel_transitions.s2_t, transitions.s2_t) 

72 

73 # Extract and compare kptpairs 

74 ni = serial_integral.ni # Number of iterations in kptpair generator 

75 assert parallel_integral.ni == ni 

76 serial_kptpairs = serial_integral.weighted_kpoint_pairs(transitions) 

77 parallel_kptpairs = parallel_integral.weighted_kpoint_pairs(transitions) 

78 for _ in range(ni): 

79 kptpair1, _ = next(serial_kptpairs) 

80 kptpair2, _ = next(parallel_kptpairs) 

81 compare_kptpairs(kptpair1, kptpair2) 

82 

83 

84# ---------- Test functionality ---------- # 

85 

86 

87def compare_kptpairs(kptpair1, kptpair2): 

88 if kptpair1 is None: 

89 # Due to k-point distribution, all ranks don't necessarily have a 

90 # kptpair to integrate 

91 assert kptpair2 is None 

92 return 

93 assert kptpair1.K1 == kptpair2.K1 

94 assert kptpair1.K2 == kptpair2.K2 

95 assert np.allclose(kptpair1.deps_myt, kptpair2.deps_myt) 

96 assert np.allclose(kptpair1.df_myt, kptpair2.df_myt) 

97 

98 compare_ikpts(kptpair1.ikpt1, kptpair2.ikpt1) 

99 compare_ikpts(kptpair1.ikpt2, kptpair2.ikpt2) 

100 

101 

102def compare_ikpts(ikpt1, ikpt2): 

103 assert ikpt1.ik == ikpt2.ik 

104 assert np.allclose(ikpt1.Ph.array, ikpt2.Ph.array) 

105 assert np.allclose(ikpt1.psit_hG, ikpt2.psit_hG) 

106 assert np.all(ikpt1.h_myt == ikpt2.h_myt) 

107 

108 

109def initialize_extractor(gs, context, tcomm, kcomm): 

110 return KohnShamKPointPairExtractor(gs, context, 

111 transitions_blockcomm=tcomm, 

112 kpts_blockcomm=kcomm) 

113 

114 

115def initialize_integral(extractor, context, q_c): 

116 _, generator = QSymmetryAnalyzer().analyze( 

117 np.asarray(q_c), extractor.gs.kpoints, context) 

118 return KPointPairPointIntegral(extractor, generator) 

119 

120 

121def initialize_transitions(gs, spincomponent, nbands): 

122 bandsummation = 'pairwise' 

123 return PairTransitions.from_transitions_domain_arguments( 

124 spincomponent, nbands, gs.nocc1, gs.nocc2, gs.nspins, bandsummation)