Coverage for gpaw/test/linalg/test_blas.py: 100%
55 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1import numpy as np
2from gpaw.utilities.blas import axpy, r2k, rk, gemmdot, mmm, mmmx
3from gpaw.utilities.tools import tri2full
6def test_gemm_size_zero():
7 c = np.ones((3, 3))
8 a = np.zeros((0, 3))
9 b = np.zeros((3, 0))
10 d = np.zeros((0, 0))
11 e = np.zeros((0, 3))
12 # gemm(1.0, a, b, 0.0, c, 'n')
13 mmm(1.0, b, 'N', a, 'N', 0.0, c)
14 assert (c == 0.0).all()
15 mmm(1.0, d, 'N', a, 'N', 0.0, e)
18def test_linalg_blas():
19 a = np.arange(5 * 7).reshape(5, 7) + 4.
20 a2 = np.arange(3 * 7).reshape(3, 7) + 3.
21 b = np.arange(7) - 2.
23 # Check gemmdot with floats
24 assert np.all(np.dot(a, b) == gemmdot(a, b))
25 assert np.all(np.dot(a, a2.T) == gemmdot(a, a2, trans='t'))
26 assert np.all(np.dot(a, a2.T) == gemmdot(a, a2, trans='c'))
27 assert np.dot(b, b) == gemmdot(b, b)
29 # Check gemmdot with complex arrays
30 a = a * (2 + 1.j)
31 a2 = a2 * (-1 + 3.j)
32 b = b * (3 - 2.j)
33 assert np.all(np.dot(a, b) == gemmdot(a, b))
34 assert np.all(np.dot(a, a2.T) == gemmdot(a, a2, trans='t'))
35 assert np.all(np.dot(a, a2.T.conj()) == gemmdot(a, a2, trans='c'))
36 assert np.dot(b, b) == gemmdot(b, b, trans='n')
37 assert np.dot(b, b.conj()) == gemmdot(b, b, trans='c')
39 # Check gemm for transa='n'
40 a2 = np.arange(7 * 5 * 1 * 3).reshape(7, 5, 1, 3) * (-1. + 4.j) + 3.
41 c = np.tensordot(a, a2, [1, 0])
42 mmmx(1., a, 'N', a2, 'N', -1., c)
43 assert not c.any()
45 # Check gemm for transa='c'
46 a = np.arange(4 * 5 * 1 * 3).reshape(4, 5, 1, 3) * (3. - 2.j) + 4.
47 c = np.tensordot(a, a2.conj(), [[1, 2, 3], [1, 2, 3]])
48 mmmx(1., a, 'N', a2, 'C', -1., c)
49 assert not c.any()
51 # Check axpy
52 c = 5.j * a
53 axpy(-5.j, a, c)
54 assert not c.any()
56 # Check rk
57 c = np.tensordot(a, a.conj(), [[1, 2, 3], [1, 2, 3]])
58 rk(1., a, -1., c)
59 tri2full(c)
60 assert not c.any()
62 # Check gemmdot for transa='c'
63 c = np.tensordot(a, a2.conj(), [-1, -1])
64 gemmdot(a, a2, beta=-1., out=c, trans='c')
65 assert not c.any()
67 # Check gemmdot for transa='n'
68 a2.shape = 3, 7, 5, 1
69 c = np.tensordot(a, a2, [-1, 0])
70 gemmdot(a, a2, beta=-1., out=c, trans='n')
71 assert not c.any()
73 # Check r2k
74 a2 = 5. * a
75 c = np.tensordot(a, a2.conj(), [[1, 2, 3], [1, 2, 3]])
76 r2k(.5, a, a2, -1., c)
77 tri2full(c)
78 assert not c.any()