Coverage for gpaw/test/gpu/test_matrix.py: 94%
80 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1import numpy as np
2import pytest
3from gpaw.core.matrix import Matrix
4from gpaw.gpu import cupy as cp, as_np, as_xp
5from gpaw.mpi import world
6from gpaw.new.c import GPU_AWARE_MPI
7from gpaw.gpu.mpi import CuPyMPI
10@pytest.mark.gpu
11@pytest.mark.serial
12def test_zyrk():
13 a = np.array([[1, 1 + 2j, 2], [1, 0.5j, -1 - 0.5j]])
14 m = Matrix(2, 3, data=a)
15 b = m.multiply(m, opb='C', beta=0.0, symmetric=True)
16 b.tril2full()
17 a = cp.asarray(a)
18 m = Matrix(2, 3, data=a)
19 b2 = m.multiply(m, opb='C', beta=0.0, symmetric=True)
20 b2.tril2full()
21 c = b2.to_cpu()
22 assert (c.data == b.data).all()
25@pytest.mark.gpu
26@pytest.mark.serial
27def test_eigh():
28 H1 = Matrix(2, 2, data=np.array([[2, 42.1 + 42.1j], [0.1 - 0.1j, 3]]))
29 S1 = Matrix(2, 2, data=np.array([[1, 42.1 + 42.2j], [0.1 - 0.2j, 0.9]]))
30 H2 = Matrix(2, 2, data=cp.asarray(H1.data))
31 S2 = Matrix(2, 2, data=cp.asarray(S1.data))
33 E1 = H1.eigh(S1)
35 S0 = S1.copy()
36 S0.tril2full()
38 E2 = H2.eigh(S2)
39 assert as_np(E2) == pytest.approx(E1)
41 C1 = H1.data
42 C2 = H2.to_cpu().data
44 # Check that eigenvectors are parallel:
45 X = C1.conj() @ S0.data @ C2.T
46 assert abs(X) == pytest.approx(np.eye(2))
49def op(a: np.ndarray, o: str) -> np.ndarray:
50 if o == 'N':
51 return a
52 if o == 'C':
53 return a.T.conj()
54 1 / 0
57@pytest.mark.gpu
58@pytest.mark.parametrize(
59 'shape1, shape2, op1, op2, sym, same',
60 [((5, 9), (5, 9), 'N', 'C', 1, 1),
61 ((2, 3), (2, 3), 'N', 'C', 1, 0),
62 ((5, 9), (5, 9), 'N', 'C', 0, 0),
63 ((5, 9), (5, 9), 'C', 'N', 0, 0),
64 ((5, 9), (9, 5), 'C', 'C', 0, 0),
65 ((5, 5), (5, 9), 'N', 'N', 0, 0)])
66@pytest.mark.parametrize('beta', [0.0, 1.0])
67@pytest.mark.parametrize('dtype', [float, complex])
68@pytest.mark.parametrize('xp', [np, cp])
69def test_mul(shape1, shape2, op1, op2, beta, sym, same, dtype, xp, rng):
70 if world.size > 1 and xp is cp:
71 if op1 == 'C' or (op1 == 'N' and op2 == 'C' and sym and beta == 0.0):
72 pytest.skip('Not implemented!')
73 alpha = 1.234
74 comm = world if GPU_AWARE_MPI else CuPyMPI(world)
76 shape3 = (shape1[0] if op1 == 'N' else shape1[1],
77 shape2[1] if op2 == 'N' else shape2[0])
78 m1, m2, m3 = (Matrix(*shape, dtype=dtype, dist=(comm, 1, 1), xp=xp)
79 for shape in [shape1, shape2, shape3])
81 if world.rank == 0:
82 for m in [m1, m2, m3]:
83 data = m.data.view(float)
84 data[:] = as_xp(rng.random(data.shape), xp)
85 if sym:
86 m2.data[:] = m1.data
87 m3.data += m3.data.T.conj()
89 # Correct result:
90 a1, a2, a3 = (as_np(m.data) for m in [m1, m2, m3])
91 a3 = beta * a3 + alpha * op(a1, op1) @ op(a2, op2)
93 M1, M2, M3 = (Matrix(*shape, dtype=dtype, dist=(comm, -1, 1), xp=xp)
94 for shape in [shape1, shape2, shape3])
95 for m, M in zip([m1, m2, m3], [M1, M2, M3]):
96 m.redist(M)
98 if same:
99 M2 = M1
101 M1.multiply(M2, alpha=alpha, opa=op1, opb=op2, beta=beta,
102 out=M3, symmetric=sym)
104 m3 = M3.gather()
105 if world.rank == 0:
106 if sym:
107 m3.tril2full()
108 error = abs(a3 - as_np(m3.data)).max()
109 else:
110 error = 0.0
111 error = world.sum_scalar(error)
112 assert error < 1e-13
115if __name__ == '__main__':
116 test_mul((1, 1), (1, 19), 'N', 'N', 0.0, 0, 0,
117 complex, cp, np.random.default_rng(42))