Coverage for gpaw/test/parallel/test_arraydict_redist.py: 94%
50 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1import numpy as np
2from gpaw.mpi import world
3from gpaw.utilities.partition import AtomPartition
4from gpaw.arraydict import ArrayDict
7def test_parallel_arraydict_redist():
8 gen = np.random.RandomState(0)
10 def shape(a):
11 return (a, a // 2) # Shapes: (0, 0), (1, 0), (2, 1), ...
13 natoms = 33
15 if world.size == 1:
16 rank_a = np.zeros(natoms, int)
17 else:
18 # When on more than 2 cores, make sure that at least one core
19 # (rank=0) has zero entries:
20 lower = 0 if world.size == 2 else 1
21 rank_a = gen.randint(lower, world.size, natoms)
22 assert (rank_a < world.size).all()
24 serial = AtomPartition(world, np.zeros(natoms, int))
25 partition = AtomPartition(world, rank_a)
26 even_partition = partition.as_even_partition()
28 def check(atomdict, title):
29 if world.rank == world.size // 2 or world.rank == 0:
30 print('rank %d %s: %s' % (world.rank, title.rjust(10), atomdict))
32 # Create a normal, "well-behaved" dict against which to test arraydict.
33 ref = dict(atomdict)
34 # print atomdict
35 assert set(atomdict.keys()) == set(ref.keys()) # check keys()
36 for a in atomdict: # check __iter__, __getitem__
37 assert ref[a] is atomdict[a]
38 values = list(atomdict.values())
39 for i, key in enumerate(atomdict):
40 # AtomDict guarantees fixed ordering of keys. Check that
41 # values() ordering is consistent with loop ordering:
42 assert values[i] is atomdict[key]
44 items = list(atomdict.items())
46 for i, (key, item) in enumerate(atomdict.items()):
47 assert item is atomdict[key]
48 assert item is ref[key]
49 assert items[i][0] == key
50 assert items[i][1] is item
52 # Hopefully this should verify all the complicated stuff
54 ad = ArrayDict(partition, shape, float)
55 for key in ad:
56 ad[key][:] = key
57 array0 = ad.toarray()
59 _ = dict(ad)
60 check(ad, 'new')
61 ad.redistribute(even_partition)
62 array1 = ad.toarray()
63 if world.rank > 1:
64 assert array1.shape != array0.shape
65 check(ad, 'even')
66 ad.redistribute(serial)
67 check(ad, 'serial')
68 ad.redistribute(partition)
69 check(ad, 'back')
71 array2 = ad.toarray()
72 assert (array0 == array2).all()