Coverage for gpaw/band_descriptor.py: 36%

155 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1# Copyright (C) 2003 CAMP 

2# Please see the accompanying LICENSE file for further information. 

3 

4"""Band-descriptors for blocked/strided groups. 

5 

6This module contains classes defining two kinds of band groups: 

7 

8* Blocked groups with contiguous band indices. 

9* Strided groups with evenly-spaced band indices. 

10""" 

11 

12import numpy as np 

13 

14from gpaw.mpi import serial_comm 

15 

16NONBLOCKING = False 

17 

18 

19class BandDescriptor: 

20 r"""Descriptor-class for ordered lists of bands 

21 

22 A ``BandDescriptor`` object holds information on how functions, such 

23 as wave functions and corresponding occupation numbers, are divided 

24 into groups according to band indices. The main information here is 

25 how many bands are stored on each processor and who gets what. 

26 

27 This is how a 12 band array is laid out in memory on 3 cpu's:: 

28 

29 a) Blocked groups b) Strided groups 

30 

31 3 7 11 9 10 11 

32 myn 2 \ 6 \ 10 myn 6 7 8 

33 | 1 \ 5 \ 9 | 3 4 5 

34 | 0 4 8 | 0 1 2 

35 | | 

36 +----- band_rank +----- band_rank 

37 

38 Example: 

39 

40 >>> a = np.zeros((3, 4)) 

41 >>> a.ravel()[:] = range(12) 

42 >>> a 

43 array([[ 0., 1., 2., 3.], 

44 [ 4., 5., 6., 7.], 

45 [ 8., 9., 10., 11.]]) 

46 >>> b = np.zeros((4, 3)) 

47 >>> b.ravel()[:] = range(12) 

48 >>> b.T 

49 array([[ 0., 3., 6., 9.], 

50 [ 1., 4., 7., 10.], 

51 [ 2., 5., 8., 11.]]) 

52 """ 

53 

54 def __init__(self, nbands: int, comm=None, strided=False): 

55 """Construct band-descriptor object. 

56 

57 Parameters: 

58 

59 nbands: int 

60 Global number of bands. 

61 comm: MPI-communicator 

62 Communicator for band-groups. 

63 strided: bool 

64 Enable strided band distribution for better 

65 load balancing with many unoccupied bands. 

66 

67 Note that if comm.size is 1, then all bands are contained on 

68 a single CPU and blocked/strided grouping loses its meaning. 

69 

70 Attributes: 

71 

72 ============ ====================================================== 

73 ``nbands`` Number of bands in total. 

74 ``mynbands`` Number of bands on this CPU. 

75 ``beg`` Beginning of band indices in group (inclusive). 

76 ``end`` End of band indices in group (exclusive). 

77 ``step`` Stride for band indices between ``beg`` and ``end``. 

78 ``comm`` MPI-communicator for band distribution. 

79 ============ ====================================================== 

80 """ 

81 

82 if comm is None: 

83 comm = serial_comm 

84 

85 self.comm = comm 

86 self.nbands = nbands 

87 self.strided = strided 

88 

89 self.maxmynbands = (nbands + comm.size - 1) // comm.size 

90 

91 if strided: 

92 assert nbands % comm.size == 0 

93 self.beg = comm.rank 

94 self.end = nbands 

95 self.step = comm.size 

96 self.mynbands = self.maxmynbands 

97 else: 

98 self.beg = min(nbands, comm.rank * self.maxmynbands) 

99 self.end = min(nbands, self.beg + self.maxmynbands) 

100 self.mynbands = self.end - self.beg 

101 self.step = 1 

102 

103 def __len__(self): 

104 return self.mynbands 

105 

106 def get_slice(self, band_rank=None): 

107 """Return the slice of global bands which belong to a given rank.""" 

108 if band_rank is None: 

109 band_rank = self.comm.rank 

110 assert band_rank in range(self.comm.size) 

111 

112 if self.strided: 

