Coverage for gpaw/utilities/blas.py: 67%

220 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1# Copyright (C) 2003 CAMP 

2# Please see the accompanying LICENSE file for further information. 

3 

4""" 

5Python wrapper functions for the ``C`` package: 

6Basic Linear Algebra Subroutines (BLAS) 

7 

8See also: 

9https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms 

10and 

11https://www.netlib.org/lapack/lug/node145.html 

12""" 

13from typing import TypeVar 

14 

15import gpaw.cgpaw as cgpaw 

16import numpy as np 

17import scipy.linalg.blas as blas 

18from gpaw import debug 

19from gpaw.new import prod 

20from gpaw.typing import Array2D, ArrayND 

21from gpaw.utilities import is_contiguous 

22 

23 

24def is_finite(array, tril=False): 

25 if isinstance(array, np.ndarray): 

26 xp = np 

27 else: 

28 from gpaw.gpu import cupy as xp 

29 if tril: 

30 array = xp.tril(array) 

31 return xp.isfinite(array).all() 

32 

33 

34__all__ = ['mmm'] 

35 

36T = TypeVar('T', float, complex) 

37 

38 

39def mmm(alpha: T, 

40 a: Array2D, 

41 opa: str, 

42 b: Array2D, 

43 opb: str, 

44 beta: T, 

45 c: Array2D) -> None: 

46 """Matrix-matrix multiplication using dgemm or zgemm. 

47 

48 For opa='N' and opb='N', we have::: 

49 

50 c <- αab + βc. 

51 

52 Use 'T' to transpose matrices and 'C' to transpose and complex conjugate 

53 matrices. 

54 """ 

55 

56 assert opa in 'NTC' 

57 assert opb in 'NTC' 

58 

59 if opa == 'N': 

60 a1, a2 = a.shape 

61 else: 

62 a2, a1 = a.shape 

63 if opb == 'N': 

64 b1, b2 = b.shape 

65 else: 

66 b2, b1 = b.shape 

67 assert a2 == b1 

68 assert c.shape == (a1, b2) 

69 

70 assert a.dtype == b.dtype == c.dtype 

71 assert a.strides[1] == c.itemsize or a.size == 0 

72 assert b.strides[1] == c.itemsize or b.size == 0 

73 assert c.strides[1] == c.itemsize or c.size == 0 

74 if a.dtype == float: 

75 assert not isinstance(alpha, complex) 

76 assert not isinstance(beta, complex) 

77 else: 

78 assert a.dtype == complex 

79 

80 cgpaw.mmm(alpha, a, opa, b, opb, beta, c) 

81 

82 

83def gpu_mmm(alpha, a, opa, b, opb, beta, c): 

84 """Launch CPU or GPU version of mmm().""" 

85 m = b.shape[1] if opb == 'N' else b.shape[0] 

86 n = a.shape[0] if opa == 'N' else a.shape[1] 

87 k = b.shape[0] if opb == 'N' else b.shape[1] 

88 lda = a.strides[0] // a.itemsize 

89 ldb = b.strides[0] // b.itemsize 

90 ldc = c.strides[0] // c.itemsize 

91 cgpaw.mmm_gpu(alpha, a.data.ptr, lda, opa, 

92 b.data.ptr, ldb, opb, beta, 

93 c.data.ptr, ldc, c.itemsize, 

94 m, n, k) 

95 

96 

97def gpu_scal(alpha, x): 

98 """alpha x 

99 

100 Performs the operation:: 

101 

102 x <- alpha * x 

103 

104 """ 

105 if debug: 

106 if isinstance(alpha, complex): 

107 assert is_contiguous(x, complex) 

108 else: 

109 assert isinstance(alpha, float) 

110 assert x.dtype in [float, complex] 

111 assert x.flags.c_contiguous 

112 cgpaw.scal_gpu(alpha, x.data.ptr, x.shape, x.dtype) 

113 

114 

115def to2d(array: ArrayND) -> Array2D: 

