Coverage for gpaw/gpu/__init__.py: 32%

201 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1from __future__ import annotations 

2import contextlib 

3from time import time 

4from typing import TYPE_CHECKING 

5from types import ModuleType 

6from collections.abc import Iterable 

7from gpaw.new.timer import trace 

8 

9import numpy as np 

10 

11from gpaw.cgpaw import have_magma 

12 

13cupy_is_fake = True 

14"""True if :mod:`cupy` has been replaced by ``gpaw.gpu.cpupy``""" 

15 

16is_hip = False 

17"""True if we are using HIP""" 

18 

19device_id = None 

20"""Device id""" 

21 

22 

23def gpu_gemm(*args, **kwargs): 

24 raise NotImplementedError('gpu_gemm: You are not using GPAW with GPUs.') 

25 

26 

27if TYPE_CHECKING: 

28 import gpaw.gpu.cpupy as cupy 

29 import gpaw.gpu.cpupyx as cupyx 

30else: 

31 try: 

32 import gpaw.cgpaw as cgpaw 

33 if not hasattr(cgpaw, 'gpaw_gpu_init'): 

34 raise ImportError 

35 

36 import cupy 

37 # Cupy gemm wrapper (does extra copying): 

38 # from cupy import cublas 

39 # gpu_gemm = trace(gpu=True)(cublas.gemm) # noqa: F811 

40 

41 # Homerolled gemm wrapper and helper functions: 

42 from cupy.cublas import (_get_scalar_ptr, _trans_to_cublas_op, 

43 _change_order_if_necessary, device) 

44 from cupy_backends.cuda.libs import cublas as _cublas 

45 

46 def _decide_ld_and_trans(a, trans): 

47 ld = None 

48 if a._c_contiguous: 

49 ld = a.shape[1] 

50 trans = 1 - trans 

51 elif a._f_contiguous: 

52 ld = a.shape[0] 

53 elif a.strides[-1] == a.dtype.itemsize: 

54 # Semi C-contiguous (sliced along second dim) 

55 ld = a.strides[-2] // a.strides[-1] 

56 trans = 1 - trans 

57 return ld, trans 

58 

59 @trace(gpu=True) 

60 def gpu_gemm( # noqa: F811 

61 transa, transb, a, b, out=None, alpha=1.0, beta=0.0): 

62 """Computes out = alpha * op(a) @ op(b) + beta * out 

63 

64 op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', 

65 op(a) = a.T.conj() if transa is 'H'. 

66 op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', 

67 op(b) = b.T.conj() if transb is 'H'. 

68 

69 This is pretty much a copy of the code from cupy.cublas.gemm, 

70 with _decide_ld_and_trans modified to not be afraid of 

71 hermitian transposes and a preference to C-contiguous arrays. 

72 """ 

73 assert a.ndim == b.ndim == 2 

74 assert a.dtype == b.dtype 

75 dtype = a.dtype.char 

76 if dtype == 'f': 

77 func = _cublas.sgemm 

78 elif dtype == 'd': 

79 func = _cublas.dgemm 

80 elif dtype == 'F': 

81 func = _cublas.cgemm 

82 elif dtype == 'D': 

83 func = _cublas.zgemm 

84 else: 

85 raise TypeError('invalid dtype') 

86 

87 transa = _trans_to_cublas_op(transa) 

88 transb = _trans_to_cublas_op(transb) 

89 if transa == _cublas.CUBLAS_OP_N: 

90 m, k = a.shape 

91 else: 

92 k, m = a.shape 

93 if transb == _cublas.CUBLAS_OP_N: 

94 n = b.shape[1] 

95 assert b.shape[0] == k 

96 else: 

97 n = b.shape[0] 

98 assert b.shape[1] == k 

99 if out is None: 

100 out = cupy.empty((m, n), dtype=dtype, order='C') 

101 beta = 0.0 

102 else: 

103 assert out.ndim == 2 

104 assert out.shape == (m, n) 

105 assert out.dtype == dtype 

106 

107 alpha, alpha_ptr = _get_scalar_ptr(alpha, a.dtype) 

108 beta, beta_ptr = _get_scalar_ptr(beta, a.dtype) 

109 handle = device.get_cublas_handle() 

110 orig_mode = _cublas.getPointerMode(handle) 

111 if isinstance(alpha, cupy.ndarray) or \ 

