Coverage for gpaw/utilities/partition.py: 63%
165 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1import numpy as np
2from gpaw.arraydict import ArrayDict
5def to_parent_comm(partition):
6 # XXX assume communicator is strided, i.e. regular.
7 # This actually imposes implicit limitations on things, but is not
8 # "likely" to cause trouble with the usual communicators, i.e.
9 # for gd/kd/bd.
10 parent = partition.comm.parent
11 if parent is None:
12 # This should not ordinarily be necessary, but when running with
13 # AtomPAW, it is. So let's stay out of trouble.
14 return partition
16 members = partition.comm.get_members()
17 parent_rank_a = members[partition.rank_a]
19 # XXXX we hope and pray that our communicator is "equivalent" to
20 # that which includes parent's rank0.
21 assert min(members) == members[0]
22 parent_rank_a -= members[0] # yuckkk
23 return AtomPartition(parent, parent_rank_a,
24 name='parent-%s' % partition.name)
27class AtomicMatrixDistributor:
28 """Class to distribute atomic dictionaries like dH_asp and D_asp."""
29 def __init__(self, atom_partition, broadcast_comm,
30 work_partition=None):
31 # Assumptions on communicators are as follows.
32 #
33 # atom_partition represents standard domain decomposition, and
34 # broadcast_comm are the corresponding kpt/band communicators
35 # together encompassing wfs.world.
36 #
37 # Initially, dH_asp are distributed over domains according to the
38 # physical location of each atom, but duplicated across band
39 # and k-point communicators.
40 #
41 # The idea is to transfer dH_asp so they are distributed equally
42 # among all ranks on wfs.world, and back, when necessary.
43 self.broadcast_comm = broadcast_comm
45 self.grid_partition = atom_partition
46 self.grid_unique_partition = to_parent_comm(self.grid_partition)
48 # This represents a full distribution across grid, kpt, and band.
49 if work_partition is None:
50 work_partition = self.grid_unique_partition.as_even_partition()
51 self.work_partition = work_partition
53 def distribute(self, D_asp):
54 # Right now the D are duplicated across the band/kpt comms.
55 # Here we pick out a set of unique D. With duplicates out,
56 # we can redistribute one-to-one to the larger work_partition.
57 # assert D_asp.partition == self.grid_partition
59 Ddist_asp = ArrayDict(self.grid_unique_partition, D_asp.shapes_a,
60 dtype=D_asp.dtype)
62 if self.broadcast_comm.rank != 0:
63 assert len(Ddist_asp) == 0
64 for a in Ddist_asp:
65 Ddist_asp[a] = D_asp[a]
66 Ddist_asp.redistribute(self.work_partition)
67 return Ddist_asp
69 def collect(self, dHdist_asp):
70 # We have an array on work_partition. We want first to
71 # collect it on grid_unique_partition, the broadcast it
72 # to grid_partition.
74 # First receive one-to-one from everywhere.
75 # assert dHdist_asp.partition == self.work_partition
76 dHdist_asp = dHdist_asp.copy()
77 dHdist_asp.redistribute(self.grid_unique_partition)
79 dH_asp = ArrayDict(self.grid_partition, dHdist_asp.shapes_a,
80 dtype=dHdist_asp.dtype)
81 if self.broadcast_comm.rank == 0:
82 buf = dHdist_asp.toarray()
83 assert not np.isnan(buf).any()
84 else:
85 buf = dH_asp.toarray()
86 buf[:] = np.nan # Let's be careful for now like --debug mode
87 self.broadcast_comm.broadcast(buf, 0)
88 assert not np.isnan(buf).any()
89 dH_asp.fromarray(buf)
90 return dH_asp
93class EvenPartitioning:
94 """Represents an even partitioning of N elements over a communicator.
96 For example N=17 and comm.size=5 will result in this distribution:
98 * rank 0 has 3 local elements: 0, 1, 2
99 * rank 1 has 3 local elements: 3, 4, 5
100 * rank 2 has 3 local elements: 6, 7, 8
101 * rank 3 has 4 local elements: 9, 10, 11, 12
102 * rank 4 has 4 local elements: 13, 14, 15, 16
104 This class uses only the 'rank' and 'size' communicator attributes."""
105 def __init__(self, comm, N):
106 # Conventions:
107 # n, N: local/global size
108 # i, I: local/global index
109 self.comm = comm
110 self.N = N
111 self.nlong = -(-N // comm.size) # size of a long slice
112 self.nshort = N // comm.size # size of a short slice
113 self.longcount = N % comm.size # number of ranks with a long slice
114 self.shortcount = comm.size - self.longcount # ranks with short slice
116 def nlocal(self, rank=None):
117 """Get the number of locally stored elements."""
118 if rank is None:
119 rank = self.comm.rank
120 if rank < self.shortcount:
121 return self.nshort
122 else:
123 return self.nlong
125 def minmax(self, rank=None):
126 """Get the minimum and maximum index of elements stored locally."""
127 if rank is None:
128 rank = self.comm.rank
129 I1 = self.nshort * rank
130 if rank < self.shortcount:
131 I2 = I1 + self.nshort
132 else:
133 I1 += rank - self.shortcount
134 I2 = I1 + self.nlong
135 return I1, I2
137 def slice(self, rank=None):
138 """Get the list of indices of locally stored elements."""
139 I1, I2 = self.minmax(rank=rank)
140 return np.arange(I1, I2)
142 def global2local(self, I):
143 """Get a tuple (rank, local index) from global index I."""
144 nIshort = self.nshort * self.shortcount
145 if I < nIshort:
146 return I // self.nshort, I % self.nshort
147 else:
148 Ioffset = I - nIshort
149 return (self.shortcount + Ioffset // self.nlong,
150 Ioffset % self.nlong)
152 def local2global(self, i, rank=None):
153 """Get global index I corresponding to local index i on rank."""
154 if rank is None:
155 rank = self.comm.rank
156 return rank * self.nshort + max(rank - self.shortcount, 0) + i
158 def as_atom_partition(self, strided=False, name='unnamed-even'):
159 rank_a = [self.global2local(i)[0] for i in range(self.N)]
160 if strided:
161 rank_a = np.arange(self.comm.size).repeat(self.nlong)
162 rank_a = rank_a.reshape(self.comm.size, -1).T.ravel()
163 rank_a = rank_a[self.shortcount:].copy()
164 return AtomPartition(self.comm, rank_a, name=name)
166 def get_description(self):
167 lines = []
168 for a in range(self.comm.size):
169 elements = ', '.join(map(str, self.slice(a)))
170 line = 'rank %d has %d local elements: %s' % (a, self.nlocal(a),
171 elements)
172 lines.append(line)
173 return '\n'.join(lines)
176# Interface for things that can be redistributed with general_redistribute
177class Redistributable:
178 def get_recvbuffer(self, a):
179 raise NotImplementedError
181 def get_sendbuffer(self, a):
182 raise NotImplementedError
184 def assign(self, a):
185 raise NotImplementedError
188# Let's keep this as an independent function for now in case we change the
189# classes 5 times, like we do
190def general_redistribute(comm, src_rank_a, dst_rank_a, redistributable):
191 # To do: it should be possible to specify duplication to several ranks
192 # But how is this done best?
193 requests = []
194 flags = (src_rank_a != dst_rank_a)
195 my_incoming_atom_indices = np.argwhere(
196 np.bitwise_and(flags, dst_rank_a == comm.rank)).ravel()
197 my_outgoing_atom_indices = np.argwhere(
198 np.bitwise_and(flags, src_rank_a == comm.rank)).ravel()
200 for a in my_incoming_atom_indices:
201 # Get matrix from old domain:
202 buf = redistributable.get_recvbuffer(a)
203 requests.append(comm.receive(buf, src_rank_a[a], tag=a, block=False))
204 # These arrays are not supposed to pointers into a larger,
205 # contiguous buffer, so we should make a copy - except we
206 # must wait until we have completed the send/receiving
207 # into them, so we will do it a few lines down.
208 redistributable.assign(a, buf)
210 for a in my_outgoing_atom_indices:
211 # Send matrix to new domain:
212 buf = redistributable.get_sendbuffer(a)
213 requests.append(comm.send(buf, dst_rank_a[a], tag=a, block=False))
215 comm.waitall(requests)
218class AtomPartition:
219 """Represents atoms distributed on a standard grid descriptor."""
220 def __init__(self, comm, rank_a, name='unnamed'):
221 self.comm = comm
222 self.rank_a = np.array(rank_a)
223 self.my_indices = self.get_indices(comm.rank)
224 self.natoms = len(rank_a)
225 self.name = name
227 def __eq__(self, other: object) -> bool:
228 if not isinstance(other, AtomPartition):
229 return NotImplemented
230 return (self.comm.compare(other.comm) in ['ident', 'congruent']
231 and np.array_equal(self.rank_a, other.rank_a))
233 def __ne__(self, other):
234 return not self == other
236 def as_serial(self):
237 return AtomPartition(self.comm, np.zeros(self.natoms, int),
238 name='%s-serial' % self.name)
240 def get_indices(self, rank):
241 return np.where(self.rank_a == rank)[0]
243 def as_even_partition(self):
244 even_part = EvenPartitioning(self.comm, len(self.rank_a))
245 return even_part.as_atom_partition()
247 def redistribute(self, new_partition, atomdict_ax, get_empty):
248 # XXX we the two communicators to be equal according to
249 # some proper criterion like MPI_Comm_compare -> MPI_IDENT.
250 # But that is not implemented, so we don't.
251 if self.comm.compare(new_partition.comm) not in ['ident',
252 'congruent']:
253 msg = ('Incompatible partitions %s --> %s. '
254 'Communicators must be at least congruent'
255 % (self, new_partition))
256 raise ValueError(msg)
258 # atomdict_ax may be a dictionary or a list of dictionaries
260 has_many = not hasattr(atomdict_ax, 'items')
261 if has_many:
262 class Redist:
263 def get_recvbuffer(self, a):
264 return get_empty(a)
266 def assign(self, a, b_x):
267 for u, d_ax in enumerate(atomdict_ax):
268 assert a not in d_ax
269 atomdict_ax[u][a] = b_x[u]
271 def get_sendbuffer(self, a):
272 return np.array([d_ax.data.pop(a) for d_ax in atomdict_ax])
273 else:
274 class Redist:
275 def get_recvbuffer(self, a):
276 return get_empty(a)
278 def assign(self, a, b_x):
279 assert a not in atomdict_ax
280 atomdict_ax[a] = b_x
282 def get_sendbuffer(self, a):
283 return atomdict_ax.data.pop(a)
285 try:
286 general_redistribute(self.comm, self.rank_a,
287 new_partition.rank_a, Redist())
288 except ValueError as err:
289 raise ValueError('redistribute %s --> %s: %s'
290 % (self, new_partition, err))
291 if isinstance(atomdict_ax, ArrayDict):
292 atomdict_ax.partition = new_partition # XXX
293 atomdict_ax.check_consistency()
295 def __repr__(self):
296 indextext = ', '.join(map(str, self.my_indices))
297 return ('%s %s@rank%d/%d (%d/%d): [%s]'
298 % (self.__class__.__name__, self.name, self.comm.rank,
299 self.comm.size, len(self.my_indices), self.natoms,
300 indextext))
302 def arraydict(self, shapes, dtype=float):
303 return ArrayDict(self, shapes, dtype)