Coverage for gpaw/test/response/test_parallel_kptpair_extraction.py: 39%
71 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1import pytest
2from itertools import product
4import numpy as np
6from gpaw import GPAW
7from gpaw.mpi import world
8from gpaw.response import ResponseContext, ResponseGroundStateAdapter
9from gpaw.response.pw_parallelization import block_partition
10from gpaw.response.kspair import KohnShamKPointPairExtractor
11from gpaw.response.pair_transitions import PairTransitions
12from gpaw.response.pair_integrator import KPointPairPointIntegral
13from gpaw.response.symmetry import QSymmetryAnalyzer
15from gpaw.test.response.test_chiks import (generate_system_s, generate_qrel_q,
16 get_q_c, generate_nblocks_n)
17from gpaw.test.gpwfile import response_band_cutoff
19pytestmark = pytest.mark.skipif(world.size == 1, reason='world.size == 1')
22# ---------- Actual tests ---------- #
25@pytest.mark.response
26@pytest.mark.kspair
27@pytest.mark.parametrize('system,qrel,nblocks', product(generate_system_s(),
28 generate_qrel_q(),
29 generate_nblocks_n()))
30def test_parallel_extract_kptdata(in_tmp_dir, gpw_files,
31 system, qrel, nblocks):
32 """Test that the KohnShamKPointPair data extracted from a serial and a
33 parallel calculator object is identical."""
35 # ---------- Inputs ---------- #
37 wfs, spincomponent = system
38 q_c = get_q_c(wfs, qrel)
40 # ---------- Script ---------- #
42 # Initialize serial ground state adapter
43 serial_gs = ResponseGroundStateAdapter.from_gpw_file(gpw_files[wfs])
44 assert not serial_gs.is_parallelized()
46 # Initialize parallel ground state adapter
47 calc = GPAW(gpw_files[wfs], parallel=dict(domain=1))
48 nbands = response_band_cutoff[wfs]
49 parallel_gs = ResponseGroundStateAdapter(calc)
50 assert parallel_gs.is_parallelized()
52 # Set up extractors and integrals
53 context = ResponseContext()
54 tcomm, kcomm = block_partition(context.comm, nblocks)
55 serial_extractor = initialize_extractor(
56 serial_gs, context, tcomm, kcomm)
57 assert serial_extractor.gs.world.size == 1
58 parallel_extractor = initialize_extractor(
59 parallel_gs, context, tcomm, kcomm)
60 assert parallel_extractor.gs.world.size > 1
61 serial_integral = initialize_integral(serial_extractor, context, q_c)
62 parallel_integral = initialize_integral(parallel_extractor, context, q_c)
64 # Set up transitions
65 transitions = initialize_transitions(serial_gs, spincomponent, nbands)
66 parallel_transitions = initialize_transitions(
67 parallel_gs, spincomponent, nbands)
68 assert np.allclose(parallel_transitions.n1_t, transitions.n1_t)
69 assert np.allclose(parallel_transitions.n2_t, transitions.n2_t)
70 assert np.allclose(parallel_transitions.s1_t, transitions.s1_t)
71 assert np.allclose(parallel_transitions.s2_t, transitions.s2_t)
73 # Extract and compare kptpairs
74 ni = serial_integral.ni # Number of iterations in kptpair generator
75 assert parallel_integral.ni == ni
76 serial_kptpairs = serial_integral.weighted_kpoint_pairs(transitions)
77 parallel_kptpairs = parallel_integral.weighted_kpoint_pairs(transitions)
78 for _ in range(ni):
79 kptpair1, _ = next(serial_kptpairs)
80 kptpair2, _ = next(parallel_kptpairs)
81 compare_kptpairs(kptpair1, kptpair2)
84# ---------- Test functionality ---------- #
87def compare_kptpairs(kptpair1, kptpair2):
88 if kptpair1 is None:
89 # Due to k-point distribution, all ranks don't necessarily have a
90 # kptpair to integrate
91 assert kptpair2 is None
92 return
93 assert kptpair1.K1 == kptpair2.K1
94 assert kptpair1.K2 == kptpair2.K2
95 assert np.allclose(kptpair1.deps_myt, kptpair2.deps_myt)
96 assert np.allclose(kptpair1.df_myt, kptpair2.df_myt)
98 compare_ikpts(kptpair1.ikpt1, kptpair2.ikpt1)
99 compare_ikpts(kptpair1.ikpt2, kptpair2.ikpt2)
102def compare_ikpts(ikpt1, ikpt2):
103 assert ikpt1.ik == ikpt2.ik
104 assert np.allclose(ikpt1.Ph.array, ikpt2.Ph.array)
105 assert np.allclose(ikpt1.psit_hG, ikpt2.psit_hG)
106 assert np.all(ikpt1.h_myt == ikpt2.h_myt)
109def initialize_extractor(gs, context, tcomm, kcomm):
110 return KohnShamKPointPairExtractor(gs, context,
111 transitions_blockcomm=tcomm,
112 kpts_blockcomm=kcomm)
115def initialize_integral(extractor, context, q_c):
116 _, generator = QSymmetryAnalyzer().analyze(
117 np.asarray(q_c), extractor.gs.kpoints, context)
118 return KPointPairPointIntegral(extractor, generator)
121def initialize_transitions(gs, spincomponent, nbands):
122 bandsummation = 'pairwise'
123 return PairTransitions.from_transitions_domain_arguments(
124 spincomponent, nbands, gs.nocc1, gs.nocc2, gs.nspins, bandsummation)