113 nstride = self.comm.size 

114 nslice = slice(band_rank, None, nstride) 

115 else: 

116 nslice = slice(self.beg, self.end) 

117 return nslice 

118 

119 def get_band_indices(self, band_rank=None): 

120 """Return the global band indices which belong to a given rank.""" 

121 nslice = self.get_slice(band_rank) 

122 return np.arange(*nslice.indices(self.nbands)) 

123 

124 def get_band_ranks(self): 

125 """Return array of ranks as a function of global band indices.""" 

126 rank_n = np.empty(self.nbands, dtype=int) 

127 for band_rank in range(self.comm.size): 

128 nslice = self.get_slice(band_rank) 

129 rank_n[nslice] = band_rank 

130 assert (rank_n >= 0).all() and (rank_n < self.comm.size).all() 

131 return rank_n 

132 

133 def who_has(self, n): 

134 """Convert global band index to rank information and local index.""" 

135 if self.strided: 

136 myn, band_rank = divmod(n, self.comm.size) 

137 else: 

138 band_rank, myn = divmod(n, self.maxmynbands) 

139 return band_rank, myn 

140 

141 def global_index(self, myn, band_rank=None): 

142 """Convert rank information and local index to global index.""" 

143 if band_rank is None: 

144 band_rank = self.comm.rank 

145 

146 if self.strided: 

147 return band_rank + myn * self.comm.size 

148 

149 return band_rank * self.maxmynbands + myn 

150 

151 def get_size_of_global_array(self): 

152 return (self.nbands,) 

153 

154 def zeros(self, n=(), dtype=float, global_array=False): 

155 """Return new zeroed 3D array for this domain. 

156 

157 The type can be set with the ``dtype`` keyword (default: 

158 ``float``). Extra dimensions can be added with ``n=dim``. 

159 A global array spanning all domains can be allocated with 

160 ``global_array=True``.""" 

161 # TODO XXX doc 

162 return self._new_array(n, dtype, True, global_array) 

163 

164 def empty(self, n=(), dtype=float, global_array=False): 

165 """Return new uninitialized 3D array for this domain. 

166 

167 The type can be set with the ``dtype`` keyword (default: 

168 ``float``). Extra dimensions can be added with ``n=dim``. 

169 A global array spanning all domains can be allocated with 

170 ``global_array=True``.""" 

171 # TODO XXX doc 

172 return self._new_array(n, dtype, False, global_array) 

173 

174 def _new_array(self, n=(), dtype=float, zero=True, global_array=False): 

175 if global_array: 

176 shape = self.get_size_of_global_array() 

177 else: 

178 shape = (self.mynbands,) 

179 

180 if isinstance(n, int): 

181 n = (n,) 

182 

183 shape = tuple(shape) + n 

184 

185 if zero: 

186 return np.zeros(shape, dtype) 

187 else: 

188 return np.empty(shape, dtype) 

189 

190 def collect(self, a_nx, broadcast=False): 

191 """Collect distributed array to master-CPU or all CPU's.""" 

192 if self.comm.size == 1: 

193 return a_nx 

194 

195 xshape = a_nx.shape[1:] 

196 

197 # Optimization for blocked groups 

198 if not self.strided: 

199 if self.comm.size * self.maxmynbands > self.nbands: 

200 return self.nasty_non_strided_collect(a_nx, broadcast) 

201 if broadcast: 

202 A_nx = self.empty(xshape, a_nx.dtype, global_array=True) 

203 self.comm.all_gather(a_nx, A_nx) 

204 return A_nx 

205 

206 if self.comm.rank == 0: 

207 A_nx = self.empty(xshape, a_nx.dtype, global_array=True) 

208 else: 

209 A_nx = None 

210 self.comm.gather(a_nx, 0, A_nx) 

211 return A_nx 

212 

213 # Collect all arrays on the master: 

214 if self.comm.rank != 0: 

