Coverage for gpaw/gpu/__init__.py: 32%
201 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1from __future__ import annotations
2import contextlib
3from time import time
4from typing import TYPE_CHECKING
5from types import ModuleType
6from collections.abc import Iterable
7from gpaw.new.timer import trace
9import numpy as np
11from gpaw.cgpaw import have_magma
13cupy_is_fake = True
14"""True if :mod:`cupy` has been replaced by ``gpaw.gpu.cpupy``"""
16is_hip = False
17"""True if we are using HIP"""
19device_id = None
20"""Device id"""
23def gpu_gemm(*args, **kwargs):
24 raise NotImplementedError('gpu_gemm: You are not using GPAW with GPUs.')
27if TYPE_CHECKING:
28 import gpaw.gpu.cpupy as cupy
29 import gpaw.gpu.cpupyx as cupyx
30else:
31 try:
32 import gpaw.cgpaw as cgpaw
33 if not hasattr(cgpaw, 'gpaw_gpu_init'):
34 raise ImportError
36 import cupy
37 # Cupy gemm wrapper (does extra copying):
38 # from cupy import cublas
39 # gpu_gemm = trace(gpu=True)(cublas.gemm) # noqa: F811
41 # Homerolled gemm wrapper and helper functions:
42 from cupy.cublas import (_get_scalar_ptr, _trans_to_cublas_op,
43 _change_order_if_necessary, device)
44 from cupy_backends.cuda.libs import cublas as _cublas
46 def _decide_ld_and_trans(a, trans):
47 ld = None
48 if a._c_contiguous:
49 ld = a.shape[1]
50 trans = 1 - trans
51 elif a._f_contiguous:
52 ld = a.shape[0]
53 elif a.strides[-1] == a.dtype.itemsize:
54 # Semi C-contiguous (sliced along second dim)
55 ld = a.strides[-2] // a.strides[-1]
56 trans = 1 - trans
57 return ld, trans
59 @trace(gpu=True)
60 def gpu_gemm( # noqa: F811
61 transa, transb, a, b, out=None, alpha=1.0, beta=0.0):
62 """Computes out = alpha * op(a) @ op(b) + beta * out
64 op(a) = a if transa is 'N', op(a) = a.T if transa is 'T',
65 op(a) = a.T.conj() if transa is 'H'.
66 op(b) = b if transb is 'N', op(b) = b.T if transb is 'T',
67 op(b) = b.T.conj() if transb is 'H'.
69 This is pretty much a copy of the code from cupy.cublas.gemm,
70 with _decide_ld_and_trans modified to not be afraid of
71 hermitian transposes and a preference to C-contiguous arrays.
72 """
73 assert a.ndim == b.ndim == 2
74 assert a.dtype == b.dtype
75 dtype = a.dtype.char
76 if dtype == 'f':
77 func = _cublas.sgemm
78 elif dtype == 'd':
79 func = _cublas.dgemm
80 elif dtype == 'F':
81 func = _cublas.cgemm
82 elif dtype == 'D':
83 func = _cublas.zgemm
84 else:
85 raise TypeError('invalid dtype')
87 transa = _trans_to_cublas_op(transa)
88 transb = _trans_to_cublas_op(transb)
89 if transa == _cublas.CUBLAS_OP_N:
90 m, k = a.shape
91 else:
92 k, m = a.shape
93 if transb == _cublas.CUBLAS_OP_N:
94 n = b.shape[1]
95 assert b.shape[0] == k
96 else:
97 n = b.shape[0]
98 assert b.shape[1] == k
99 if out is None:
100 out = cupy.empty((m, n), dtype=dtype, order='C')
101 beta = 0.0
102 else:
103 assert out.ndim == 2
104 assert out.shape == (m, n)
105 assert out.dtype == dtype
107 alpha, alpha_ptr = _get_scalar_ptr(alpha, a.dtype)
108 beta, beta_ptr = _get_scalar_ptr(beta, a.dtype)
109 handle = device.get_cublas_handle()
110 orig_mode = _cublas.getPointerMode(handle)
111 if isinstance(alpha, cupy.ndarray) or \
112 isinstance(beta, cupy.ndarray):
113 if not isinstance(alpha, cupy.ndarray):
114 alpha = cupy.array(alpha)
115 alpha_ptr = alpha.data.ptr
116 if not isinstance(beta, cupy.ndarray):
117 beta = cupy.array(beta)
118 beta_ptr = beta.data.ptr
119 _cublas.setPointerMode(handle,
120 _cublas.CUBLAS_POINTER_MODE_DEVICE)
121 else:
122 _cublas.setPointerMode(handle,
123 _cublas.CUBLAS_POINTER_MODE_HOST)
125 lda, transa = _decide_ld_and_trans(a, transa)
126 ldb, transb = _decide_ld_and_trans(b, transb)
127 if not (lda is None or ldb is None):
128 if out._c_contiguous:
129 # Computes out.T = alpha * b.T @ a.T + beta * out.T
130 try:
131 func(handle, 1 - transb, 1 - transa, n, m, k,
132 alpha_ptr, b.data.ptr, ldb, a.data.ptr, lda,
133 beta_ptr, out.data.ptr, n)
134 finally:
135 _cublas.setPointerMode(handle, orig_mode)
136 return out
137 elif out._f_contiguous:
138 try:
139 func(handle, transa, transb, m, n, k, alpha_ptr,
140 a.data.ptr, lda, b.data.ptr, ldb, beta_ptr,
141 out.data.ptr, m)
142 finally:
143 _cublas.setPointerMode(handle, orig_mode)
144 return out
145 elif out.strides[-1] == out.dtype.itemsize:
146 # Semi C-contiguous (sliced along second dim)
147 # Computes out.T = alpha * b.T @ a.T + beta * out.T
148 try:
149 ld_out = out.strides[-2] // out.strides[-1]
150 func(handle, 1 - transb, 1 - transa, n, m, k,
151 alpha_ptr, b.data.ptr, ldb, a.data.ptr, lda,
152 beta_ptr, out.data.ptr, ld_out)
153 finally:
154 _cublas.setPointerMode(handle, orig_mode)
155 return out
157 # Backup plan with copies
158 a, lda = _change_order_if_necessary(a, lda)
159 b, ldb = _change_order_if_necessary(b, ldb)
160 c = out
161 if not out._f_contiguous:
162 c = out.copy(order='F')
163 try:
164 func(handle, transa, transb, m, n, k, alpha_ptr, a.data.ptr,
165 lda, b.data.ptr, ldb, beta_ptr, c.data.ptr, m)
166 finally:
167 _cublas.setPointerMode(handle, orig_mode)
168 if not out._f_contiguous:
169 cupy._core.elementwise_copy(c, out)
170 return out
172 import cupyx
173 from cupy.cuda import runtime
174 numpy2 = np.__version__.split('.')[0] == '2'
176 def fftshift_patch(x, axes=None):
177 x = cupy.asarray(x)
178 if axes is None:
179 axes = list(range(x.ndim))
180 elif not isinstance(axes, Iterable):
181 axes = (axes,)
182 return cupy.roll(x, [x.shape[axis] // 2 for axis in axes], axes)
184 def ifftshift_patch(x, axes=None):
185 x = cupy.asarray(x)
186 if axes is None:
187 axes = list(range(x.ndim))
188 elif not isinstance(axes, Iterable):
189 axes = (axes,)
190 return cupy.roll(x, [-(x.shape[axis] // 2) for axis in axes], axes)
192 if numpy2:
193 cupy.fft.fftshift = fftshift_patch
194 cupy.fft.ifftshift = ifftshift_patch
196 is_hip = runtime.is_hip
197 cupy_is_fake = False
199 # Check the number of devices
200 # Do not fail when calling `gpaw info` on a login node without GPUs
201 try:
202 device_count = runtime.getDeviceCount()
203 except runtime.CUDARuntimeError as e:
204 # Likely no device present
205 if 'ErrorNoDevice' not in str(e):
206 # Raise error in case of some other error
207 raise e
208 device_count = 0
210 if device_count > 0:
211 # select GPU device (round-robin based on MPI rank)
212 # if not set, all MPI ranks will use the same default device
213 from gpaw.mpi import rank
214 runtime.setDevice(rank % device_count)
216 # initialise C parameters and memory buffers
217 import gpaw.cgpaw as cgpaw
218 cgpaw.gpaw_gpu_init()
220 # Generate a device id
221 import os
222 nodename = os.uname()[1]
223 bus_id = runtime.deviceGetPCIBusId(runtime.getDevice())
224 device_id = f'{nodename}:{bus_id}'
226 except ImportError:
227 import gpaw.gpu.cpupy as cupy
228 import gpaw.gpu.cpupyx as cupyx
229 from gpaw.gpu.cpupy.cublas import gemm as gpu_gemm # noqa
232__all__ = ['cupy', 'cupyx', 'as_xp', 'as_np', 'synchronize']
235def synchronize():
236 if not cupy_is_fake:
237 cupy.cuda.get_current_stream().synchronize()
240def as_np(array: np.ndarray | cupy.ndarray) -> np.ndarray:
241 """Transfer array to CPU (if not already there).
243 Parameters
244 ==========
245 array:
246 Numpy or CuPy array.
247 """
248 if isinstance(array, np.ndarray):
249 return array
250 return cupy.asnumpy(array)
253def as_xp(array, xp):
254 """Transfer array to CPU or GPU (if not already there).
256 Parameters
257 ==========
258 array:
259 Numpy or CuPy array.
260 xp:
261 :mod:`numpy` or :mod:`cupy`.
262 """
263 if xp is np:
264 if isinstance(array, np.ndarray):
265 return array
266 return cupy.asnumpy(array)
267 if isinstance(array, np.ndarray):
268 return cupy.asarray(array)
269 1 / 0
270 return array
273def einsum(subscripts, *operands, out):
274 if isinstance(out, np.ndarray):
275 np.einsum(subscripts, *operands, out=out)
276 else:
277 out[:] = cupy.einsum(subscripts, *operands)
280@trace(gpu=True)
281def cupy_eigh(a: cupy.ndarray, UPLO: str) -> tuple[cupy.ndarray, cupy.ndarray]:
282 """Wrapper for ``eigh()``.
284 Usually CUDA > MAGMA > HIP, so we try to choose the best one.
285 HIP native solver is questionably slow so for now do it on the CPU if
286 MAGMA is not available.
287 """
288 from scipy.linalg import eigh
289 if not is_hip:
290 return cupy.linalg.eigh(a, UPLO=UPLO)
292 elif have_magma and a.ndim == 2 and a.shape[0] > 128:
293 # import here to avoid circular import.
294 # magma needs cupy (possibly fake),
295 # which must be imported from this file
296 from gpaw.new.magma import eigh_magma_gpu
298 return eigh_magma_gpu(a, UPLO)
300 else:
301 # fallback to CPU
302 eigs, evals = eigh(cupy.asnumpy(a),
303 lower=(UPLO == 'L'),
304 check_finite=False)
306 return cupy.asarray(eigs), cupy.asarray(evals)
309class XP:
310 """Class for adding xp attribute (numpy or cupy).
312 Also implements pickling which will not work out of the box
313 because a module can't be pickled.
314 """
315 def __init__(self, xp: ModuleType):
316 self.xp = xp
318 def __getstate__(self):
319 state = self.__dict__.copy()
320 assert self.xp is np
321 del state['xp']
322 return state
324 def __setstate__(self, state):
325 state['xp'] = np
326 self.__dict__.update(state)
329@contextlib.contextmanager
330def T():
331 t1 = time()
332 yield
333 synchronize()
334 t2 = time()
335 print(f'{(t2 - t1) * 1e9:_.3f} ns')