112 isinstance(beta, cupy.ndarray): 

113 if not isinstance(alpha, cupy.ndarray): 

114 alpha = cupy.array(alpha) 

115 alpha_ptr = alpha.data.ptr 

116 if not isinstance(beta, cupy.ndarray): 

117 beta = cupy.array(beta) 

118 beta_ptr = beta.data.ptr 

119 _cublas.setPointerMode(handle, 

120 _cublas.CUBLAS_POINTER_MODE_DEVICE) 

121 else: 

122 _cublas.setPointerMode(handle, 

123 _cublas.CUBLAS_POINTER_MODE_HOST) 

124 

125 lda, transa = _decide_ld_and_trans(a, transa) 

126 ldb, transb = _decide_ld_and_trans(b, transb) 

127 if not (lda is None or ldb is None): 

128 if out._c_contiguous: 

129 # Computes out.T = alpha * b.T @ a.T + beta * out.T 

130 try: 

131 func(handle, 1 - transb, 1 - transa, n, m, k, 

132 alpha_ptr, b.data.ptr, ldb, a.data.ptr, lda, 

133 beta_ptr, out.data.ptr, n) 

134 finally: 

135 _cublas.setPointerMode(handle, orig_mode) 

136 return out 

137 elif out._f_contiguous: 

138 try: 

139 func(handle, transa, transb, m, n, k, alpha_ptr, 

140 a.data.ptr, lda, b.data.ptr, ldb, beta_ptr, 

141 out.data.ptr, m) 

142 finally: 

143 _cublas.setPointerMode(handle, orig_mode) 

144 return out 

145 elif out.strides[-1] == out.dtype.itemsize: 

146 # Semi C-contiguous (sliced along second dim) 

147 # Computes out.T = alpha * b.T @ a.T + beta * out.T 

148 try: 

149 ld_out = out.strides[-2] // out.strides[-1] 

150 func(handle, 1 - transb, 1 - transa, n, m, k, 

151 alpha_ptr, b.data.ptr, ldb, a.data.ptr, lda, 

152 beta_ptr, out.data.ptr, ld_out) 

153 finally: 

154 _cublas.setPointerMode(handle, orig_mode) 

155 return out 

156 

157 # Backup plan with copies 

158 a, lda = _change_order_if_necessary(a, lda) 

159 b, ldb = _change_order_if_necessary(b, ldb) 

160 c = out 

161 if not out._f_contiguous: 

162 c = out.copy(order='F') 

163 try: 

164 func(handle, transa, transb, m, n, k, alpha_ptr, a.data.ptr, 

165 lda, b.data.ptr, ldb, beta_ptr, c.data.ptr, m) 

166 finally: 

167 _cublas.setPointerMode(handle, orig_mode) 

168 if not out._f_contiguous: 

169 cupy._core.elementwise_copy(c, out) 

170 return out 

171 

172 import cupyx 

173 from cupy.cuda import runtime 

174 numpy2 = np.__version__.split('.')[0] == '2' 

175 

176 def fftshift_patch(x, axes=None): 

177 x = cupy.asarray(x) 

178 if axes is None: 

179 axes = list(range(x.ndim)) 

180 elif not isinstance(axes, Iterable): 

181 axes = (axes,) 

