Coverage for gpaw/response/pw_parallelization.py: 57%

143 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1import numpy as np 

2from gpaw.blacs import BlacsDescriptor, BlacsGrid, Redistributor 

3 

4 

5class Blocks1D: 

6 def __init__(self, blockcomm, N): 

7 self.blockcomm = blockcomm 

8 self.N = N # Global number of points 

9 

10 self.blocksize = (N + blockcomm.size - 1) // blockcomm.size 

11 self.a = min(blockcomm.rank * self.blocksize, N) 

12 self.b = min(self.a + self.blocksize, N) 

13 self.nlocal = self.b - self.a 

14 

15 self.myslice = slice(self.a, self.b) 

16 

17 def all_gather(self, in_myix): 

18 """All-gather array where the first dimension is block distributed. 

19 

20 Here, myi is understood as the distributed index, whereas x are the 

21 remaining (global) dimensions.""" 

22 # Set up buffers 

23 buf_myix = self.local_communication_buffer(in_myix) 

24 buf_ix = self.global_communication_buffer(in_myix) 

25 

26 # Excecute all-gather 

27 self.blockcomm.all_gather(buf_myix, buf_ix) 

28 out_ix = buf_ix[:self.N] 

29 

30 return out_ix 

31 

32 def gather(self, in_myix, root=0): 

33 """Gather array to root where the first dimension is block distributed. 

34 """ 

35 assert root in range(self.blockcomm.size) 

36 

37 # Set up buffers 

38 buf_myix = self.local_communication_buffer(in_myix) 

39 if self.blockcomm.rank == root: 

40 buf_ix = self.global_communication_buffer(in_myix) 

41 else: 

42 buf_ix = None 

43 

44 # Excecute gather 

45 self.blockcomm.gather(buf_myix, root, buf_ix) 

46 if self.blockcomm.rank == root: 

47 out_ix = buf_ix[:self.N] 

48 else: 

49 out_ix = None 

50 

51 return out_ix 

52 

53 def local_communication_buffer(self, in_myix): 

54 """Set up local communication buffer.""" 

55 if in_myix.shape[0] == self.blocksize and in_myix.flags.contiguous: 

56 buf_myix = in_myix # Use input array as communication buffer 

57 else: 

58 assert in_myix.shape[0] == self.nlocal 

59 buf_myix = np.empty( 

60 (self.blocksize,) + in_myix.shape[1:], in_myix.dtype) 

61 buf_myix[:self.nlocal] = in_myix 

62 

63 return buf_myix 

64 

65 def global_communication_buffer(self, in_myix): 

66 """Set up global communication buffer.""" 

67 buf_ix = np.empty( 

68 (self.blockcomm.size * self.blocksize,) + in_myix.shape[1:], 

69 dtype=in_myix.dtype) 

70 return buf_ix 

71 

72 def find_global_index(self, i): 

73 """Find rank and local index of the global index i""" 

74 rank = i // self.blocksize 

75 li = i % self.blocksize 

76 

77 return rank, li 

78 

79 

80def block_partition(comm, nblocks): 

81 r"""Partition the communicator into a 2D array with horizontal 

82 and vertical communication. 

83 

84 Communication between blocks (blockcomm) 

85 <-----------------------------------------------> 

86 _______________________________________________ 

87 | | | | | | | | | ⋀ 

88 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | | 

89 |_____|_____|_____|_____|_____|_____|_____|_____| | 

90 | | | | | | | | | | Communication inside 

91 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | | blocks 

92 |_____|_____|_____|_____|_____|_____|_____|_____| | (intrablockcomm) 

93 | | | | | | | | | | 

94 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | | 

95 |_____|_____|_____|_____|_____|_____|_____|_____| ⋁ 

96 

97 """ 

98 if nblocks == 'max': 

99 # Maximize the number of blocks 

100 nblocks = comm.size 

101 assert isinstance(nblocks, int) 

102 assert nblocks > 0 and nblocks <= comm.size, comm.size 

103 assert comm.size % nblocks == 0, comm.size 

104 

105 # Communicator between different blocks 

106 if nblocks == comm.size: 

107 blockcomm = comm 

108 else: 

109 rank1 = comm.rank // nblocks * nblocks 

110 rank2 = rank1 + nblocks 

111 blockcomm = comm.new_communicator(range(rank1, rank2)) 

112 

113 # Communicator inside each block 

114 ranks = range(comm.rank % nblocks, comm.size, nblocks) 

115 if nblocks == 1: 

116 assert len(ranks) == comm.size 

117 intrablockcomm = comm 

118 else: 

119 intrablockcomm = comm.new_communicator(ranks) 

120 

121 assert blockcomm.size * intrablockcomm.size == comm.size 

122 

123 return blockcomm, intrablockcomm 

124 

125 

126class PlaneWaveBlockDistributor: 

127 """Functionality to shuffle block distribution of pair functions 

128 in the plane wave basis.""" 

129 

130 def __init__(self, world, blockcomm, intrablockcomm): 

131 self.world = world 

132 self.blockcomm = blockcomm 

133 self.intrablockcomm = intrablockcomm 

134 

135 @property 

136 def fully_block_distributed(self): 

137 return self.world.compare(self.blockcomm) == 'ident' 

138 

139 def new_distributor(self, *, nblocks): 

140 """Set up a new PlaneWaveBlockDistributor.""" 

141 world = self.world 

142 blockcomm, intrablockcomm = block_partition(comm=world, 

143 nblocks=nblocks) 

144 blockdist = PlaneWaveBlockDistributor(world, blockcomm, intrablockcomm) 

145 

146 return blockdist 

147 

148 def _redistribute(self, in_wGG, nw): 

