Coverage for gpaw/test/parallel/test_arraydict_redist.py: 94%

50 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import numpy as np 

2from gpaw.mpi import world 

3from gpaw.utilities.partition import AtomPartition 

4from gpaw.arraydict import ArrayDict 

5 

6 

7def test_parallel_arraydict_redist(): 

8 gen = np.random.RandomState(0) 

9 

10 def shape(a): 

11 return (a, a // 2) # Shapes: (0, 0), (1, 0), (2, 1), ... 

12 

13 natoms = 33 

14 

15 if world.size == 1: 

16 rank_a = np.zeros(natoms, int) 

17 else: 

18 # When on more than 2 cores, make sure that at least one core 

19 # (rank=0) has zero entries: 

20 lower = 0 if world.size == 2 else 1 

21 rank_a = gen.randint(lower, world.size, natoms) 

22 assert (rank_a < world.size).all() 

23 

24 serial = AtomPartition(world, np.zeros(natoms, int)) 

25 partition = AtomPartition(world, rank_a) 

26 even_partition = partition.as_even_partition() 

27 

28 def check(atomdict, title): 

29 if world.rank == world.size // 2 or world.rank == 0: 

30 print('rank %d %s: %s' % (world.rank, title.rjust(10), atomdict)) 

31 

32 # Create a normal, "well-behaved" dict against which to test arraydict. 

33 ref = dict(atomdict) 

34 # print atomdict 

35 assert set(atomdict.keys()) == set(ref.keys()) # check keys() 

36 for a in atomdict: # check __iter__, __getitem__ 

37 assert ref[a] is atomdict[a] 

38 values = list(atomdict.values()) 

39 for i, key in enumerate(atomdict): 

40 # AtomDict guarantees fixed ordering of keys. Check that 

41 # values() ordering is consistent with loop ordering: 

42 assert values[i] is atomdict[key] 

43 

44 items = list(atomdict.items()) 

45 

46 for i, (key, item) in enumerate(atomdict.items()): 

47 assert item is atomdict[key] 

48 assert item is ref[key] 

49 assert items[i][0] == key 

50 assert items[i][1] is item 

51 

52 # Hopefully this should verify all the complicated stuff 

53 

54 ad = ArrayDict(partition, shape, float) 

55 for key in ad: 

56 ad[key][:] = key 

57 array0 = ad.toarray() 

58 

59 _ = dict(ad) 

60 check(ad, 'new') 

61 ad.redistribute(even_partition) 

62 array1 = ad.toarray() 

63 if world.rank > 1: 

64 assert array1.shape != array0.shape 

65 check(ad, 'even') 

66 ad.redistribute(serial) 

67 check(ad, 'serial') 

68 ad.redistribute(partition) 

69 check(ad, 'back') 

70 

71 array2 = ad.toarray() 

72 assert (array0 == array2).all()