116 """2D view af ndarray. 

117 

118 >>> to2d(np.zeros((2, 3, 4))).shape 

119 (2, 12) 

120 """ 

121 shape = array.shape 

122 return array.reshape((shape[0], prod(shape[1:]))) 

123 

124 

125def mmmx(alpha: T, 

126 a: ArrayND, 

127 opa: str, 

128 b: ArrayND, 

129 opb: str, 

130 beta: T, 

131 c: ArrayND) -> None: 

132 """Matrix-matrix multiplication using dgemm or zgemm. 

133 

134 Arrays a, b and c are converted to 2D arrays before calling mmm(). 

135 """ 

136 mmm(alpha, to2d(a), opa, to2d(b), opb, beta, to2d(c)) 

137 

138 

139def gpu_gemm(alpha, a, b, beta, c, transa='n'): 

140 """General Matrix Multiply. 

141 

142 Performs the operation:: 

143 

144 c <- alpha * b.a + beta * c 

145 

146 If transa is "n", ``b.a`` denotes the matrix multiplication defined by:: 

147 

148 _ 

149 \ 

150 (b.a) = ) b * a 

151 ijkl... /_ ip pjkl... 

152 p 

153 

154 If transa is "t" or "c", ``b.a`` denotes the matrix multiplication 

155 defined by:: 

156 

157 _ 

158 \ 

159 (b.a) = ) b * a 

160 ij /_ iklm... jklm... 

161 klm... 

162 

163 where in case of "c" also complex conjugate of a is taken. 

164 """ 

165 if debug: 

166 assert beta == 0.0 or is_finite(c) 

167 

168 assert (a.dtype == float and b.dtype == float and c.dtype == float and 

169 isinstance(alpha, float) and isinstance(beta, float) or 

170 a.dtype == complex and b.dtype == complex and 

171 c.dtype == complex) 

172 assert a.flags.c_contiguous 

173 if transa == 'n': 

174 assert c.flags.c_contiguous or (c.ndim == 2 

175 and c.strides[1] == c.itemsize) 

176 assert b.ndim == 2 

177 assert b.strides[1] == b.itemsize 

178 assert a.shape[0] == b.shape[1] 

179 assert c.shape == b.shape[0:1] + a.shape[1:] 

180 else: 

181 assert b.size == 0 or b[0].flags.c_contiguous 

182 assert c.strides[1] == c.itemsize 

183 assert a.shape[1:] == b.shape[1:] 

184 assert c.shape == (b.shape[0], a.shape[0]) 

185 

186 cgpaw.gemm_gpu(alpha, a.data.ptr, a.shape, 

187 b.data.ptr, b.shape, beta, 

188 c.data.ptr, c.shape, 

189 a.dtype, transa) 

190 

191 

192def gpu_gemv(alpha, a, x, beta, y, trans='t'): 

193 """General Matrix Vector product. 

194 

195 Performs the operation:: 

196 

197 y <- alpha * a.x + beta * y 

198 

199 ``a.x`` denotes matrix multiplication, where the product-sum is 

200 over the entire length of the vector x and 

201 the first dimension of a (for trans='n'), or 

202 the last dimension of a (for trans='t' or 'c'). 

203 

204 If trans='c', the complex conjugate of a is used. The default is 

205 trans='t', i.e. behaviour like np.dot with a 2D matrix and a vector. 

206 """ 

207 if debug: 

208 assert (a.dtype == float and x.dtype == float and y.dtype == float and 

209 isinstance(alpha, float) and isinstance(beta, float) or 

210 a.dtype == complex and x.dtype == complex and 

211 y.dtype == complex) 

212 assert a.flags.c_contiguous 

213 assert y.flags.c_contiguous 

214 assert x.ndim == 1 

215 assert y.ndim == a.ndim - 1 

216 if trans == 'n': 

217 assert a.shape[0] == x.shape[0] 

218 assert a.shape[1:] == y.shape 

