Coverage for gpaw/utilities/mblas.py: 71%
17 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1import numpy as np
3import gpaw.cgpaw as cgpaw
4from gpaw.utilities.blas import axpy
5from gpaw import gpu
8def multi_axpy_cpu(a, x, y):
9 for ai, xi, yi in zip(a, x, y):
10 axpy(ai, xi, yi)
13def multi_axpy(a, x, y):
14 assert type(x) is type(y)
16 if isinstance(a, (float, complex)):
17 axpy(a, x, y)
18 else:
19 if not isinstance(x, np.ndarray):
20 if not isinstance(a, np.ndarray):
21 a_gpu = a
22 else:
23 a_gpu = gpu.copy_to_device(a)
24 cgpaw.multi_axpy_gpu(gpu.get_pointer(a_gpu),
25 a.dtype,
26 gpu.get_pointer(x),
27 x.shape,
28 gpu.get_pointer(y),
29 y.shape,
30 x.dtype)
31 else:
32 multi_axpy_cpu(a, x, y)