Coverage for gpaw/test/parallel/test_pblas.py: 95%
190 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1"""Test of PBLAS Level 2 & 3 : rk, r2k, gemv, gemm.
3The test generates random matrices A0, B0, X0, etc. on a
41-by-1 BLACS grid. They are redistributed to a mprocs-by-nprocs
5BLACS grid, BLAS operations are performed in parallel, and
6results are compared against BLAS.
7"""
9import pytest
10import numpy as np
12from gpaw.mpi import world, rank, broadcast_float
13from gpaw.blacs import BlacsGrid, Redistributor
14from gpaw.utilities import compiled_with_sl
15from gpaw.utilities.blas import r2k, rk
16from gpaw.utilities.scalapack import \
17 pblas_simple_gemm, pblas_gemm, \
18 pblas_simple_gemv, pblas_gemv, \
19 pblas_simple_r2k, pblas_simple_rk, \
20 pblas_simple_hemm, pblas_hemm, \
21 pblas_simple_symm, pblas_symm
22from gpaw.utilities.tools import tri2full
24pytestmark = pytest.mark.skipif(not compiled_with_sl(),
25 reason='not compiled with scalapack')
27# may need to be be increased if the mprocs-by-nprocs
28# BLACS grid becomes larger
29tol = 5.0e-13
31mnprocs_i = [(1, 1)]
32if world.size >= 2:
33 mnprocs_i += [(1, 2), (2, 1)]
34if world.size >= 4:
35 mnprocs_i += [(2, 2)]
36if world.size >= 8:
37 mnprocs_i += [(2, 4), (4, 2)]
40def initialize_random(seed, dtype):
41 gen = np.random.Generator(np.random.PCG64(seed))
42 if dtype == complex:
43 def random(*args):
44 return gen.random(*args) + 1.0j * gen.random(*args)
45 else:
46 def random(*args):
47 return gen.random(*args)
48 return random
51def initialize_alpha_beta(simple, random):
52 if simple:
53 alpha = 1.0
54 beta = 0.0
55 else:
56 alpha = random()
57 beta = random()
58 return alpha, beta
61def initialize_matrix(grid, M, N, mb, nb, random):
62 block_desc = grid.new_descriptor(M, N, mb, nb)
63 local_desc = block_desc.as_serial()
64 A0 = random(local_desc.shape)
65 A0 = np.ascontiguousarray(A0)
66 local_desc.checkassert(A0)
67 A = local_desc.redistribute(block_desc, A0)
68 block_desc.checkassert(A)
69 return A0, A, block_desc
72def calculate_error(ref_A0, A, block_desc):
73 local_desc = block_desc.as_serial()
74 A0 = block_desc.redistribute(local_desc, A)
75 comm = block_desc.blacsgrid.comm
76 if comm.rank == 0:
77 err = np.abs(ref_A0 - A0).max()
78 else:
79 err = np.nan
80 err = broadcast_float(err, comm)
81 return err
84@pytest.mark.parametrize('mprocs, nprocs', mnprocs_i)
85@pytest.mark.parametrize('dtype', [float, complex])
86def test_pblas_rk_r2k(dtype, mprocs, nprocs,
87 M=160, K=140, seed=42):
88 gen = np.random.RandomState(seed)
89 grid = BlacsGrid(world, mprocs, nprocs)
91 if dtype == complex:
92 epsilon = 1.0j
93 else:
94 epsilon = 0.0
96 # Create descriptors for matrices on master:
97 globA = grid.new_descriptor(M, K, M, K)
98 globD = grid.new_descriptor(M, K, M, K)
99 globS = grid.new_descriptor(M, M, M, M)
100 globU = grid.new_descriptor(M, M, M, M)
102 # print globA.asarray()
103 # Populate matrices local to master:
104 A0 = gen.rand(*globA.shape) + epsilon * gen.rand(*globA.shape)
105 D0 = gen.rand(*globD.shape) + epsilon * gen.rand(*globD.shape)
107 # Local result matrices
108 S0 = globS.zeros(dtype=dtype) # zeros needed for rank-updates
109 U0 = globU.zeros(dtype=dtype) # zeros needed for rank-updates
111 # Local reference matrix product:
112 if rank == 0:
113 r2k(1.0, A0, D0, 0.0, S0)
114 rk(1.0, A0, 0.0, U0)
115 assert globA.check(A0)
116 assert globD.check(D0) and globS.check(S0) and globU.check(U0)
118 # Create distributed destriptors with various block sizes:
119 distA = grid.new_descriptor(M, K, 2, 2)
120 distD = grid.new_descriptor(M, K, 2, 3)
121 distS = grid.new_descriptor(M, M, 2, 2)
122 distU = grid.new_descriptor(M, M, 2, 2)
124 # Distributed matrices:
125 A = distA.empty(dtype=dtype)
126 D = distD.empty(dtype=dtype)
127 S = distS.zeros(dtype=dtype) # zeros needed for rank-updates
128 U = distU.zeros(dtype=dtype) # zeros needed for rank-updates
129 Redistributor(world, globA, distA).redistribute(A0, A)
130 Redistributor(world, globD, distD).redistribute(D0, D)
132 pblas_simple_r2k(distA, distD, distS, A, D, S)
133 pblas_simple_rk(distA, distU, A, U)
135 # Collect result back on master
136 S1 = globS.zeros(dtype=dtype) # zeros needed for rank-updates
137 U1 = globU.zeros(dtype=dtype) # zeros needed for rank-updates
138 Redistributor(world, distS, globS).redistribute(S, S1)
139 Redistributor(world, distU, globU).redistribute(U, U1)
141 if rank == 0:
142 r2k_err = abs(S1 - S0).max()
143 rk_err = abs(U1 - U0).max()
144 print('r2k err', r2k_err)
145 print('rk_err', rk_err)
146 else:
147 r2k_err = 0.0
148 rk_err = 0.0
150 # We don't like exceptions on only one cpu
151 r2k_err = world.sum_scalar(r2k_err)
152 rk_err = world.sum_scalar(rk_err)
154 assert r2k_err == pytest.approx(0, abs=tol)
155 assert rk_err == pytest.approx(0, abs=tol)
158@pytest.mark.parametrize('mprocs, nprocs', mnprocs_i)
159@pytest.mark.parametrize('simple', [True, False])
160@pytest.mark.parametrize('transa', ['N', 'T', 'C'])
161@pytest.mark.parametrize('dtype', [float, complex])
162def test_pblas_gemv(dtype, simple, transa, mprocs, nprocs,
163 M=160, N=120, seed=42):
164 """Test pblas_simple_gemv, pblas_gemv
166 The operation is
167 * y <- alpha*A*x + beta*y
169 Additional options
170 * alpha=1 and beta=0 if simple == True
171 """
172 random = initialize_random(seed, dtype)
173 grid = BlacsGrid(world, mprocs, nprocs)
175 # Initialize matrices
176 alpha, beta = initialize_alpha_beta(simple, random)
177 shapeA = (M, N)
178 shapeX = {'N': (N, 1), 'T': (M, 1), 'C': (M, 1)}[transa]
179 shapeY = {'N': (M, 1), 'T': (N, 1), 'C': (N, 1)}[transa]
180 A0, A, descA = initialize_matrix(grid, *shapeA, 2, 2, random)
181 X0, X, descX = initialize_matrix(grid, *shapeX, 4, 1, random)
182 Y0, Y, descY = initialize_matrix(grid, *shapeY, 3, 1, random)
184 if grid.comm.rank == 0:
185 print(A0)
187 # Calculate reference with numpy
188 op_t = {'N': lambda M: M,
189 'T': lambda M: np.transpose(M),
190 'C': lambda M: np.conjugate(np.transpose(M))}
191 ref_Y0 = alpha * np.dot(op_t[transa](A0), X0) + beta * Y0
192 else:
193 ref_Y0 = None
195 # Calculate with scalapack
196 if simple:
197 pblas_simple_gemv(descA, descX, descY,
198 A, X, Y,
199 transa=transa)
200 else:
201 pblas_gemv(alpha, A, X, beta, Y,
202 descA, descX, descY,
203 transa=transa)
205 # Check error
206 err = calculate_error(ref_Y0, Y, descY)
207 assert err < tol
210@pytest.mark.parametrize('mprocs, nprocs', mnprocs_i)
211@pytest.mark.parametrize('transb', ['N', 'T', 'C'])
212@pytest.mark.parametrize('transa', ['N', 'T', 'C'])
213@pytest.mark.parametrize('simple', [True, False])
214@pytest.mark.parametrize('dtype', [float, complex])
215def test_pblas_gemm(dtype, simple, transa, transb, mprocs, nprocs,
216 M=160, N=120, K=140, seed=42):
217 """Test pblas_simple_gemm, pblas_gemm
219 The operation is
220 * C <- alpha*A*B + beta*C
222 Additional options
223 * alpha=1 and beta=0 if simple == True
224 """
225 random = initialize_random(seed, dtype)
226 grid = BlacsGrid(world, mprocs, nprocs)
228 # Initialize matrices
229 alpha, beta = initialize_alpha_beta(simple, random)
230 shapeA = {'N': (M, K), 'T': (K, M), 'C': (K, M)}[transa]
231 shapeB = {'N': (K, N), 'T': (N, K), 'C': (N, K)}[transb]
232 shapeC = (M, N)
233 A0, A, descA = initialize_matrix(grid, *shapeA, 2, 2, random)
234 B0, B, descB = initialize_matrix(grid, *shapeB, 2, 4, random)
235 C0, C, descC = initialize_matrix(grid, *shapeC, 3, 2, random)
237 if grid.comm.rank == 0:
238 print(A0)
240 # Calculate reference with numpy
241 op_t = {'N': lambda M: M,
242 'T': lambda M: np.transpose(M),
243 'C': lambda M: np.conjugate(np.transpose(M))}
244 ref_C0 = alpha * np.dot(op_t[transa](A0), op_t[transb](B0)) + beta * C0
245 else:
246 ref_C0 = None
248 # Calculate with scalapack
249 if simple:
250 pblas_simple_gemm(descA, descB, descC,
251 A, B, C,
252 transa=transa, transb=transb)
253 else:
254 pblas_gemm(alpha, A, B, beta, C,
255 descA, descB, descC,
256 transa=transa, transb=transb)
258 # Check error
259 err = calculate_error(ref_C0, C, descC)
260 assert err < tol
263@pytest.mark.parametrize('mprocs, nprocs', mnprocs_i)
264@pytest.mark.parametrize('uplo', ['L', 'U'])
265@pytest.mark.parametrize('side', ['L', 'R'])
266@pytest.mark.parametrize('simple', [True, False])
267@pytest.mark.parametrize('hemm', [True, False])
268@pytest.mark.parametrize('dtype', [float, complex])
269def test_pblas_hemm_symm(dtype, hemm, simple, side, uplo, mprocs, nprocs,
270 M=160, N=120, seed=42):
271 """Test pblas_simple_hemm, pblas_simple_symm, pblas_hemm, pblas_symm
273 The operation is
274 * C <- alpha*A*B + beta*C if side == 'L'
275 * C <- alpha*B*A + beta*C if side == 'R'
277 The computations are done with
278 * lower triangular of A if uplo == 'L'
279 * upper triangular of A if uplo == 'U'
281 Additional options
282 * A is Hermitian if hemm == True
283 * A is symmetric if hemm == False
284 * alpha=1 and beta=0 if simple == True
285 """
286 random = initialize_random(seed, dtype)
287 grid = BlacsGrid(world, mprocs, nprocs)
289 def generate_A_matrix(shape):
290 A0 = random(shape)
291 if grid.comm.rank == 0:
292 if hemm:
293 # Hermitian matrix
294 A0 = A0 + A0.T.conj()
295 else:
296 # Symmetric matrix
297 A0 = A0 + A0.T
299 # Only lower or upper triangular is used, so
300 # fill the other triangular with NaN to detect errors
301 if uplo == 'L':
302 A0 += np.triu(A0 * np.nan, 1)
303 else:
304 A0 += np.tril(A0 * np.nan, -1)
305 A0 = np.ascontiguousarray(A0)
306 return A0
308 # Initialize matrices
309 alpha, beta = initialize_alpha_beta(simple, random)
310 shapeA = {'L': (M, M), 'R': (N, N)}[side]
311 shapeB = (M, N)
312 shapeC = (M, N)
313 A0, A, descA = initialize_matrix(grid, *shapeA, 2, 2, generate_A_matrix)
314 B0, B, descB = initialize_matrix(grid, *shapeB, 2, 4, random)
315 C0, C, descC = initialize_matrix(grid, *shapeC, 3, 2, random)
317 if grid.comm.rank == 0:
318 print(A0)
320 # Calculate reference with numpy
321 tri2full(A0, uplo, map=np.conj if hemm else np.positive)
322 if side == 'L':
323 ref_C0 = alpha * np.dot(A0, B0) + beta * C0
324 else:
325 ref_C0 = alpha * np.dot(B0, A0) + beta * C0
326 else:
327 ref_C0 = None
329 # Calculate with scalapack
330 if simple and hemm:
331 pblas_simple_hemm(descA, descB, descC,
332 A, B, C,
333 uplo=uplo, side=side)
334 elif hemm:
335 pblas_hemm(alpha, A, B, beta, C,
336 descA, descB, descC,
337 uplo=uplo, side=side)
338 elif simple:
339 pblas_simple_symm(descA, descB, descC,
340 A, B, C,
341 uplo=uplo, side=side)
342 else:
343 pblas_symm(alpha, A, B, beta, C,
344 descA, descB, descC,
345 uplo=uplo, side=side)
347 # Check error
348 err = calculate_error(ref_C0, C, descC)
349 assert err < tol