Coverage for gpaw/test/linalg/test_mmm.py: 100%
23 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1"""Test BLAS matrix-matrix-multiplication interface."""
2import numpy as np
3from gpaw.utilities.blas import mmm
6def test_linalg_mmm(rng):
7 def op(o, m):
8 if o == 'N':
9 return m
10 if o == 'T':
11 return m.T
12 return m.T.conj()
14 def matrix(shape, dtype):
15 if dtype == float:
16 return rng.random(shape)
17 return rng.random(shape) + 1j * rng.random(shape)
19 for dtype in [float, complex]:
20 a = matrix((2, 3), dtype)
21 for opa in 'NTC':
22 A = op(opa, a)
23 B = matrix((A.shape[1], 4), dtype)
24 for opb in 'NTC':
25 b = op(opb, B).copy()
26 C = np.dot(A, B)
27 mmm(1, a, opa, b, opb, -1, C)
28 assert abs(C).max() < 1e-14