Coverage for gpaw/utilities/mblas.py: 71%

17 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1import numpy as np 

2 

3import gpaw.cgpaw as cgpaw 

4from gpaw.utilities.blas import axpy 

5from gpaw import gpu 

6 

7 

8def multi_axpy_cpu(a, x, y): 

9 for ai, xi, yi in zip(a, x, y): 

10 axpy(ai, xi, yi) 

11 

12 

13def multi_axpy(a, x, y): 

14 assert type(x) is type(y) 

15 

16 if isinstance(a, (float, complex)): 

17 axpy(a, x, y) 

18 else: 

19 if not isinstance(x, np.ndarray): 

20 if not isinstance(a, np.ndarray): 

21 a_gpu = a 

22 else: 

23 a_gpu = gpu.copy_to_device(a) 

24 cgpaw.multi_axpy_gpu(gpu.get_pointer(a_gpu), 

25 a.dtype, 

26 gpu.get_pointer(x), 

27 x.shape, 

28 gpu.get_pointer(y), 

29 y.shape, 

30 x.dtype) 

31 else: 

32 multi_axpy_cpu(a, x, y)