149 """Redistribute array. 

150 

151 Switch between two kinds of parallel distributions: 

152 

153 1) parallel over G-vectors (second dimension of in_wGG) 

154 2) parallel over frequency (first dimension of in_wGG) 

155 

156 Returns new array using the memory in the 1-d array out_x. 

157 """ 

158 

159 comm = self.blockcomm 

160 

161 if comm.size == 1: 

162 return in_wGG 

163 

164 mynw = (nw + comm.size - 1) // comm.size 

165 nG = in_wGG.shape[2] 

166 mynG = (nG + comm.size - 1) // comm.size 

167 

168 bg1 = BlacsGrid(comm, comm.size, 1) 

169 bg2 = BlacsGrid(comm, 1, comm.size) 

170 md1 = BlacsDescriptor(bg1, nw, nG**2, mynw, nG**2) 

171 md2 = BlacsDescriptor(bg2, nw, nG**2, nw, mynG * nG) 

172 

173 if len(in_wGG) == nw: 

174 mdin = md2 

175 mdout = md1 

176 else: 

177 mdin = md1 

178 mdout = md2 

179 

180 r = Redistributor(comm, mdin, mdout) 

181 

182 # mdout.shape[1] is always divisible by nG because 

183 # every block starts at a multiple of nG, and the last block 

184 # ends at nG² which of course also is divisible. Nevertheless: 

185 assert mdout.shape[1] % nG == 0 

186 # (If it were not divisible, we would "lose" some numbers and the 

187 # redistribution would be corrupted.) 

188 

189 inbuf = in_wGG.reshape(mdin.shape) 

190 # numpy.reshape does not *guarantee* that the reshaped view will 

191 # be contiguous. To support redistribution of input arrays with an 

192 # arbitrary allocation layout, we make sure that the corresponding 

193 # input BLACS buffer in contiguous 

194 inbuf = np.ascontiguousarray(inbuf) 

195 

196 outbuf = np.empty(mdout.shape, complex) 

197 

198 r.redistribute(inbuf, outbuf) 

199 

200 outshape = (mdout.shape[0], mdout.shape[1] // nG, nG) 

201 out_wGG = outbuf.reshape(outshape) 

202 assert out_wGG.flags.contiguous # Since mdout.shape[1] % nG == 0 

203 

204 return out_wGG 

205 

206 def has_distribution(self, array, nw, distribution): 

207 """Check if array 'array' has distribution 'distribution'.""" 

208 if distribution not in ['wGG', 'WgG', 'zGG', 'ZgG']: 

209 raise ValueError(f'Invalid dist_type: {distribution}') 

210 

211 comm = self.blockcomm 

212 if comm.size == 1: 

213 # In serial, all distributions are equivalent 

214 return True 

215 

216 # At the moment, only wGG and WgG distributions are supported. zGG and 

217 # ZgG are complex frequency aliases for wGG and WgG respectively 

218 assert len(array.shape) == 3 

219 nG = array.shape[2] 

220 if array.shape[1] < array.shape[2]: 

221 # Looks like array is WgG/ZgG distributed 

222 gblocks = Blocks1D(comm, nG) 

223 assert array.shape == (nw, gblocks.nlocal, nG) 

224 return distribution in ['WgG', 'ZgG'] 

225 else: 

226 # Looks like array is wGG/zGG distributed 

227 wblocks = Blocks1D(comm, nw) 

228 assert array.shape == (wblocks.nlocal, nG, nG) 

229 return distribution in ['wGG', 'zGG'] 

230 

231 def distribute_as(self, array, nw, distribution): 

232 """Redistribute array. 

233 

234 Switch between two kinds of parallel distributions: 

235 

236 1) parallel over G-vectors (distribution in ['WgG', 'ZgG']) 

237 2) parallel over frequency (distribution in ['wGG', 'zGG']) 

238 """ 

239 if self.has_distribution(array, nw, distribution): 

240 # If the array already has the requested distribution, do nothing 

241 return array 

242 else: 

243 return self._redistribute(array, nw) 

244 

245 def distribute_frequencies(self, in_wGG, nw): 

246 """Distribute frequencies to all cores.""" 

247 

248 world = self.world 

249 comm = self.blockcomm 

250 

251 if world.size == 1: 

252 return in_wGG 

253 

254 mynw = (nw + world.size - 1) // world.size 

255 nG = in_wGG.shape[2] 

256 mynG = (nG + comm.size - 1) // comm.size 

257 

258 wa = min(world.rank * mynw, nw) 

259 wb = min(wa + mynw, nw) 

260 

261 if self.blockcomm.size == 1: 

262 return in_wGG[wa:wb].copy() 

263 

264 if self.intrablockcomm.rank == 0: 

265 bg1 = BlacsGrid(comm, 1, comm.size) 

266 in_wGG = in_wGG.reshape((nw, -1)) 

267 else: 

268 bg1 = BlacsGrid(None, 1, 1) 

269 # bg1 = DryRunBlacsGrid(mpi.serial_comm, 1, 1) 

270 in_wGG = np.zeros((0, 0), complex) 

271 md1 = BlacsDescriptor(bg1, nw, nG**2, nw, mynG * nG) 

272 

273 bg2 = BlacsGrid(world, world.size, 1) 

274 md2 = BlacsDescriptor(bg2, nw, nG**2, mynw, nG**2) 

275 

276 r = Redistributor(world, md1, md2) 

277 shape = (wb - wa, nG, nG) 

278 out_wGG = np.empty(shape, complex) 

279 r.redistribute(in_wGG, out_wGG.reshape((wb - wa, nG**2))) 

280 

281 return out_wGG