Coverage for gpaw/test/parallel/test_pblas.py: 95%

190 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1"""Test of PBLAS Level 2 & 3 : rk, r2k, gemv, gemm. 

2 

3The test generates random matrices A0, B0, X0, etc. on a 

41-by-1 BLACS grid. They are redistributed to a mprocs-by-nprocs 

5BLACS grid, BLAS operations are performed in parallel, and 

6results are compared against BLAS. 

7""" 

8 

9import pytest 

10import numpy as np 

11 

12from gpaw.mpi import world, rank, broadcast_float 

13from gpaw.blacs import BlacsGrid, Redistributor 

14from gpaw.utilities import compiled_with_sl 

15from gpaw.utilities.blas import r2k, rk 

16from gpaw.utilities.scalapack import \ 

17 pblas_simple_gemm, pblas_gemm, \ 

18 pblas_simple_gemv, pblas_gemv, \ 

19 pblas_simple_r2k, pblas_simple_rk, \ 

20 pblas_simple_hemm, pblas_hemm, \ 

21 pblas_simple_symm, pblas_symm 

22from gpaw.utilities.tools import tri2full 

23 

24pytestmark = pytest.mark.skipif(not compiled_with_sl(), 

25 reason='not compiled with scalapack') 

26 

27# may need to be be increased if the mprocs-by-nprocs 

28# BLACS grid becomes larger 

29tol = 5.0e-13 

30 

31mnprocs_i = [(1, 1)] 

32if world.size >= 2: 

33 mnprocs_i += [(1, 2), (2, 1)] 

34if world.size >= 4: 

35 mnprocs_i += [(2, 2)] 

36if world.size >= 8: 

37 mnprocs_i += [(2, 4), (4, 2)] 

38 

39 

40def initialize_random(seed, dtype): 

41 gen = np.random.Generator(np.random.PCG64(seed)) 

42 if dtype == complex: 

43 def random(*args): 

44 return gen.random(*args) + 1.0j * gen.random(*args) 

45 else: 

46 def random(*args): 

47 return gen.random(*args) 

48 return random 

49 

50 

51def initialize_alpha_beta(simple, random): 

52 if simple: 

53 alpha = 1.0 

54 beta = 0.0 

55 else: 

56 alpha = random() 

57 beta = random() 

58 return alpha, beta 

59 

60 

61def initialize_matrix(grid, M, N, mb, nb, random): 

62 block_desc = grid.new_descriptor(M, N, mb, nb) 

63 local_desc = block_desc.as_serial() 

64 A0 = random(local_desc.shape) 

65 A0 = np.ascontiguousarray(A0) 

66 local_desc.checkassert(A0) 

67 A = local_desc.redistribute(block_desc, A0) 

68 block_desc.checkassert(A) 

69 return A0, A, block_desc 

70 

71 

72def calculate_error(ref_A0, A, block_desc): 

73 local_desc = block_desc.as_serial() 

74 A0 = block_desc.redistribute(local_desc, A) 

75 comm = block_desc.blacsgrid.comm 

76 if comm.rank == 0: 

77 err = np.abs(ref_A0 - A0).max() 

78 else: 

79 err = np.nan 

80 err = broadcast_float(err, comm) 

81 return err 

82 

83 

84@pytest.mark.parametrize('mprocs, nprocs', mnprocs_i) 

85@pytest.mark.parametrize('dtype', [float, complex]) 

86def test_pblas_rk_r2k(dtype, mprocs, nprocs, 

87 M=160, K=140, seed=42): 

88 gen = np.random.RandomState(seed) 

89 grid = BlacsGrid(world, mprocs, nprocs) 

90 

91 if dtype == complex: 

92 epsilon = 1.0j 

93 else: 

94 epsilon = 0.0 

95 

96 # Create descriptors for matrices on master: 

97 globA = grid.new_descriptor(M, K, M, K) 

98 globD = grid.new_descriptor(M, K, M, K) 

99 globS = grid.new_descriptor(M, M, M, M) 

100 globU = grid.new_descriptor(M, M, M, M) 

101 

102 # print globA.asarray() 

103 # Populate matrices local to master: 

104 A0 = gen.rand(*globA.shape) + epsilon * gen.rand(*globA.shape) 

105 D0 = gen.rand(*globD.shape) + epsilon * gen.rand(*globD.shape) 

106 

107 # Local result matrices 

108 S0 = globS.zeros(dtype=dtype) # zeros needed for rank-updates 

109 U0 = globU.zeros(dtype=dtype) # zeros needed for rank-updates 

110 

111 # Local reference matrix product: 

112 if rank == 0: 

113 r2k(1.0, A0, D0, 0.0, S0) 

114 rk(1.0, A0, 0.0, U0) 

115 assert globA.check(A0) 

116 assert globD.check(D0) and globS.check(S0) and globU.check(U0) 

117 

118 # Create distributed destriptors with various block sizes: 

119 distA = grid.new_descriptor(M, K, 2, 2) 

120 distD = grid.new_descriptor(M, K, 2, 3) 

