Coverage for gpaw/utilities/partition.py: 63%

165 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1import numpy as np 

2from gpaw.arraydict import ArrayDict 

3 

4 

5def to_parent_comm(partition): 

6 # XXX assume communicator is strided, i.e. regular. 

7 # This actually imposes implicit limitations on things, but is not 

8 # "likely" to cause trouble with the usual communicators, i.e. 

9 # for gd/kd/bd. 

10 parent = partition.comm.parent 

11 if parent is None: 

12 # This should not ordinarily be necessary, but when running with 

13 # AtomPAW, it is. So let's stay out of trouble. 

14 return partition 

15 

16 members = partition.comm.get_members() 

17 parent_rank_a = members[partition.rank_a] 

18 

19 # XXXX we hope and pray that our communicator is "equivalent" to 

20 # that which includes parent's rank0. 

21 assert min(members) == members[0] 

22 parent_rank_a -= members[0] # yuckkk 

23 return AtomPartition(parent, parent_rank_a, 

24 name='parent-%s' % partition.name) 

25 

26 

27class AtomicMatrixDistributor: 

28 """Class to distribute atomic dictionaries like dH_asp and D_asp.""" 

29 def __init__(self, atom_partition, broadcast_comm, 

30 work_partition=None): 

31 # Assumptions on communicators are as follows. 

32 # 

33 # atom_partition represents standard domain decomposition, and 

34 # broadcast_comm are the corresponding kpt/band communicators 

35 # together encompassing wfs.world. 

36 # 

37 # Initially, dH_asp are distributed over domains according to the 

38 # physical location of each atom, but duplicated across band 

39 # and k-point communicators. 

40 # 

41 # The idea is to transfer dH_asp so they are distributed equally 

42 # among all ranks on wfs.world, and back, when necessary. 

43 self.broadcast_comm = broadcast_comm 

44 

45 self.grid_partition = atom_partition 

46 self.grid_unique_partition = to_parent_comm(self.grid_partition) 

47 

48 # This represents a full distribution across grid, kpt, and band. 

49 if work_partition is None: 

50 work_partition = self.grid_unique_partition.as_even_partition() 

51 self.work_partition = work_partition 

52 

53 def distribute(self, D_asp): 

54 # Right now the D are duplicated across the band/kpt comms. 

55 # Here we pick out a set of unique D. With duplicates out, 

56 # we can redistribute one-to-one to the larger work_partition. 

57 # assert D_asp.partition == self.grid_partition 

58 

59 Ddist_asp = ArrayDict(self.grid_unique_partition, D_asp.shapes_a, 

60 dtype=D_asp.dtype) 

61 

62 if self.broadcast_comm.rank != 0: 

63 assert len(Ddist_asp) == 0 

64 for a in Ddist_asp: 

65 Ddist_asp[a] = D_asp[a] 

66 Ddist_asp.redistribute(self.work_partition) 

67 return Ddist_asp 

68 

69 def collect(self, dHdist_asp): 

70 # We have an array on work_partition. We want first to 

71 # collect it on grid_unique_partition, the broadcast it 

72 # to grid_partition. 

73 

74 # First receive one-to-one from everywhere. 

75 # assert dHdist_asp.partition == self.work_partition 

76 dHdist_asp = dHdist_asp.copy() 

77 dHdist_asp.redistribute(self.grid_unique_partition) 

78 

79 dH_asp = ArrayDict(self.grid_partition, dHdist_asp.shapes_a, 

80 dtype=dHdist_asp.dtype) 

81 if self.broadcast_comm.rank == 0: 

82 buf = dHdist_asp.toarray() 

83 assert not np.isnan(buf).any() 

84 else: 

85 buf = dH_asp.toarray() 

86 buf[:] = np.nan # Let's be careful for now like --debug mode 

87 self.broadcast_comm.broadcast(buf, 0) 

88 assert not np.isnan(buf).any() 

89 dH_asp.fromarray(buf) 

90 return dH_asp 

91 

92 

93class EvenPartitioning: 

