Coverage for gpaw/test/linalg/test_mmm.py: 100%

23 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1"""Test BLAS matrix-matrix-multiplication interface.""" 

2import numpy as np 

3from gpaw.utilities.blas import mmm 

4 

5 

6def test_linalg_mmm(rng): 

7 def op(o, m): 

8 if o == 'N': 

9 return m 

10 if o == 'T': 

11 return m.T 

12 return m.T.conj() 

13 

14 def matrix(shape, dtype): 

15 if dtype == float: 

16 return rng.random(shape) 

17 return rng.random(shape) + 1j * rng.random(shape) 

18 

19 for dtype in [float, complex]: 

20 a = matrix((2, 3), dtype) 

21 for opa in 'NTC': 

22 A = op(opa, a) 

23 B = matrix((A.shape[1], 4), dtype) 

24 for opb in 'NTC': 

25 b = op(opb, B).copy() 

26 C = np.dot(A, B) 

27 mmm(1, a, opa, b, opb, -1, C) 

28 assert abs(C).max() < 1e-14