121 distS = grid.new_descriptor(M, M, 2, 2) 

122 distU = grid.new_descriptor(M, M, 2, 2) 

123 

124 # Distributed matrices: 

125 A = distA.empty(dtype=dtype) 

126 D = distD.empty(dtype=dtype) 

127 S = distS.zeros(dtype=dtype) # zeros needed for rank-updates 

128 U = distU.zeros(dtype=dtype) # zeros needed for rank-updates 

129 Redistributor(world, globA, distA).redistribute(A0, A) 

130 Redistributor(world, globD, distD).redistribute(D0, D) 

131 

132 pblas_simple_r2k(distA, distD, distS, A, D, S) 

133 pblas_simple_rk(distA, distU, A, U) 

134 

135 # Collect result back on master 

136 S1 = globS.zeros(dtype=dtype) # zeros needed for rank-updates 

137 U1 = globU.zeros(dtype=dtype) # zeros needed for rank-updates 

138 Redistributor(world, distS, globS).redistribute(S, S1) 

139 Redistributor(world, distU, globU).redistribute(U, U1) 

140 

141 if rank == 0: 

142 r2k_err = abs(S1 - S0).max() 

143 rk_err = abs(U1 - U0).max() 

144 print('r2k err', r2k_err) 

145 print('rk_err', rk_err) 

146 else: 

147 r2k_err = 0.0 

148 rk_err = 0.0 

149 

150 # We don't like exceptions on only one cpu 

151 r2k_err = world.sum_scalar(r2k_err) 

152 rk_err = world.sum_scalar(rk_err) 

153 

154 assert r2k_err == pytest.approx(0, abs=tol) 

155 assert rk_err == pytest.approx(0, abs=tol) 

156 

157 

158@pytest.mark.parametrize('mprocs, nprocs', mnprocs_i) 

159@pytest.mark.parametrize('simple', [True, False]) 

160@pytest.mark.parametrize('transa', ['N', 'T', 'C']) 

161@pytest.mark.parametrize('dtype', [float, complex]) 

162def test_pblas_gemv(dtype, simple, transa, mprocs, nprocs, 

163 M=160, N=120, seed=42): 

164 """Test pblas_simple_gemv, pblas_gemv 

165 

166 The operation is 

167 * y <- alpha*A*x + beta*y 

168 

169 Additional options 

170 * alpha=1 and beta=0 if simple == True 

171 """ 

172 random = initialize_random(seed, dtype) 

173 grid = BlacsGrid(world, mprocs, nprocs) 

174 

175 # Initialize matrices 

176 alpha, beta = initialize_alpha_beta(simple, random) 

177 shapeA = (M, N) 

178 shapeX = {'N': (N, 1), 'T': (M, 1), 'C': (M, 1)}[transa] 

179 shapeY = {'N': (M, 1), 'T': (N, 1), 'C': (N, 1)}[transa] 

180 A0, A, descA = initialize_matrix(grid, *shapeA, 2, 2, random) 

181 X0, X, descX = initialize_matrix(grid, *shapeX, 4, 1, random) 

182 Y0, Y, descY = initialize_matrix(grid, *shapeY, 3, 1, random) 

183 

184 if grid.comm.rank == 0: 

185 print(A0) 

186 

187 # Calculate reference with numpy 

188 op_t = {'N': lambda M: M, 

189 'T': lambda M: np.transpose(M), 

190 'C': lambda M: np.conjugate(np.transpose(M))} 

191 ref_Y0 = alpha * np.dot(op_t[transa](A0), X0) + beta * Y0 

192 else: 

193 ref_Y0 = None 

194 

195 # Calculate with scalapack 

196 if simple: 

197 pblas_simple_gemv(descA, descX, descY, 

198 A, X, Y, 

199 transa=transa) 

200 else: 

201 pblas_gemv(alpha, A, X, beta, Y, 

202 descA, descX, descY, 

203 transa=transa) 

204 

205 # Check error 

206 err = calculate_error(ref_Y0, Y, descY) 

207 assert err < tol 

208 

209 

210@pytest.mark.parametrize('mprocs, nprocs', mnprocs_i) 

211@pytest.mark.parametrize('transb', ['N', 'T', 'C']) 

212@pytest.mark.parametrize('transa', ['N', 'T', 'C']) 

213@pytest.mark.parametrize('simple', [True, False]) 

214@pytest.mark.parametrize('dtype', [float, complex]) 

215def test_pblas_gemm(dtype, simple, transa, transb, mprocs, nprocs, 

216 M=160, N=120, K=140, seed=42): 

217 """Test pblas_simple_gemm, pblas_gemm 

218 

219 The operation is 

220 * C <- alpha*A*B + beta*C 

221 

222 Additional options 

223 * alpha=1 and beta=0 if simple == True 

224 """ 

225 random = initialize_random(seed, dtype) 

226 grid = BlacsGrid(world, mprocs, nprocs) 

227 

228 # Initialize matrices 

229 alpha, beta = initialize_alpha_beta(simple, random) 

230 shapeA = {'N': (M, K), 'T': (K, M), 'C': (K, M)}[transa] 