182 return cupy.roll(x, [x.shape[axis] // 2 for axis in axes], axes) 

183 

184 def ifftshift_patch(x, axes=None): 

185 x = cupy.asarray(x) 

186 if axes is None: 

187 axes = list(range(x.ndim)) 

188 elif not isinstance(axes, Iterable): 

189 axes = (axes,) 

190 return cupy.roll(x, [-(x.shape[axis] // 2) for axis in axes], axes) 

191 

192 if numpy2: 

193 cupy.fft.fftshift = fftshift_patch 

194 cupy.fft.ifftshift = ifftshift_patch 

195 

196 is_hip = runtime.is_hip 

197 cupy_is_fake = False 

198 

199 # Check the number of devices 

200 # Do not fail when calling `gpaw info` on a login node without GPUs 

201 try: 

202 device_count = runtime.getDeviceCount() 

203 except runtime.CUDARuntimeError as e: 

204 # Likely no device present 

205 if 'ErrorNoDevice' not in str(e): 

206 # Raise error in case of some other error 

207 raise e 

208 device_count = 0 

209 

210 if device_count > 0: 

211 # select GPU device (round-robin based on MPI rank) 

212 # if not set, all MPI ranks will use the same default device 

213 from gpaw.mpi import rank 

214 runtime.setDevice(rank % device_count) 

215 

216 # initialise C parameters and memory buffers 

217 import gpaw.cgpaw as cgpaw 

218 cgpaw.gpaw_gpu_init() 

219 

220 # Generate a device id 

221 import os 

222 nodename = os.uname()[1] 

223 bus_id = runtime.deviceGetPCIBusId(runtime.getDevice()) 

224 device_id = f'{nodename}:{bus_id}' 

225 

226 except ImportError: 

227 import gpaw.gpu.cpupy as cupy 

228 import gpaw.gpu.cpupyx as cupyx 

229 from gpaw.gpu.cpupy.cublas import gemm as gpu_gemm # noqa 

230 

231 

232__all__ = ['cupy', 'cupyx', 'as_xp', 'as_np', 'synchronize'] 

233 

234 

235def synchronize(): 

236 if not cupy_is_fake: 

237 cupy.cuda.get_current_stream().synchronize() 

238 

239 

240def as_np(array: np.ndarray | cupy.ndarray) -> np.ndarray: 

241 """Transfer array to CPU (if not already there). 

242 

243 Parameters 

244 ========== 

245 array: 

246 Numpy or CuPy array. 

247 """ 

248 if isinstance(array, np.ndarray): 

249 return array 

250 return cupy.asnumpy(array) 

251 

252 

253def as_xp(array, xp): 

254 """Transfer array to CPU or GPU (if not already there). 

255 

256 Parameters 

257 ========== 

258 array: 

259 Numpy or CuPy array. 

260 xp: 

261 :mod:`numpy` or :mod:`cupy`. 

262 """ 

263 if xp is np: 

264 if isinstance(array, np.ndarray): 

265 return array 

266 return cupy.asnumpy(array) 

267 if isinstance(array, np.ndarray): 

268 return cupy.asarray(array) 

269 1 / 0 

270 return array 

271 

272 

273def einsum(subscripts, *operands, out): 

274 if isinstance(out, np.ndarray): 

275 np.einsum(subscripts, *operands, out=out) 

276 else: 

277 out[:] = cupy.einsum(subscripts, *operands) 

278 

279 

280@trace(gpu=True) 

281def cupy_eigh(a: cupy.ndarray, UPLO: str) -> tuple[cupy.ndarray, cupy.ndarray]: 

282 """Wrapper for ``eigh()``. 

283 

284 Usually CUDA > MAGMA > HIP, so we try to choose the best one. 

285 HIP native solver is questionably slow so for now do it on the CPU if 

286 MAGMA is not available. 

287 """ 

288 from scipy.linalg import eigh 

289 if not is_hip: 

290 return cupy.linalg.eigh(a, UPLO=UPLO) 

291 

292 elif have_magma and a.ndim == 2 and a.shape[0] > 128: 

293 # import here to avoid circular import. 

294 # magma needs cupy (possibly fake), 

295 # which must be imported from this file 

296 from gpaw.new.magma import eigh_magma_gpu 

297 

298 return eigh_magma_gpu(a, UPLO) 

299 

300 else: 

301 # fallback to CPU 

302 eigs, evals = eigh(cupy.asnumpy(a), 

303 lower=(UPLO == 'L'), 

304 check_finite=False) 

305 

306 return cupy.asarray(eigs), cupy.asarray(evals) 

307 

308 

309class XP: 

310 """Class for adding xp attribute (numpy or cupy). 

311 

312 Also implements pickling which will not work out of the box 

313 because a module can't be pickled. 

314 """ 

315 def __init__(self, xp: ModuleType): 

316 self.xp = xp 

317 

318 def __getstate__(self): 

319 state = self.__dict__.copy() 

320 assert self.xp is np 

321 del state['xp'] 

322 return state 

323 

324 def __setstate__(self, state): 

325 state['xp'] = np 

326 self.__dict__.update(state) 

327 

328 

329@contextlib.contextmanager 

330def T(): 

331 t1 = time() 

332 yield 

333 synchronize() 

334 t2 = time() 

335 print(f'{(t2 - t1) * 1e9:_.3f} ns')