Coverage for gpaw/response/wgg.py: 28%

139 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1"""Parallelization scheme for frequency–planewave–planewave arrays.""" 

2from gpaw.mpi import world 

3from gpaw.response.pw_parallelization import block_partition 

4from gpaw.utilities.scalapack import scalapack_set, scalapack_solve 

5from gpaw.blacs import BlacsGrid 

6import numpy as np 

7 

8 

9def get_blocksize(length, commsize): 

10 return -(-length // commsize) 

11 

12 

13def get_strides(cpugrid): 

14 return np.array([cpugrid[1] * cpugrid[2], cpugrid[2], 1], int) 

15 

16 

17class Grid: 

18 def __init__(self, comm, shape, cpugrid=None, blocksize=None): 

19 self.comm = comm 

20 self.shape = shape 

21 

22 if cpugrid is None: 

23 cpugrid = choose_parallelization(shape[0], shape[1], 

24 comm.size) 

25 

26 self.cpugrid = cpugrid 

27 

28 if blocksize is None: 

29 blocksize = [get_blocksize(size, commsize) 

30 for size, commsize in zip(shape, cpugrid)] 

31 # XXX scalapack blocksize hack 

32 blocksize[1] = blocksize[2] = max(blocksize[1:]) 

33 

34 # XXX Our scalapack interface does NOT like it when blocksizes 

35 # are not the same. There must be a bug. 

36 assert blocksize[1] == blocksize[2] 

37 

38 self.blocksize = tuple(blocksize) 

39 self.cpugrid = cpugrid 

40 

41 self.myparpos = self.rank2parpos(self.comm.rank) 

42 

43 n_cp = self.get_domains() 

44 

45 shape = [] 

46 myslice_c = [] 

47 for i in range(3): 

48 n_p = n_cp[i] 

49 parpos = self.myparpos[i] 

50 myend = n_p[parpos + 1] 

51 mystart = n_p[parpos] 

52 myslice_c.append(slice(mystart, myend)) 

53 size = myend - mystart 

54 shape.append(size) 

55 

56 self.myshape = tuple(shape) 

57 self.myslice = tuple(myslice_c) 

58 

59 # TODO inherit these from array descriptor 

60 def zeros(self, dtype=float): 

61 return np.zeros(self.myshape, dtype=dtype) 

62 

63 def get_domains(self): 

64 """Get definition of domains. 

65 

66 Returns domains_cp where domains_cp[c][r + 1] - domains_cp[c][r] 

67 is the number of points in domain r along direction c. 

68 

69 The second axis contains the "fencepost" locations 

70 of the grid: [0, blocksize, 2 * blocksize, ...] 

71 """ 

72 domains_cp = [] 

73 

74 for i in range(3): 

75 n_p = np.empty(self.cpugrid[i] + 1, int) 

76 n_p[0] = 0 

77 n_p[1:] = self.blocksize[i] 

78 n_p[:] = n_p.cumsum().clip(0, self.shape[i]) 

79 domains_cp.append(n_p) 

80 

81 return domains_cp 

82 

83 def rank2parpos(self, rank): 

84 # XXX Borrowing from gd -- we should eliminate this duplication. 

85 

86 strides = get_strides(self.cpugrid) 

87 cpugrid_coord = np.array( 

88 [rank // strides[0], 

89 (rank % strides[0]) // strides[1], 

90 rank % strides[1]]) 

91 

92 return cpugrid_coord 

93 

94 def redistribute(self, dstgrid, srcarray, dstarray): 

95 from gpaw.utilities.grid_redistribute import general_redistribute 

96 domains1 = self.get_domains() 

97 domains2 = dstgrid.get_domains() 

98 general_redistribute(self.comm, domains1, domains2, 

99 self.rank2parpos, dstgrid.rank2parpos, 

100 srcarray, dstarray, behavior='overwrite') 

101 

102 def invert_inplace(self, x_wgg): 

103 # Build wgg grid choosing scalapack 

104 nscalapack_cores = np.prod(self.cpugrid[1:]) 

105 blacs_comm, wcomm = block_partition(self.comm, nscalapack_cores) 

106 assert wcomm.size == self.cpugrid[0] 

107 assert blacs_comm.size * wcomm.size == self.comm.size 

108 for iw, x_gg in enumerate(x_wgg): 

109 bg = BlacsGrid(blacs_comm, *self.cpugrid[1:][::-1]) 

110 desc = bg.new_descriptor( 

111 *self.shape[1:], 

112 *self.blocksize[1:]) 

113 

114 xtmp_gg = desc.empty(dtype=x_wgg.dtype) 

115 xtmp_gg[:] = x_gg.T 

116 

117 righthand = desc.zeros(dtype=complex) 

118 scalapack_set(desc, righthand, alpha=0.0, beta=1.0, uplo='U') 

119 

120 scalapack_solve(desc, desc, xtmp_gg, righthand) 

121 x_gg[:] = righthand.T 

122 

123 

124def get_x_WGG(WGG_grid): 

125 x_WGG = WGG_grid.zeros(dtype=complex) 

126 rng = np.random.RandomState(42) 

127 

128 x_WGG.flat[:] = rng.random(x_WGG.size) 

129 x_WGG.flat[:] += rng.random(x_WGG.size) * 1j 

130 # XXX write also to imaginary parts 

131 

132 nG = x_WGG.shape[1] 

133 

134 xinv_WGG = np.zeros_like(x_WGG) 

135 if WGG_grid.comm.rank == 0: 

136 assert x_WGG.shape == WGG_grid.myshape 

137 for iw, x_GG in enumerate(x_WGG): 

138 x_GG += x_GG.T.conj().copy() 

139 x_GG += np.identity(nG) * 5 

140 eigs = np.linalg.eigvals(x_GG) 

141 assert all(eigs.real) > 0 

142 xinv_WGG[iw] = np.linalg.inv(x_GG) 

143 else: 

144 assert np.prod(x_WGG.shape) == 0 

145 return x_WGG, xinv_WGG 

146 

147 

148def factorize(N): 

149 for n in range(1, N + 1): 

150 if N % n == 0: 

151 yield N // n, n 

152 

153 

154def get_products(N): 

155 for a1, a2 in factorize(N): 

156 for a2p, a3 in factorize(a2): 

157 yield a1, a2p, a3 

158 

159 

160def choose_parallelization(nW, nG, commsize): 

161 min_badness = 10000000 

162 

163 for wGG in get_products(commsize): 

164 wsize, gsize1, gsize2 = wGG 

165 nw = (nW + wsize - 1) // wsize 

166 

167 if nw > nW: 

168 continue 

169 

170 number_of_cores_with_zeros = (wsize * nw - nW) // nw 

171 scalapack_skew = (gsize1 - gsize2)**2 

172 scalapack_size = gsize1 * gsize2 

173 badness = (number_of_cores_with_zeros * 1000 

174 + 10 * scalapack_skew + scalapack_size) 

175 

176 # print(wsize, gsize1, gsize2, nw, number_of_cores_with_zeros, badness) 

177 if badness < min_badness: 

178 wGG_min = wGG 

179 min_badness = badness 

180 return wGG_min 

181 

182 

183def main(comm=world): 

184 nW = 3 

185 nG = 31 

186 

187 cpugrid = choose_parallelization(nW, nG, comm.size) 

188 

189 WGG = (nW, nG, nG) 

190 dtype = complex 

191 

192 # Build serial grid (data only on rank 0) 

193 # and establish matrix and its inverse 

194 WGG_grid = Grid(comm, WGG, cpugrid, blocksize=WGG) 

195 x_WGG, xinv_WGG = get_x_WGG(WGG_grid) 

196 

197 # Distribute to WgG grid: 

198 WgG_grid = Grid(comm, WGG, (1, comm.size, 1)) 

199 x_WgG = np.zeros(WgG_grid.myshape, dtype=dtype) 

200 WGG_grid.redistribute(WgG_grid, x_WGG, x_WgG) 

201 

202 wgg_grid = Grid(comm, WGG, cpugrid) 

203 print(f'cpugrid={cpugrid} blocksize={wgg_grid.blocksize} ' 

204 f'shape={wgg_grid.shape} myshape={wgg_grid.myshape}') 

205 

206 x_wgg = wgg_grid.zeros(dtype=dtype) 

207 WgG_grid.redistribute(wgg_grid, x_WgG, x_wgg) 

208 

209 # By now let's distribute wgg back to WgG to check that numbers 

210 # are the same: 

211 x1_WgG = WgG_grid.zeros(dtype=dtype) 

212 wgg_grid.redistribute(WgG_grid, x_wgg, x1_WgG) 

213 assert np.allclose(x_WgG, x1_WgG) 

214 

215 wgg_grid.invert_inplace(x_wgg) 

216 

217 # Distribute the inverse wgg back to WGG: 

218 inv_x_WGG = WGG_grid.zeros(dtype=dtype) 

219 wgg_grid.redistribute(WGG_grid, x_wgg, inv_x_WGG) 

220 

221 from gpaw.utilities.tools import tri2full 

222 if comm.rank == 0: 

223 for inv_x_GG in inv_x_WGG: 

224 tri2full(inv_x_GG, 'L') 

225 

226 for x_GG, inv_x_GG in zip(x_WGG, inv_x_WGG): 

227 assert np.allclose(x_GG @ inv_x_GG, np.identity(nG)) 

228 

229 

230if __name__ == '__main__': 

231 main()