Coverage for gpaw/test/linalg/test_blas.py: 100%

55 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import numpy as np 

2from gpaw.utilities.blas import axpy, r2k, rk, gemmdot, mmm, mmmx 

3from gpaw.utilities.tools import tri2full 

4 

5 

6def test_gemm_size_zero(): 

7 c = np.ones((3, 3)) 

8 a = np.zeros((0, 3)) 

9 b = np.zeros((3, 0)) 

10 d = np.zeros((0, 0)) 

11 e = np.zeros((0, 3)) 

12 # gemm(1.0, a, b, 0.0, c, 'n') 

13 mmm(1.0, b, 'N', a, 'N', 0.0, c) 

14 assert (c == 0.0).all() 

15 mmm(1.0, d, 'N', a, 'N', 0.0, e) 

16 

17 

18def test_linalg_blas(): 

19 a = np.arange(5 * 7).reshape(5, 7) + 4. 

20 a2 = np.arange(3 * 7).reshape(3, 7) + 3. 

21 b = np.arange(7) - 2. 

22 

23 # Check gemmdot with floats 

24 assert np.all(np.dot(a, b) == gemmdot(a, b)) 

25 assert np.all(np.dot(a, a2.T) == gemmdot(a, a2, trans='t')) 

26 assert np.all(np.dot(a, a2.T) == gemmdot(a, a2, trans='c')) 

27 assert np.dot(b, b) == gemmdot(b, b) 

28 

29 # Check gemmdot with complex arrays 

30 a = a * (2 + 1.j) 

31 a2 = a2 * (-1 + 3.j) 

32 b = b * (3 - 2.j) 

33 assert np.all(np.dot(a, b) == gemmdot(a, b)) 

34 assert np.all(np.dot(a, a2.T) == gemmdot(a, a2, trans='t')) 

35 assert np.all(np.dot(a, a2.T.conj()) == gemmdot(a, a2, trans='c')) 

36 assert np.dot(b, b) == gemmdot(b, b, trans='n') 

37 assert np.dot(b, b.conj()) == gemmdot(b, b, trans='c') 

38 

39 # Check gemm for transa='n' 

40 a2 = np.arange(7 * 5 * 1 * 3).reshape(7, 5, 1, 3) * (-1. + 4.j) + 3. 

41 c = np.tensordot(a, a2, [1, 0]) 

42 mmmx(1., a, 'N', a2, 'N', -1., c) 

43 assert not c.any() 

44 

45 # Check gemm for transa='c' 

46 a = np.arange(4 * 5 * 1 * 3).reshape(4, 5, 1, 3) * (3. - 2.j) + 4. 

47 c = np.tensordot(a, a2.conj(), [[1, 2, 3], [1, 2, 3]]) 

48 mmmx(1., a, 'N', a2, 'C', -1., c) 

49 assert not c.any() 

50 

51 # Check axpy 

52 c = 5.j * a 

53 axpy(-5.j, a, c) 

54 assert not c.any() 

55 

56 # Check rk 

57 c = np.tensordot(a, a.conj(), [[1, 2, 3], [1, 2, 3]]) 

58 rk(1., a, -1., c) 

59 tri2full(c) 

60 assert not c.any() 

61 

62 # Check gemmdot for transa='c' 

63 c = np.tensordot(a, a2.conj(), [-1, -1]) 

64 gemmdot(a, a2, beta=-1., out=c, trans='c') 

65 assert not c.any() 

66 

67 # Check gemmdot for transa='n' 

68 a2.shape = 3, 7, 5, 1 

69 c = np.tensordot(a, a2, [-1, 0]) 

70 gemmdot(a, a2, beta=-1., out=c, trans='n') 

71 assert not c.any() 

72 

73 # Check r2k 

74 a2 = 5. * a 

75 c = np.tensordot(a, a2.conj(), [[1, 2, 3], [1, 2, 3]]) 

76 r2k(.5, a, a2, -1., c) 

77 tri2full(c) 

78 assert not c.any()