Coverage for gpaw/test/response/test_parallelization.py: 94%
18 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1import numpy as np
3import pytest
5from gpaw.mpi import world
6from gpaw.response.pw_parallelization import Blocks1D
9@pytest.mark.response
10def test_blocks1d_collect():
11 """Test the ability to collect an array distributed over the first
12 dimension."""
13 dat_i = np.arange(150)
14 dat_ij = dat_i.reshape((10, 15))
15 dat_ijk = dat_i.reshape((5, 3, 10))
17 for array in [dat_i, dat_ij, dat_ijk]:
18 blocks = Blocks1D(world, array.shape[0])
19 local_array = array[blocks.myslice]
21 # Test all-gather
22 collected_array = blocks.all_gather(local_array)
23 assert np.all(array == collected_array)
25 # Test gather
26 collected_array = blocks.gather(local_array)
27 if world.rank == 0:
28 assert np.all(array == collected_array)
29 else:
30 assert collected_array is None