Coverage for gpaw/test/response/test_parallelization.py: 94%

18 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1import numpy as np 

2 

3import pytest 

4 

5from gpaw.mpi import world 

6from gpaw.response.pw_parallelization import Blocks1D 

7 

8 

9@pytest.mark.response 

10def test_blocks1d_collect(): 

11 """Test the ability to collect an array distributed over the first 

12 dimension.""" 

13 dat_i = np.arange(150) 

14 dat_ij = dat_i.reshape((10, 15)) 

15 dat_ijk = dat_i.reshape((5, 3, 10)) 

16 

17 for array in [dat_i, dat_ij, dat_ijk]: 

18 blocks = Blocks1D(world, array.shape[0]) 

19 local_array = array[blocks.myslice] 

20 

21 # Test all-gather 

22 collected_array = blocks.all_gather(local_array) 

23 assert np.all(array == collected_array) 

24 

25 # Test gather 

26 collected_array = blocks.gather(local_array) 

27 if world.rank == 0: 

28 assert np.all(array == collected_array) 

29 else: 

30 assert collected_array is None