231 shapeB = {'N': (K, N), 'T': (N, K), 'C': (N, K)}[transb] 

232 shapeC = (M, N) 

233 A0, A, descA = initialize_matrix(grid, *shapeA, 2, 2, random) 

234 B0, B, descB = initialize_matrix(grid, *shapeB, 2, 4, random) 

235 C0, C, descC = initialize_matrix(grid, *shapeC, 3, 2, random) 

236 

237 if grid.comm.rank == 0: 

238 print(A0) 

239 

240 # Calculate reference with numpy 

241 op_t = {'N': lambda M: M, 

242 'T': lambda M: np.transpose(M), 

243 'C': lambda M: np.conjugate(np.transpose(M))} 

244 ref_C0 = alpha * np.dot(op_t[transa](A0), op_t[transb](B0)) + beta * C0 

245 else: 

246 ref_C0 = None 

247 

248 # Calculate with scalapack 

249 if simple: 

250 pblas_simple_gemm(descA, descB, descC, 

251 A, B, C, 

252 transa=transa, transb=transb) 

253 else: 

254 pblas_gemm(alpha, A, B, beta, C, 

255 descA, descB, descC, 

256 transa=transa, transb=transb) 

257 

258 # Check error 

259 err = calculate_error(ref_C0, C, descC) 

260 assert err < tol 

261 

262 

263@pytest.mark.parametrize('mprocs, nprocs', mnprocs_i) 

264@pytest.mark.parametrize('uplo', ['L', 'U']) 

265@pytest.mark.parametrize('side', ['L', 'R']) 

266@pytest.mark.parametrize('simple', [True, False]) 

267@pytest.mark.parametrize('hemm', [True, False]) 

268@pytest.mark.parametrize('dtype', [float, complex]) 

269def test_pblas_hemm_symm(dtype, hemm, simple, side, uplo, mprocs, nprocs, 

270 M=160, N=120, seed=42): 

271 """Test pblas_simple_hemm, pblas_simple_symm, pblas_hemm, pblas_symm 

272 

273 The operation is 

274 * C <- alpha*A*B + beta*C if side == 'L' 

275 * C <- alpha*B*A + beta*C if side == 'R' 

276 

277 The computations are done with 

278 * lower triangular of A if uplo == 'L' 

279 * upper triangular of A if uplo == 'U' 

280 

281 Additional options 

282 * A is Hermitian if hemm == True 

283 * A is symmetric if hemm == False 

284 * alpha=1 and beta=0 if simple == True 

285 """ 

286 random = initialize_random(seed, dtype) 

287 grid = BlacsGrid(world, mprocs, nprocs) 

288 

289 def generate_A_matrix(shape): 

290 A0 = random(shape) 

291 if grid.comm.rank == 0: 

292 if hemm: 

293 # Hermitian matrix 

294 A0 = A0 + A0.T.conj() 

295 else: 

296 # Symmetric matrix 

297 A0 = A0 + A0.T 

298 

299 # Only lower or upper triangular is used, so 

300 # fill the other triangular with NaN to detect errors 

301 if uplo == 'L': 

302 A0 += np.triu(A0 * np.nan, 1) 

303 else: 

304 A0 += np.tril(A0 * np.nan, -1) 

305 A0 = np.ascontiguousarray(A0) 

306 return A0 

307 

308 # Initialize matrices 

309 alpha, beta = initialize_alpha_beta(simple, random) 

310 shapeA = {'L': (M, M), 'R': (N, N)}[side] 

311 shapeB = (M, N) 

312 shapeC = (M, N) 

313 A0, A, descA = initialize_matrix(grid, *shapeA, 2, 2, generate_A_matrix) 

314 B0, B, descB = initialize_matrix(grid, *shapeB, 2, 4, random) 

315 C0, C, descC = initialize_matrix(grid, *shapeC, 3, 2, random) 

316 

317 if grid.comm.rank == 0: 

318 print(A0) 

319 

320 # Calculate reference with numpy 

321 tri2full(A0, uplo, map=np.conj if hemm else np.positive) 

322 if side == 'L': 

323 ref_C0 = alpha * np.dot(A0, B0) + beta * C0 

324 else: 

325 ref_C0 = alpha * np.dot(B0, A0) + beta * C0 

326 else: 

327 ref_C0 = None 

328 

329 # Calculate with scalapack 

330 if simple and hemm: 

331 pblas_simple_hemm(descA, descB, descC, 

332 A, B, C, 

333 uplo=uplo, side=side) 

334 elif hemm: 

335 pblas_hemm(alpha, A, B, beta, C, 

336 descA, descB, descC, 

337 uplo=uplo, side=side) 

338 elif simple: 

339 pblas_simple_symm(descA, descB, descC, 

340 A, B, C, 

341 uplo=uplo, side=side) 

342 else: 

343 pblas_symm(alpha, A, B, beta, C, 

344 descA, descB, descC, 

345 uplo=uplo, side=side) 

346 

347 # Check error 

348 err = calculate_error(ref_C0, C, descC) 

349 assert err < tol