219 else: 

220 assert a.shape[-1] == x.shape[0] 

221 assert a.shape[:-1] == y.shape 

222 

223 cgpaw.gemv_gpu(alpha, a.data.ptr, a.shape, 

224 x.data.ptr, x.shape, beta, 

225 y.data.ptr, a.dtype, 

226 trans) 

227 

228 

229which_axpy = { 

230 np.float32: blas.saxpy, 

231 np.float64: blas.daxpy, 

232 np.complex64: blas.caxpy, 

233 np.complex128: blas.zaxpy 

234} 

235 

236 

237def axpy(alpha, x, y): 

238 """alpha x plus y. 

239 

240 Performs the operation:: 

241 

242 y <- alpha * x + y 

243 

244 """ 

245 if x.size == 0: 

246 return 

247 assert x.flags.contiguous 

248 assert y.flags.contiguous 

249 x = x.ravel() 

250 y = y.ravel() 

251 z = which_axpy[np.dtype(x.dtype).type](x, y, a=alpha) 

252 assert z is y, (x, y, x.shape, y.shape) 

253 

254 

255def gpu_axpy(alpha, x, y): 

256 """alpha x plus y. 

257 

258 Performs the operation:: 

259 

260 y <- alpha * x + y 

261 

262 """ 

263 if debug: 

264 if isinstance(alpha, complex): 

265 assert is_contiguous(x, complex) and is_contiguous(y, complex) 

266 else: 

267 assert isinstance(alpha, float) 

268 assert x.dtype in [float, complex] 

269 assert x.dtype == y.dtype 

270 assert x.flags.c_contiguous and y.flags.c_contiguous 

271 assert x.shape == y.shape 

272 

273 cgpaw.axpy_gpu(alpha, x.data.ptr, x.shape, 

274 y.data.ptr, y.shape, 

275 x.dtype) 

276 

277 

278def rk(alpha, a, beta, c, trans='c'): 

279 """Rank-k update of a matrix. 

280 

281 For ``trans='c'`` the following operation is performed::: 

282 

283 

284 c <- αaa + βc, 

285 

286 and for ``trans='t'`` we get::: 

287 

288 

289 c <- αa a + βc 

290 

291 If the ``a`` array has more than 2 dimensions then the 2., 3., ... 

292 axes are combined. 

293 

294 Only the lower triangle of ``c`` will contain sensible numbers. 

295 """ 

296 if debug: 

297 assert beta == 0.0 or is_finite(c, tril=True) 

298 

299 assert (a.dtype == float and c.dtype == float or 

300 a.dtype == complex and c.dtype == complex) 

301 assert a.flags.c_contiguous, (a.shape, a.strides, a.dtype) 

302 assert a.ndim > 1 

303 if trans == 'n': 

304 assert c.shape == (a.shape[1], a.shape[1]) 

305 else: 

306 assert c.shape == (a.shape[0], a.shape[0]) 

307 assert c.strides[1] == c.itemsize or c.size == 0 

308 

309 cgpaw.rk(alpha, a, beta, c, trans) 

310 

311 

312def gpu_rk(alpha, a, beta, c, trans='c'): 

313 """Launch CPU or GPU version of rk().""" 

314 cgpaw.rk_gpu(alpha, a.data.ptr, a.shape, 

315 beta, c.data.ptr, c.shape, 

316 a.dtype) 

317 

318 

319def r2k(alpha, a, b, beta, c, trans='c'): 

320 """Rank-2k update of a matrix. 

321 

322 Performs the operation:: 

323 

324 dag cc dag 

325 c <- alpha * a . b + alpha * b . a + beta * c 

326 

327 or if trans='n':: 

328 dag cc dag 

329 c <- alpha * a . b + alpha * b . a + beta * c 

330 

331 where ``a.b`` denotes the matrix multiplication defined by:: 

332 

333 _ 

334 \ 

335 (a.b) = ) a * b 

336 ij /_ ipklm... pjklm... 

337 pklm... 

338 

339 ``cc`` denotes complex conjugation. 

340 

341 ``dag`` denotes the hermitian conjugate (complex conjugation plus a 

342 swap of axis 0 and 1). 

343 

344 Only the lower triangle of ``c`` will contain sensible numbers. 

345 """ 

