Coverage for gpaw/utilities/blas.py: 67%
220 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1# Copyright (C) 2003 CAMP
2# Please see the accompanying LICENSE file for further information.
4"""
5Python wrapper functions for the ``C`` package:
6Basic Linear Algebra Subroutines (BLAS)
8See also:
9https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms
10and
11https://www.netlib.org/lapack/lug/node145.html
12"""
13from typing import TypeVar
15import gpaw.cgpaw as cgpaw
16import numpy as np
17import scipy.linalg.blas as blas
18from gpaw import debug
19from gpaw.new import prod
20from gpaw.typing import Array2D, ArrayND
21from gpaw.utilities import is_contiguous
24def is_finite(array, tril=False):
25 if isinstance(array, np.ndarray):
26 xp = np
27 else:
28 from gpaw.gpu import cupy as xp
29 if tril:
30 array = xp.tril(array)
31 return xp.isfinite(array).all()
34__all__ = ['mmm']
36T = TypeVar('T', float, complex)
39def mmm(alpha: T,
40 a: Array2D,
41 opa: str,
42 b: Array2D,
43 opb: str,
44 beta: T,
45 c: Array2D) -> None:
46 """Matrix-matrix multiplication using dgemm or zgemm.
48 For opa='N' and opb='N', we have:::
50 c <- αab + βc.
52 Use 'T' to transpose matrices and 'C' to transpose and complex conjugate
53 matrices.
54 """
56 assert opa in 'NTC'
57 assert opb in 'NTC'
59 if opa == 'N':
60 a1, a2 = a.shape
61 else:
62 a2, a1 = a.shape
63 if opb == 'N':
64 b1, b2 = b.shape
65 else:
66 b2, b1 = b.shape
67 assert a2 == b1
68 assert c.shape == (a1, b2)
70 assert a.dtype == b.dtype == c.dtype
71 assert a.strides[1] == c.itemsize or a.size == 0
72 assert b.strides[1] == c.itemsize or b.size == 0
73 assert c.strides[1] == c.itemsize or c.size == 0
74 if a.dtype == float:
75 assert not isinstance(alpha, complex)
76 assert not isinstance(beta, complex)
77 else:
78 assert a.dtype == complex
80 cgpaw.mmm(alpha, a, opa, b, opb, beta, c)
83def gpu_mmm(alpha, a, opa, b, opb, beta, c):
84 """Launch CPU or GPU version of mmm()."""
85 m = b.shape[1] if opb == 'N' else b.shape[0]
86 n = a.shape[0] if opa == 'N' else a.shape[1]
87 k = b.shape[0] if opb == 'N' else b.shape[1]
88 lda = a.strides[0] // a.itemsize
89 ldb = b.strides[0] // b.itemsize
90 ldc = c.strides[0] // c.itemsize
91 cgpaw.mmm_gpu(alpha, a.data.ptr, lda, opa,
92 b.data.ptr, ldb, opb, beta,
93 c.data.ptr, ldc, c.itemsize,
94 m, n, k)
97def gpu_scal(alpha, x):
98 """alpha x
100 Performs the operation::
102 x <- alpha * x
104 """
105 if debug:
106 if isinstance(alpha, complex):
107 assert is_contiguous(x, complex)
108 else:
109 assert isinstance(alpha, float)
110 assert x.dtype in [float, complex]
111 assert x.flags.c_contiguous
112 cgpaw.scal_gpu(alpha, x.data.ptr, x.shape, x.dtype)
115def to2d(array: ArrayND) -> Array2D:
116 """2D view af ndarray.
118 >>> to2d(np.zeros((2, 3, 4))).shape
119 (2, 12)
120 """
121 shape = array.shape
122 return array.reshape((shape[0], prod(shape[1:])))
125def mmmx(alpha: T,
126 a: ArrayND,
127 opa: str,
128 b: ArrayND,
129 opb: str,
130 beta: T,
131 c: ArrayND) -> None:
132 """Matrix-matrix multiplication using dgemm or zgemm.
134 Arrays a, b and c are converted to 2D arrays before calling mmm().
135 """
136 mmm(alpha, to2d(a), opa, to2d(b), opb, beta, to2d(c))
139def gpu_gemm(alpha, a, b, beta, c, transa='n'):
140 """General Matrix Multiply.
142 Performs the operation::
144 c <- alpha * b.a + beta * c
146 If transa is "n", ``b.a`` denotes the matrix multiplication defined by::
148 _
149 \
150 (b.a) = ) b * a
151 ijkl... /_ ip pjkl...
152 p
154 If transa is "t" or "c", ``b.a`` denotes the matrix multiplication
155 defined by::
157 _
158 \
159 (b.a) = ) b * a
160 ij /_ iklm... jklm...
161 klm...
163 where in case of "c" also complex conjugate of a is taken.
164 """
165 if debug:
166 assert beta == 0.0 or is_finite(c)
168 assert (a.dtype == float and b.dtype == float and c.dtype == float and
169 isinstance(alpha, float) and isinstance(beta, float) or
170 a.dtype == complex and b.dtype == complex and
171 c.dtype == complex)
172 assert a.flags.c_contiguous
173 if transa == 'n':
174 assert c.flags.c_contiguous or (c.ndim == 2
175 and c.strides[1] == c.itemsize)
176 assert b.ndim == 2
177 assert b.strides[1] == b.itemsize
178 assert a.shape[0] == b.shape[1]
179 assert c.shape == b.shape[0:1] + a.shape[1:]
180 else:
181 assert b.size == 0 or b[0].flags.c_contiguous
182 assert c.strides[1] == c.itemsize
183 assert a.shape[1:] == b.shape[1:]
184 assert c.shape == (b.shape[0], a.shape[0])
186 cgpaw.gemm_gpu(alpha, a.data.ptr, a.shape,
187 b.data.ptr, b.shape, beta,
188 c.data.ptr, c.shape,
189 a.dtype, transa)
192def gpu_gemv(alpha, a, x, beta, y, trans='t'):
193 """General Matrix Vector product.
195 Performs the operation::
197 y <- alpha * a.x + beta * y
199 ``a.x`` denotes matrix multiplication, where the product-sum is
200 over the entire length of the vector x and
201 the first dimension of a (for trans='n'), or
202 the last dimension of a (for trans='t' or 'c').
204 If trans='c', the complex conjugate of a is used. The default is
205 trans='t', i.e. behaviour like np.dot with a 2D matrix and a vector.
206 """
207 if debug:
208 assert (a.dtype == float and x.dtype == float and y.dtype == float and
209 isinstance(alpha, float) and isinstance(beta, float) or
210 a.dtype == complex and x.dtype == complex and
211 y.dtype == complex)
212 assert a.flags.c_contiguous
213 assert y.flags.c_contiguous
214 assert x.ndim == 1
215 assert y.ndim == a.ndim - 1
216 if trans == 'n':
217 assert a.shape[0] == x.shape[0]
218 assert a.shape[1:] == y.shape
219 else:
220 assert a.shape[-1] == x.shape[0]
221 assert a.shape[:-1] == y.shape
223 cgpaw.gemv_gpu(alpha, a.data.ptr, a.shape,
224 x.data.ptr, x.shape, beta,
225 y.data.ptr, a.dtype,
226 trans)
229which_axpy = {
230 np.float32: blas.saxpy,
231 np.float64: blas.daxpy,
232 np.complex64: blas.caxpy,
233 np.complex128: blas.zaxpy
234}
237def axpy(alpha, x, y):
238 """alpha x plus y.
240 Performs the operation::
242 y <- alpha * x + y
244 """
245 if x.size == 0:
246 return
247 assert x.flags.contiguous
248 assert y.flags.contiguous
249 x = x.ravel()
250 y = y.ravel()
251 z = which_axpy[np.dtype(x.dtype).type](x, y, a=alpha)
252 assert z is y, (x, y, x.shape, y.shape)
255def gpu_axpy(alpha, x, y):
256 """alpha x plus y.
258 Performs the operation::
260 y <- alpha * x + y
262 """
263 if debug:
264 if isinstance(alpha, complex):
265 assert is_contiguous(x, complex) and is_contiguous(y, complex)
266 else:
267 assert isinstance(alpha, float)
268 assert x.dtype in [float, complex]
269 assert x.dtype == y.dtype
270 assert x.flags.c_contiguous and y.flags.c_contiguous
271 assert x.shape == y.shape
273 cgpaw.axpy_gpu(alpha, x.data.ptr, x.shape,
274 y.data.ptr, y.shape,
275 x.dtype)
278def rk(alpha, a, beta, c, trans='c'):
279 """Rank-k update of a matrix.
281 For ``trans='c'`` the following operation is performed:::
283 †
284 c <- αaa + βc,
286 and for ``trans='t'`` we get:::
288 †
289 c <- αa a + βc
291 If the ``a`` array has more than 2 dimensions then the 2., 3., ...
292 axes are combined.
294 Only the lower triangle of ``c`` will contain sensible numbers.
295 """
296 if debug:
297 assert beta == 0.0 or is_finite(c, tril=True)
299 assert (a.dtype == float and c.dtype == float or
300 a.dtype == complex and c.dtype == complex)
301 assert a.flags.c_contiguous, (a.shape, a.strides, a.dtype)
302 assert a.ndim > 1
303 if trans == 'n':
304 assert c.shape == (a.shape[1], a.shape[1])
305 else:
306 assert c.shape == (a.shape[0], a.shape[0])
307 assert c.strides[1] == c.itemsize or c.size == 0
309 cgpaw.rk(alpha, a, beta, c, trans)
312def gpu_rk(alpha, a, beta, c, trans='c'):
313 """Launch CPU or GPU version of rk()."""
314 cgpaw.rk_gpu(alpha, a.data.ptr, a.shape,
315 beta, c.data.ptr, c.shape,
316 a.dtype)
319def r2k(alpha, a, b, beta, c, trans='c'):
320 """Rank-2k update of a matrix.
322 Performs the operation::
324 dag cc dag
325 c <- alpha * a . b + alpha * b . a + beta * c
327 or if trans='n'::
328 dag cc dag
329 c <- alpha * a . b + alpha * b . a + beta * c
331 where ``a.b`` denotes the matrix multiplication defined by::
333 _
334 \
335 (a.b) = ) a * b
336 ij /_ ipklm... pjklm...
337 pklm...
339 ``cc`` denotes complex conjugation.
341 ``dag`` denotes the hermitian conjugate (complex conjugation plus a
342 swap of axis 0 and 1).
344 Only the lower triangle of ``c`` will contain sensible numbers.
345 """
346 if debug:
347 assert beta == 0.0 or is_finite(c, tril=True)
348 assert (a.dtype == float and b.dtype == float and c.dtype == float or
349 a.dtype == complex and b.dtype == complex and
350 c.dtype == complex)
351 assert a.flags.c_contiguous and b.flags.c_contiguous
352 assert a.ndim > 1
353 assert a.shape == b.shape
354 if trans == 'c':
355 assert c.shape == (a.shape[0], a.shape[0])
356 else:
357 assert c.shape == (a.shape[1], a.shape[1])
358 assert c.strides[1] == c.itemsize or c.size == 0
360 cgpaw.r2k(alpha, a, b, beta, c, trans)
363def gpu_r2k(alpha, a, b, beta, c, trans='c'):
364 """Launch CPU or GPU version of r2k()."""
365 cgpaw.r2k_gpu(alpha, a.data.ptr, a.shape,
366 b.data.ptr, b.shape, beta,
367 c.data.ptr, c.shape,
368 a.dtype)
371def gpu_dotc(a, b):
372 r"""Dot product, conjugating the first vector with complex arguments.
374 Returns the value of the operation::
376 _
377 \ cc
378 ) a * b
379 /_ ijk... ijk...
380 ijk...
382 ``cc`` denotes complex conjugation.
383 """
384 if debug:
385 assert ((is_contiguous(a, float) and is_contiguous(b, float)) or
386 (is_contiguous(a, complex) and is_contiguous(b, complex)))
387 assert a.shape == b.shape
389 return cgpaw.dotc_gpu(a.data.ptr, a.shape,
390 b.data.ptr, a.dtype)
393def gpu_dotu(a, b):
394 """Dot product, NOT conjugating the first vector with complex arguments.
396 Returns the value of the operation::
398 _
399 \
400 ) a * b
401 /_ ijk... ijk...
402 ijk...
405 """
406 if debug:
407 assert ((is_contiguous(a, float) and is_contiguous(b, float)) or
408 (is_contiguous(a, complex) and is_contiguous(b, complex)))
409 assert a.shape == b.shape
411 return cgpaw.dotu_gpu(a.data.ptr, a.shape,
412 b.data.ptr, a.dtype)
415def _gemmdot(a, b, alpha=1.0, beta=1.0, out=None, trans='n'):
416 """Matrix multiplication using gemm.
418 return reference to out, where::
420 out <- alpha * a . b + beta * out
422 If out is None, a suitably sized zero array will be created.
424 ``a.b`` denotes matrix multiplication, where the product-sum is
425 over the last dimension of a, and either
426 the first dimension of b (for trans='n'), or
427 the last dimension of b (for trans='t' or 'c').
429 If trans='c', the complex conjugate of b is used.
430 """
431 # Store original shapes
432 ashape = a.shape
433 bshape = b.shape
435 # Vector-vector multiplication is handled by dotu
436 if a.ndim == 1 and b.ndim == 1:
437 assert out is None
438 if trans == 'c':
439 return alpha * np.vdot(b, a) # dotc conjugates *first* argument
440 else:
441 return alpha * a.dot(b)
443 # Map all arrays to 2D arrays
444 a = a.reshape(-1, a.shape[-1])
445 if trans == 'n':
446 b = b.reshape(b.shape[0], -1)
447 outshape = a.shape[0], b.shape[1]
448 else: # 't' or 'c'
449 b = b.reshape(-1, b.shape[-1])
451 # Apply BLAS gemm routine
452 outshape = a.shape[0], b.shape[trans == 'n']
453 if out is None:
454 # (ATLAS can't handle uninitialized output array)
455 out = np.zeros(outshape, a.dtype)
456 else:
457 out = out.reshape(outshape)
458 mmmx(alpha, a, 'N', b, trans.upper(), beta, out)
460 # Determine actual shape of result array
461 if trans == 'n':
462 outshape = ashape[:-1] + bshape[1:]
463 else: # 't' or 'c'
464 outshape = ashape[:-1] + bshape[:-1]
465 return out.reshape(outshape)
468if not hasattr(cgpaw, 'mmm'):
469 # These are the functions used with noblas=True
470 # TODO: move these functions elsewhere so that
471 # they can be used for unit tests
473 def op(o, m):
474 if o.upper() == 'N':
475 return m
476 if o.upper() == 'T':
477 return m.T
478 if o.upper() == 'C':
479 return m.conj().T
480 raise ValueError(f'unknown op: {o}')
482 def rk(alpha, a, beta, c, trans='c'): # noqa
483 if c.size == 0:
484 return
485 if beta == 0:
486 c[:] = 0.0
487 else:
488 c *= beta
489 if trans == 'n':
490 c += alpha * a.conj().T.dot(a)
491 else:
492 a = a.reshape((len(a), -1))
493 c += alpha * a.dot(a.conj().T)
495 def r2k(alpha, a, b, beta, c, trans='c'): # noqa
496 if c.size == 0:
497 return
498 if beta == 0.0:
499 c[:] = 0.0
500 else:
501 c *= beta
502 if trans == 'c':
503 c += (alpha * a.reshape((len(a), -1))
504 .dot(b.reshape((len(b), -1)).conj().T) +
505 alpha * b.reshape((len(b), -1))
506 .dot(a.reshape((len(a), -1)).conj().T))
507 else:
508 c += alpha * (a.conj().T @ b + b.conj().T @ a)
510 def mmm(alpha: T, a: np.ndarray, opa: str, # noqa
511 b: np.ndarray, opb: str,
512 beta: T, c: np.ndarray) -> None:
513 if beta == 0.0:
514 c[:] = 0.0
515 else:
516 c *= beta
517 c += alpha * op(opa, a).dot(op(opb, b))
519 gemmdot = _gemmdot
521elif not debug:
522 mmm = cgpaw.mmm # noqa
523 rk = cgpaw.rk # noqa
524 r2k = cgpaw.r2k # noqa
525 gemmdot = _gemmdot
527else:
528 def gemmdot(a, b, alpha=1.0, beta=1.0, out=None, trans='n'):
529 assert a.flags.c_contiguous
530 assert b.flags.c_contiguous
531 assert a.dtype == b.dtype
532 if trans == 'n':
533 assert a.shape[-1] == b.shape[0]
534 else:
535 assert a.shape[-1] == b.shape[-1]
536 if out is not None:
537 assert out.flags.c_contiguous
538 assert a.dtype == out.dtype
539 assert a.ndim > 1 or b.ndim > 1
540 if trans == 'n':
541 assert out.shape == a.shape[:-1] + b.shape[1:]
542 else:
543 assert out.shape == a.shape[:-1] + b.shape[:-1]
544 return _gemmdot(a, b, alpha, beta, out, trans)