215 # There can be several sends before the corresponding receives 

216 # are posted, so use syncronous send here 

217 self.comm.ssend(a_nx, 0, 3011) 

218 if broadcast: 

219 A_nx = self.empty(xshape, a_nx.dtype, global_array=True) 

220 self.comm.broadcast(A_nx, 0) 

221 return A_nx 

222 else: 

223 return None 

224 

225 # Put the band groups from the slaves into the big array 

226 # for the whole collection of bands: 

227 A_nx = self.empty(xshape, a_nx.dtype, global_array=True) 

228 for band_rank in range(self.comm.size): 

229 if band_rank != 0: 

230 a_nx = self.empty(xshape, a_nx.dtype, global_array=False) 

231 self.comm.receive(a_nx, band_rank, 3011) 

232 A_nx[self.get_slice(band_rank), ...] = a_nx 

233 

234 if broadcast: 

235 self.comm.broadcast(A_nx, 0) 

236 return A_nx 

237 

238 def distribute(self, B_nx, b_nx): 

239 """ distribute full array B_nx to band groups, result in 

240 b_nx. b_nx must be allocated.""" 

241 

242 S = self.comm.size 

243 

244 if S == 1: 

245 b_nx[:] = B_nx 

246 return 

247 

248 # Optimization for blocked groups 

249 if not self.strided: 

250 M2 = self.maxmynbands 

251 if M2 * S == self.nbands: 

252 self.comm.scatter(B_nx, b_nx, 0) 

253 return 

254 

255 if self.comm.rank == 0: 

256 C_nx = np.empty((S * M2,) + B_nx.shape[1:], B_nx.dtype) 

257 C_nx[:self.nbands] = B_nx 

258 else: 

259 C_nx = None 

260 

261 if self.mynbands < M2: 

262 c_nx = np.empty((M2,) + b_nx.shape[1:], b_nx.dtype) 

263 else: 

264 c_nx = b_nx 

265 

266 self.comm.scatter(C_nx, c_nx, 0) 

267 

268 if self.mynbands < M2: 

269 b_nx[:] = c_nx[:self.mynbands] 

270 

271 return 

272 

273 if self.comm.rank != 0: 

274 self.comm.receive(b_nx, 0, 421) 

275 return 

276 else: 

277 requests = [] 

278 for band_rank in range(self.comm.size): 

279 if band_rank != 0: 

280 a_nx = B_nx[self.get_slice(band_rank), ...].copy() 

281 request = self.comm.send(a_nx, band_rank, 421, NONBLOCKING) 

282 # Remember to store a reference to the 

283 # send buffer (a_nx) so that is isn't 

284 # deallocated: 

285 requests.append((request, a_nx)) 

286 else: 

287 b_nx[:] = B_nx[self.get_slice(), ...] 

288 

289 for request, a_nx in requests: 

290 self.comm.wait(request) 

291 

292 def nasty_non_strided_collect(self, a_nx, broadcast): 

293 xshape = a_nx.shape[1:] 

294 if broadcast: 

295 A_nx = self.nasty_non_strided_collect(a_nx, False) 

296 if A_nx is None: 

297 A_nx = self.empty(xshape, a_nx.dtype, global_array=True) 

298 self.comm.broadcast(A_nx, 0) 

299 return A_nx 

300 

301 S = self.comm.size 

302 M2 = self.maxmynbands 

303 if self.comm.rank == 0: 

304 A_nx = np.empty((S * M2,) + xshape, a_nx.dtype) 

305 else: 

306 A_nx = None 

307 

308 if self.mynbands < M2: 

309 b_nx = np.empty((M2,) + xshape, a_nx.dtype) 

310 b_nx[:self.mynbands] = a_nx 

311 else: 

312 b_nx = a_nx 

313 

314 self.comm.gather(b_nx, 0, A_nx) 

315 

316 if self.comm.rank > 0: 

317 return 

318 

319 return A_nx[:self.nbands]