346 if debug: 

347 assert beta == 0.0 or is_finite(c, tril=True) 

348 assert (a.dtype == float and b.dtype == float and c.dtype == float or 

349 a.dtype == complex and b.dtype == complex and 

350 c.dtype == complex) 

351 assert a.flags.c_contiguous and b.flags.c_contiguous 

352 assert a.ndim > 1 

353 assert a.shape == b.shape 

354 if trans == 'c': 

355 assert c.shape == (a.shape[0], a.shape[0]) 

356 else: 

357 assert c.shape == (a.shape[1], a.shape[1]) 

358 assert c.strides[1] == c.itemsize or c.size == 0 

359 

360 cgpaw.r2k(alpha, a, b, beta, c, trans) 

361 

362 

363def gpu_r2k(alpha, a, b, beta, c, trans='c'): 

364 """Launch CPU or GPU version of r2k().""" 

365 cgpaw.r2k_gpu(alpha, a.data.ptr, a.shape, 

366 b.data.ptr, b.shape, beta, 

367 c.data.ptr, c.shape, 

368 a.dtype) 

369 

370 

371def gpu_dotc(a, b): 

372 r"""Dot product, conjugating the first vector with complex arguments. 

373 

374 Returns the value of the operation:: 

375 

376 _ 

377 \ cc 

378 ) a * b 

379 /_ ijk... ijk... 

380 ijk... 

381 

382 ``cc`` denotes complex conjugation. 

383 """ 

384 if debug: 

385 assert ((is_contiguous(a, float) and is_contiguous(b, float)) or 

386 (is_contiguous(a, complex) and is_contiguous(b, complex))) 

387 assert a.shape == b.shape 

388 

389 return cgpaw.dotc_gpu(a.data.ptr, a.shape, 

390 b.data.ptr, a.dtype) 

391 

392 

393def gpu_dotu(a, b): 

394 """Dot product, NOT conjugating the first vector with complex arguments. 

395 

396 Returns the value of the operation:: 

397 

398 _ 

399 \ 

400 ) a * b 

401 /_ ijk... ijk... 

402 ijk... 

403 

404 

405 """ 

406 if debug: 

407 assert ((is_contiguous(a, float) and is_contiguous(b, float)) or 

408 (is_contiguous(a, complex) and is_contiguous(b, complex))) 

409 assert a.shape == b.shape 

410 

411 return cgpaw.dotu_gpu(a.data.ptr, a.shape, 

412 b.data.ptr, a.dtype) 

413 

414 

415def _gemmdot(a, b, alpha=1.0, beta=1.0, out=None, trans='n'): 

416 """Matrix multiplication using gemm. 

417 

418 return reference to out, where:: 

419 

420 out <- alpha * a . b + beta * out 

421 

422 If out is None, a suitably sized zero array will be created. 

423 

424 ``a.b`` denotes matrix multiplication, where the product-sum is 

425 over the last dimension of a, and either 

426 the first dimension of b (for trans='n'), or 

427 the last dimension of b (for trans='t' or 'c'). 

428 

429 If trans='c', the complex conjugate of b is used. 

430 """ 

431 # Store original shapes 

432 ashape = a.shape 

433 bshape = b.shape 

434 

435 # Vector-vector multiplication is handled by dotu 

436 if a.ndim == 1 and b.ndim == 1: 

437 assert out is None 

438 if trans == 'c': 

439 return alpha * np.vdot(b, a) # dotc conjugates *first* argument 

440 else: 

441 return alpha * a.dot(b) 

442 

443 # Map all arrays to 2D arrays 

444 a = a.reshape(-1, a.shape[-1]) 