94 """Represents an even partitioning of N elements over a communicator. 

95 

96 For example N=17 and comm.size=5 will result in this distribution: 

97 

98 * rank 0 has 3 local elements: 0, 1, 2 

99 * rank 1 has 3 local elements: 3, 4, 5 

100 * rank 2 has 3 local elements: 6, 7, 8 

101 * rank 3 has 4 local elements: 9, 10, 11, 12 

102 * rank 4 has 4 local elements: 13, 14, 15, 16 

103 

104 This class uses only the 'rank' and 'size' communicator attributes.""" 

105 def __init__(self, comm, N): 

106 # Conventions: 

107 # n, N: local/global size 

108 # i, I: local/global index 

109 self.comm = comm 

110 self.N = N 

111 self.nlong = -(-N // comm.size) # size of a long slice 

112 self.nshort = N // comm.size # size of a short slice 

113 self.longcount = N % comm.size # number of ranks with a long slice 

114 self.shortcount = comm.size - self.longcount # ranks with short slice 

115 

116 def nlocal(self, rank=None): 

117 """Get the number of locally stored elements.""" 

118 if rank is None: 

119 rank = self.comm.rank 

120 if rank < self.shortcount: 

121 return self.nshort 

122 else: 

123 return self.nlong 

124 

125 def minmax(self, rank=None): 

126 """Get the minimum and maximum index of elements stored locally.""" 

127 if rank is None: 

128 rank = self.comm.rank 

129 I1 = self.nshort * rank 

130 if rank < self.shortcount: 

131 I2 = I1 + self.nshort 

132 else: 

133 I1 += rank - self.shortcount 

134 I2 = I1 + self.nlong 

135 return I1, I2 

136 

137 def slice(self, rank=None): 

138 """Get the list of indices of locally stored elements.""" 

139 I1, I2 = self.minmax(rank=rank) 

140 return np.arange(I1, I2) 

141 

142 def global2local(self, I): 

143 """Get a tuple (rank, local index) from global index I.""" 

144 nIshort = self.nshort * self.shortcount 

145 if I < nIshort: 

146 return I // self.nshort, I % self.nshort 

147 else: 

148 Ioffset = I - nIshort 

149 return (self.shortcount + Ioffset // self.nlong, 

150 Ioffset % self.nlong) 

151 

152 def local2global(self, i, rank=None): 

153 """Get global index I corresponding to local index i on rank.""" 

154 if rank is None: 

155 rank = self.comm.rank 

156 return rank * self.nshort + max(rank - self.shortcount, 0) + i 

157 

158 def as_atom_partition(self, strided=False, name='unnamed-even'): 

159 rank_a = [self.global2local(i)[0] for i in range(self.N)] 

160 if strided: 

161 rank_a = np.arange(self.comm.size).repeat(self.nlong) 

162 rank_a = rank_a.reshape(self.comm.size, -1).T.ravel() 

163 rank_a = rank_a[self.shortcount:].copy() 

164 return AtomPartition(self.comm, rank_a, name=name) 

165 

166 def get_description(self): 

167 lines = [] 

168 for a in range(self.comm.size): 

169 elements = ', '.join(map(str, self.slice(a))) 

170 line = 'rank %d has %d local elements: %s' % (a, self.nlocal(a), 

171 elements) 

172 lines.append(line) 

173 return '\n'.join(lines) 

174 

175 

176# Interface for things that can be redistributed with general_redistribute 

177class Redistributable: 

178 def get_recvbuffer(self, a): 

179 raise NotImplementedError 

180 

181 def get_sendbuffer(self, a): 

182 raise NotImplementedError 

183 

184 def assign(self, a): 

185 raise NotImplementedError 

186 

187 

188# Let's keep this as an independent function for now in case we change the 

189# classes 5 times, like we do 

190def general_redistribute(comm, src_rank_a, dst_rank_a, redistributable): 

191 # To do: it should be possible to specify duplication to several ranks 

192 # But how is this done best? 

193 requests = [] 

194 flags = (src_rank_a != dst_rank_a) 

195 my_incoming_atom_indices = np.argwhere( 

196 np.bitwise_and(flags, dst_rank_a == comm.rank)).ravel() 

197 my_outgoing_atom_indices = np.argwhere( 

198 np.bitwise_and(flags, src_rank_a == comm.rank)).ravel() 

199 

200 for a in my_incoming_atom_indices: 

201 # Get matrix from old domain: 

202 buf = redistributable.get_recvbuffer(a) 

203 requests.append(comm.receive(buf, src_rank_a[a], tag=a, block=False)) 

204 # These arrays are not supposed to pointers into a larger, 

205 # contiguous buffer, so we should make a copy - except we 

206 # must wait until we have completed the send/receiving 

207 # into them, so we will do it a few lines down. 

208 redistributable.assign(a, buf) 

209 

210 for a in my_outgoing_atom_indices: 

211 # Send matrix to new domain: 

212 buf = redistributable.get_sendbuffer(a) 

213 requests.append(comm.send(buf, dst_rank_a[a], tag=a, block=False)) 

214 

215 comm.waitall(requests) 

216 

217 

218class AtomPartition: 

219 """Represents atoms distributed on a standard grid descriptor.""" 

220 def __init__(self, comm, rank_a, name='unnamed'): 

221 self.comm = comm 

222 self.rank_a = np.array(rank_a) 

223 self.my_indices = self.get_indices(comm.rank) 

224 self.natoms = len(rank_a) 

225 self.name = name 

226 

227 def __eq__(self, other: object) -> bool: 

228 if not isinstance(other, AtomPartition): 

229 return NotImplemented 

230 return (self.comm.compare(other.comm) in ['ident', 'congruent'] 

231 and np.array_equal(self.rank_a, other.rank_a)) 

232 

233 def __ne__(self, other): 

234 return not self == other 

235 

236 def as_serial(self): 

237 return AtomPartition(self.comm, np.zeros(self.natoms, int), 

238 name='%s-serial' % self.name) 

239 

240 def get_indices(self, rank): 

241 return np.where(self.rank_a == rank)[0] 

242 

243 def as_even_partition(self): 

244 even_part = EvenPartitioning(self.comm, len(self.rank_a)) 

245 return even_part.as_atom_partition() 

246 

247 def redistribute(self, new_partition, atomdict_ax, get_empty): 

248 # XXX we the two communicators to be equal according to 

249 # some proper criterion like MPI_Comm_compare -> MPI_IDENT. 

250 # But that is not implemented, so we don't. 

251 if self.comm.compare(new_partition.comm) not in ['ident', 

252 'congruent']: 

253 msg = ('Incompatible partitions %s --> %s. ' 

254 'Communicators must be at least congruent' 

255 % (self, new_partition)) 

256 raise ValueError(msg) 

257 

258 # atomdict_ax may be a dictionary or a list of dictionaries 

259 

260 has_many = not hasattr(atomdict_ax, 'items') 

261 if has_many: 

262 class Redist: 

263 def get_recvbuffer(self, a): 

264 return get_empty(a) 

265 

266 def assign(self, a, b_x): 

267 for u, d_ax in enumerate(atomdict_ax): 

268 assert a not in d_ax 

269 atomdict_ax[u][a] = b_x[u] 

270 

271 def get_sendbuffer(self, a): 

272 return np.array([d_ax.data.pop(a) for d_ax in atomdict_ax]) 

273 else: 

274 class Redist: 

275 def get_recvbuffer(self, a): 

276 return get_empty(a) 

277 

278 def assign(self, a, b_x): 

279 assert a not in atomdict_ax 

280 atomdict_ax[a] = b_x 

281 

282 def get_sendbuffer(self, a): 

283 return atomdict_ax.data.pop(a) 

284 

285 try: 

286 general_redistribute(self.comm, self.rank_a, 

287 new_partition.rank_a, Redist()) 

288 except ValueError as err: 

289 raise ValueError('redistribute %s --> %s: %s' 

290 % (self, new_partition, err)) 

291 if isinstance(atomdict_ax, ArrayDict): 

292 atomdict_ax.partition = new_partition # XXX 

293 atomdict_ax.check_consistency() 

294 

295 def __repr__(self): 

296 indextext = ', '.join(map(str, self.my_indices)) 

297 return ('%s %s@rank%d/%d (%d/%d): [%s]' 

298 % (self.__class__.__name__, self.name, self.comm.rank, 

299 self.comm.size, len(self.my_indices), self.natoms, 

300 indextext)) 

301 

302 def arraydict(self, shapes, dtype=float): 

303 return ArrayDict(self, shapes, dtype)