Coverage for gpaw/arraydict.py: 78%
100 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1from collections.abc import MutableMapping
2import numpy as np
5class DefaultKeyMapping:
6 def a2key(self, a):
7 return a
9 def key2a(self, key):
10 return key
13class TableKeyMapping:
14 def __init__(self, a2key_a, key2a_k):
15 self.a2key_a = a2key_a
16 self.key2a_k = key2a_k
18 def key2a(self, key):
19 return self.key2a_k[key]
21 def a2key(self, a):
22 return self.a2key_a[a]
25class ArrayDict(MutableMapping):
26 """Distributed dictionary of fixed-size, fixed-dtype arrays.
28 Implements a map [0, ..., N] -> [A0, ..., AN]
30 Elements are initialized as empty numpy arrays.
32 Unlike a normal dictionary, this class implements a strict loop ordering
33 which is consistent with that of the underlying atom partition."""
34 def __init__(self, partition, shapes_a, dtype=float, d=None,
35 keymap=None):
36 self.partition = partition
37 if callable(shapes_a):
38 shapes_a = [shapes_a(a) for a in range(self.partition.natoms)]
39 self.shapes_a = shapes_a # global
40 assert len(shapes_a) == partition.natoms
41 self.dtype = dtype
42 self.data = {}
44 # This will be a terrible hack to make it easier to customize
45 # the "keys". Normally keys are 0...N and correspond to
46 # rank_a of the AtomPartition. But the keymap allows the user
47 # to pass around a one-to-one mapping between the "physical"
48 # keys 0...N and some other objects that the user may fancy.
49 if keymap is None:
50 keymap = DefaultKeyMapping()
51 self.keymap = keymap
53 if d is None:
54 for a in partition.my_indices:
55 self[a] = np.empty(self.shapes_a[a], dtype=dtype)
56 else:
57 self.update(d)
58 self.check_consistency()
60 # copy() is dangerous since redistributions write back
61 # into arrays, and redistribution of a copy could lead to bugs
62 # if the other copy changes.
63 #
64 # Let's just use deepcopy to stay out of trouble. Then the
65 # dictionary copy() is not invoked confusingly.
66 #
67 # -askhl
68 def copy(self):
69 copy = ArrayDict(self.partition, self.shapes_a, dtype=self.dtype,
70 keymap=self.keymap)
71 for a in self:
72 copy[a][:] = self[a]
73 return copy
75 def __getitem__(self, a):
76 value = self.data[a]
77 assert value.shape == self.shapes_a[a]
78 assert value.dtype == self.dtype
79 return value
81 def __setitem__(self, a, value):
82 assert value.shape == self.shapes_a[a], \
83 f'defined shape {self.shapes_a[a]} vs new {value.shape}'
84 assert value.dtype == self.dtype
85 self.data[a] = value
87 def redistribute(self, partition):
88 """Redistribute according to specified partition."""
89 def get_empty(a):
90 return np.empty(self.shapes_a[a], self.dtype)
92 self.partition.redistribute(partition, self, get_empty)
93 self.partition = partition # Better with immutable partition?
94 self.check_consistency()
96 def check_consistency(self):
97 k1 = set(self.partition.my_indices)
98 k2 = set(self.data.keys())
99 assert k1 == k2, f'Required keys {k1} different from actual {k2}'
100 for a, array in self.items():
101 assert array.dtype == self.dtype
102 assert array.shape == self.shapes_a[a], \
103 (f'array shape {array.shape} '
104 f'vs specified shape {self.shapes_a[a]}')
106 def toarray(self, axis=None):
107 # We could also implement it as a contiguous buffer.
108 if len(self) == 0:
109 # XXXXXX how should we deal with globally or locally empty arrays?
110 # This will probably lead to bugs unless we get all the
111 # dimensions right.
112 return np.empty(0, self.dtype)
113 if axis is None:
114 return np.concatenate([self[a].ravel()
115 for a in self.partition.my_indices])
116 else:
117 # XXX self[a].shape must all be consistent except along axis
118 return np.concatenate([self[a] for a in self.partition.my_indices],
119 axis=axis)
121 def fromarray(self, data):
122 assert data.dtype == self.dtype
123 M1 = 0
124 for a in self.partition.my_indices:
125 M2 = M1 + np.prod(self.shapes_a[a])
126 self[a].ravel()[:] = data[M1:M2]
127 M1 = M2
129 def redistribute_and_broadcast(self, dist_comm, dup_comm):
130 # Data exists on self which is a "nice" distribution but now
131 # we want it on sub_partition which has a smaller communicator
132 # whose parent is self.comm.
133 #
134 # We want our own data replicated on each
136 # XXX direct comparison of communicators are unsafe as we do not use
137 # MPI_Comm_compare
139 # assert subpartition.comm.parent == self.partition.comm
140 from gpaw.utilities.partition import AtomPartition
142 newrank_a = self.partition.rank_a % dist_comm.size
143 masters_only_partition = AtomPartition(self.partition.comm, newrank_a)
144 dst_partition = AtomPartition(dist_comm, newrank_a)
145 copy = self.copy()
146 copy.redistribute(masters_only_partition)
148 dst = ArrayDict(dst_partition, self.shapes_a, dtype=self.dtype,
149 keymap=self.keymap)
150 data = dst.toarray()
151 if dup_comm.rank == 0:
152 data0 = copy.toarray()
153 data[:] = data0
154 dup_comm.broadcast(data, 0)
155 dst.fromarray(data)
156 return dst
158 def __iter__(self):
159 # These functions enforce the same ordering as self.partition
160 # when looping.
161 return iter(self.partition.my_indices)
163 def __repr__(self):
164 tokens = []
165 for key in sorted(self.keys()):
166 shapestr = 'x'.join(map(str, self.shapes_a[key]))
167 tokens.append(f'{self.keymap.a2key(key)}:{shapestr}')
168 text = ', '.join(tokens)
169 return '%s@rank%d/%d {%s}' % (self.__class__.__name__,
170 self.partition.comm.rank,
171 self.partition.comm.size,
172 text)
174 def __delitem__(self, a):
175 # Actually this is not quite right; we effectively delete items
176 # when we redistribute the arraydict. But this is another
177 # code path so let's resolve this later if necessary.
178 raise TypeError('Deleting arraydict elements not supported since '
179 'doing so violates the input list-of-shapes')
181 def __len__(self):
182 return len(self.data)