445 if trans == 'n': 

446 b = b.reshape(b.shape[0], -1) 

447 outshape = a.shape[0], b.shape[1] 

448 else: # 't' or 'c' 

449 b = b.reshape(-1, b.shape[-1]) 

450 

451 # Apply BLAS gemm routine 

452 outshape = a.shape[0], b.shape[trans == 'n'] 

453 if out is None: 

454 # (ATLAS can't handle uninitialized output array) 

455 out = np.zeros(outshape, a.dtype) 

456 else: 

457 out = out.reshape(outshape) 

458 mmmx(alpha, a, 'N', b, trans.upper(), beta, out) 

459 

460 # Determine actual shape of result array 

461 if trans == 'n': 

462 outshape = ashape[:-1] + bshape[1:] 

463 else: # 't' or 'c' 

464 outshape = ashape[:-1] + bshape[:-1] 

465 return out.reshape(outshape) 

466 

467 

468if not hasattr(cgpaw, 'mmm'): 

469 # These are the functions used with noblas=True 

470 # TODO: move these functions elsewhere so that 

471 # they can be used for unit tests 

472 

473 def op(o, m): 

474 if o.upper() == 'N': 

475 return m 

476 if o.upper() == 'T': 

477 return m.T 

478 if o.upper() == 'C': 

479 return m.conj().T 

480 raise ValueError(f'unknown op: {o}') 

481 

482 def rk(alpha, a, beta, c, trans='c'): # noqa 

483 if c.size == 0: 

484 return 

485 if beta == 0: 

486 c[:] = 0.0 

487 else: 

488 c *= beta 

489 if trans == 'n': 

490 c += alpha * a.conj().T.dot(a) 

491 else: 

492 a = a.reshape((len(a), -1)) 

493 c += alpha * a.dot(a.conj().T) 

494 

495 def r2k(alpha, a, b, beta, c, trans='c'): # noqa 

496 if c.size == 0: 

497 return 

498 if beta == 0.0: 

499 c[:] = 0.0 

500 else: 

501 c *= beta 

502 if trans == 'c': 

503 c += (alpha * a.reshape((len(a), -1)) 

504 .dot(b.reshape((len(b), -1)).conj().T) + 

505 alpha * b.reshape((len(b), -1)) 

506 .dot(a.reshape((len(a), -1)).conj().T)) 

507 else: 

508 c += alpha * (a.conj().T @ b + b.conj().T @ a) 

509 

510 def mmm(alpha: T, a: np.ndarray, opa: str, # noqa 

511 b: np.ndarray, opb: str, 

512 beta: T, c: np.ndarray) -> None: 

513 if beta == 0.0: 

514 c[:] = 0.0 

515 else: 

516 c *= beta 

517 c += alpha * op(opa, a).dot(op(opb, b)) 

518 

519 gemmdot = _gemmdot 

520 

521elif not debug: 

522 mmm = cgpaw.mmm # noqa 

523 rk = cgpaw.rk # noqa 

524 r2k = cgpaw.r2k # noqa 

525 gemmdot = _gemmdot 

526 

527else: 

528 def gemmdot(a, b, alpha=1.0, beta=1.0, out=None, trans='n'): 

529 assert a.flags.c_contiguous 

530 assert b.flags.c_contiguous 

531 assert a.dtype == b.dtype 

532 if trans == 'n': 

533 assert a.shape[-1] == b.shape[0] 

534 else: 

535 assert a.shape[-1] == b.shape[-1] 

536 if out is not None: 

537 assert out.flags.c_contiguous 

538 assert a.dtype == out.dtype 

539 assert a.ndim > 1 or b.ndim > 1 

540 if trans == 'n': 

541 assert out.shape == a.shape[:-1] + b.shape[1:] 

542 else: 

543 assert out.shape == a.shape[:-1] + b.shape[:-1] 

544 return _gemmdot(a, b, alpha, beta, out, trans)