Coverage for gpaw/core/matrix.py: 64%
681 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1"""BLACS distributed matrix object."""
2from __future__ import annotations
4from types import ModuleType
5from typing import Dict, Tuple
6import gpaw.cgpaw as cgpaw
7import numpy as np
8import scipy.linalg as sla
10import gpaw.utilities.blas as blas
11from gpaw import debug, get_scipy_version
12from gpaw.gpu import cupy as cp, cupy_eigh, XP, gpu_gemm
13from gpaw.mpi import MPIComm, _Communicator, serial_comm
14from gpaw.typing import Array1D, ArrayLike1D, ArrayLike2D, Array2D
16_global_blacs_context_store: Dict[Tuple[_Communicator, int, int], int] = {}
19def suggest_blocking(N: int, ncpus: int) -> tuple[int, int, int]:
20 """Suggest blocking of ``NxN`` matrix.
22 Returns rows, columns, blocksize tuple.
24 >>> suggest_blocking(10, 6)
25 (3, 2, 2)
26 """
28 nprow = ncpus
29 npcol = 1
31 # Make npcol and nprow as close to each other as possible
32 npcol_try = npcol
33 while npcol_try < nprow:
34 if ncpus % npcol_try == 0:
35 npcol = npcol_try
36 nprow = ncpus // npcol
37 npcol_try += 1
39 assert npcol * nprow == ncpus
41 # ScaLAPACK creates trouble if there aren't at least a few whole blocks.
42 # Choose block size so that there will always be at least one whole block
43 # and at least two blocks in total.
44 blocksize = max((N - 2) // max(nprow, npcol), 1)
45 # The next commented line would give more whole blocks.
46 # blocksize = max(N // max(nprow, npcol) - 2, 1)
48 # Use block size that is a power of 2 and at most 64
49 blocksize = 2**int(np.log2(blocksize))
50 blocksize = max(min(blocksize, 64), 1)
52 return nprow, npcol, blocksize
55class MatrixWithNoData:
56 def __init__(self,
57 M: int,
58 N: int,
59 *,
60 dtype=None,
61 dist: MatrixDistribution | tuple | None = None):
62 self.shape = (M, N)
63 self.dtype = dtype
64 self.data = np.empty((0, 0), dtype)
65 dist = dist or ()
66 if isinstance(dist, tuple):
67 kwargs = {key: val for key, val in zip(['comm', 'r', 'c', 'b'],
68 dist)}
69 dist = create_distribution(M, N, **kwargs)
70 self.dist = dist
72 def create(self) -> Matrix:
73 return Matrix(*self.shape, dtype=self.dtype, dist=self.dist)
76class Matrix(XP):
77 def __init__(self,
78 M: int,
79 N: int,
80 *,
81 dtype=None,
82 data: ArrayLike2D | None = None,
83 dist: MatrixDistribution | tuple | None = None,
84 xp=None):
85 """Matrix object.
87 Parameters
88 ----------
89 M:
90 Rows.
91 N:
92 Columns.
93 dtype:
94 Data type (float or complex).
95 dist:
96 BLACS distribution given as
97 (communicator, rows, columns, blocksize)
98 tuple. Default is None meaning no distribution.
99 data:
100 Numpy ndarray to use for storage. By default, a new ndarray
101 will be allocated.
102 """
103 self.shape = (M, N)
105 if data is None or isinstance(data, (np.ndarray, cp.ndarray)):
106 pass
107 else:
108 data = np.asarray(data)
110 if dtype is None:
111 if data is None:
112 dtype = float
113 else:
114 dtype = data.dtype
115 self.dtype = np.dtype(dtype)
116 assert np.dtype(self.dtype) in \
117 [np.float32, np.float64, np.complex64, np.complex128], dtype
119 self.xp: ModuleType
120 if xp is None:
121 if isinstance(dist, CuPyDistribution):
122 xp = cp
123 elif data is not None and not isinstance(data, np.ndarray):
124 xp = cp
125 else:
126 xp = np
127 XP.__init__(self, xp)
129 dist = dist or ()
130 if isinstance(dist, tuple):
131 kwargs = {key: val for key, val in zip(['comm', 'r', 'c', 'b'],
132 dist)}
133 dist = create_distribution(M, N, xp=self.xp, **kwargs)
134 else:
135 assert self.shape == dist.full_shape
136 self.dist = dist
138 self.data: Array2D
139 if data is None:
140 self.data = self.xp.empty(dist.shape, self.dtype)
141 else:
142 assert data.shape == dist.shape, (data.shape, dist.shape, dist)
143 self.data = data
145 def __repr__(self):
146 dist = str(self.dist).split('(')[1]
147 if self.xp is cp:
148 dist = 'xp=cp, ' + dist
149 return f'Matrix({self.dtype.name}: {dist}'
151 def new(self, dist='inherit', data=None) -> Matrix:
152 """Create new matrix of same shape and dtype.
154 Default is to use same BLACS distribution. Use dist to use another
155 distribution.
156 """
157 return Matrix(*self.shape,
158 dtype=self.dtype,
159 dist=self.dist if dist == 'inherit' else dist,
160 data=data,
161 xp=self.xp)
163 def copy(self) -> Matrix:
164 """Create a copy."""
165 M = self.new()
166 M.data[:] = self.data
167 return M
169 def __setitem__(self, item, value):
170 assert item == slice(None)
171 assert isinstance(value, Matrix)
172 self.data[:] = value.data
174 def __iadd__(self, other):
175 if isinstance(other, Matrix):
176 other = other.data
177 self.data += other
178 return self
180 def multiply(self,
181 other,
182 alpha=1.0,
183 opa='N',
184 opb='N',
185 out=None,
186 data_buffer=None,
187 beta=0.0,
188 symmetric=False) -> Matrix:
189 """BLAS matrix-multiplication with other matrix."""
190 if not isinstance(other, Matrix):
191 other = other.matrix
192 A = self
193 B = other
194 dist = self.dist
195 if out is None:
196 assert beta == 0.0
197 M = A.shape[0] if opa == 'N' else A.shape[1]
198 N = B.shape[1] if opb == 'N' else B.shape[0]
199 out = Matrix(M, N, dtype=A.dtype, dist=dist.new(M, N))
200 elif not isinstance(out, Matrix):
201 out = out.matrix
202 if out.data is other.data:
203 # Repeatably call multiply using data_buffer
204 assert opa == 'N', 'Not implemented'
205 assert opb == 'N', 'Not implemented'
206 assert not beta, 'Not implemented'
207 assert other.shape[0] == self.shape[0]
209 # Assert simple (only row distributed) distributions:
210 assert self.shape[1] == self.data.shape[1]
211 assert other.shape[1] == other.data.shape[1]
212 assert out.shape[1] == out.data.shape[1]
214 if data_buffer is None:
215 raise ValueError('other is out, and data_buffer is None')
217 assert isinstance(data_buffer, other.xp.ndarray)
218 dtype = other.data.dtype
219 data_buffer = data_buffer.view(dtype)
220 if other.data.shape[0] > 0:
221 # Obtain buffer size s.t. the maximum number of
222 # columns in other.data fits into data_buffer
223 buffer_size = min(
224 data_buffer.size // other.data.shape[0],
225 other.data.shape[1])
226 else:
227 # There is no data in other. Thus buffer_size
228 # fits all.
229 buffer_size = other.data.shape[1]
230 buffer_size = dist.comm.min_scalar(buffer_size)
231 max_B = other.data.shape[1]
233 if buffer_size >= max_B:
234 # No need for sliced multiply
235 other_buffer = other.new(
236 data=data_buffer[:other.data.size].reshape(
237 other.data.shape))
238 other_buffer.data[:] = other.data
239 dist.multiply(alpha, A, opa, other_buffer, opb, beta, out,
240 symmetric=symmetric)
241 return out
243 # Sliced multiply
244 for i in range(0, max_B, buffer_size):
245 r_buffer_size = min(max(other.data.shape[1] - i, 0),
246 buffer_size)
247 l_buffer_size = r_buffer_size * other.data.shape[0]
248 buffer = Matrix(
249 M=other.shape[0],
250 N=r_buffer_size,
251 data=data_buffer[
252 :l_buffer_size].reshape(
253 (other.data.shape[0], r_buffer_size)
254 ),
255 dist=dist.new(M=other.shape[0], N=r_buffer_size),
256 xp=other.xp)
257 buffer.data[:] \
258 = other.data[:, i:i + buffer_size]
259 out_view = buffer.new(
260 data=out.data[:, i:i + buffer_size])
261 dist.multiply(alpha, A, opa, buffer,
262 opb, beta, out_view, symmetric=False)
263 return out
265 dist.multiply(alpha, A, opa, B, opb, beta, out, symmetric=symmetric)
266 return out
268 def redist(self, other: Matrix) -> None:
269 """Redistribute to other BLACS layout."""
270 if self is other:
271 return
272 d1 = self.dist
273 d2 = other.dist
274 n1 = d1.rows * d1.columns
275 n2 = d2.rows * d2.columns
276 if n1 == n2 == 1:
277 other.data[:] = self.data
278 return
280 if n2 == 1 and d1.blocksize is None:
281 assert d2.blocksize is None
282 assert d1.columns == 1
283 comm = d1.comm
284 if comm.rank == 0:
285 M = self.shape[0]
286 m = (M + comm.size - 1) // comm.size
287 other.data[:m] = self.data
288 for r in range(1, comm.size):
289 m1 = min(r * m, M)
290 m2 = min(m1 + m, M)
291 comm.receive(other.data[m1:m2], r)
292 else:
293 comm.send(self.data, 0)
294 return
296 if n1 == 1 and d2.blocksize is None:
297 assert d1.blocksize is None
298 assert d1.columns == 1
299 comm = d1.comm
300 if comm.rank == 0:
301 M = self.shape[0]
302 m = (M + comm.size - 1) // comm.size
303 other.data[:] = self.data[:m]
304 for r in range(1, comm.size):
305 m1 = min(r * m, M)
306 m2 = min(m1 + m, M)
307 comm.send(self.data[m1:m2], r)
308 else:
309 comm.receive(other.data, 0)
310 return
312 c = d1.comm if d1.comm.size > d2.comm.size else d2.comm
313 n = max(n1, n2)
314 if n < c.size:
315 c = c.new_communicator(np.arange(n))
316 if c is not None:
317 M, N = self.shape
318 d1 = create_distribution(M, N, c,
319 d1.rows, d1.columns, d1.blocksize)
320 d2 = create_distribution(M, N, c,
321 d2.rows, d2.columns, d2.blocksize)
322 if n1 == n:
323 ctx = d1.desc[1]
324 else:
325 ctx = d2.desc[1]
326 redist(d1, self.data, d2, other.data, ctx)
328 def gather(self, root: int = 0, broadcast=False) -> Matrix:
329 """Gather the Matrix on the root rank.
331 Returns a new Matrix distributed so that all data is on the root rank
332 """
333 assert root == 0
334 if self.dist.comm.size > 1:
335 S = self.new(dist=(self.dist.comm, 1, 1))
336 self.redist(S)
337 if broadcast:
338 if self.dist.comm.rank > 0:
339 S = self.new(dist=None)
340 self.dist.comm.broadcast(S.data, 0)
341 else:
342 S = self
344 return S
346 def inv(self, uplo='L'):
347 """Inplace inversion."""
348 assert uplo == 'L'
349 M, N = self.shape
350 assert M == N
351 dist = self.dist
352 if dist.comm.size == 1:
353 self.tril2full()
354 self.data[:] = sla.inv(self.data,
355 overwrite_a=True,
356 check_finite=debug)
357 return
358 bc, br = dist.desc[4:6]
359 assert bc == br
360 info = cgpaw.scalapack_inverse(self.data, dist.desc, 'U')
361 if info != 0:
362 raise ValueError(f'scalapack_inverse error: {info}')
364 def invcholesky(self) -> None:
365 """In-place inverse of Cholesky decomposition.
367 Calculate a lower triangle matrix `L` where:::
369 †
370 LSL = 1,
372 and `S` is self. Only the lower part of `S` is used.
374 >>> S = Matrix(2, 2, data=[[1.0, np.nan],
375 ... [0.1, 1.0]])
376 >>> S.invcholesky()
377 >>> S.data
378 array([[ 1. , -0. ],
379 [-0.10050378, 1.00503782]])
380 """
381 S = self.gather()
382 if self.dist.comm.rank == 0:
383 if isinstance(S.data, np.ndarray):
384 if debug:
385 S.data[np.triu_indices(S.shape[0], 1)] = 42.0
386 L_nn = sla.cholesky(S.data,
387 lower=True,
388 overwrite_a=True,
389 check_finite=debug)
390 S.data[:] = sla.inv(L_nn,
391 overwrite_a=True,
392 check_finite=debug)
393 else:
394 S.tril2full()
395 L_nn = cp.linalg.cholesky(S.data)
396 S.data[:] = cp.linalg.inv(L_nn)
398 if S is not self:
399 S.redist(self)
401 def eigh(self,
402 S=None,
403 *,
404 cc=False,
405 scalapack=(None, 1, 1, None),
406 limit: int | None = None) -> Array1D:
407 """Calculate eigenvectors and eigenvalues.
409 Matrix must be symmetric/hermitian and stored in lower half.
410 If ``S`` is given, solve a generalized eigenvalue problem.
412 Parameters
413 ----------
414 cc: bool
415 Complex conjugate matrix before finding eigenvalues.
416 scalapack: tuple
417 BLACS distribution for ScaLapack to use. Default is to do serial
418 diagonalization.
419 limit:
420 Number of eigenvector and values to find. Defaults to all.
421 """
422 slcomm, rows, columns, blocksize = scalapack
423 slcomm = slcomm or self.dist.comm
424 dist = (slcomm, rows, columns, blocksize)
426 redist = (rows != self.dist.rows or
427 columns != self.dist.columns or
428 blocksize != self.dist.blocksize)
430 if redist:
431 H = self.new(dist=dist)
432 self.redist(H)
433 if S is not None:
434 S0 = S
435 S = S0.new(dist=dist)
436 S0.redist(S)
437 else:
438 assert self.dist.comm.size == slcomm.size
439 H = self
441 if limit == H.shape[0]:
442 limit = None
444 if limit:
445 eps = self.xp.empty(limit)
446 else:
447 eps = self.xp.empty(H.shape[0])
449 if rows * columns == 1:
450 if self.dist.comm.rank == 0:
451 if cc and np.issubdtype(H.dtype, np.complexfloating):
452 np.negative(H.data.imag, H.data.imag)
453 if debug:
454 H.data[np.triu_indices(H.shape[0], 1)] = 42.0
455 if S is None:
456 if self.xp is not np:
457 assert isinstance(H.data, cp.ndarray)
458 eps[:], H.data.T[:] = cupy_eigh(H.data, UPLO='L')
459 else:
460 eps[:], H.data.T[:] = sla.eigh(
461 H.data,
462 lower=True,
463 overwrite_a=True,
464 check_finite=debug,
465 driver='evx' if H.data.size == 1 else 'evd')
466 else:
467 if self.xp is cp:
468 assert self.dist.comm.size == 1
469 S.invcholesky()
470 self.tril2full()
471 eigs = self.eighg(S)
472 self.data[:] = self.data.T.copy()
473 return eigs
474 if debug:
475 S.data[self.xp.triu_indices(H.shape[0], 1)] = 42.0
476 eps, evecs = sla.eigh(
477 H.data,
478 S.data,
479 lower=True,
480 overwrite_a=True,
481 overwrite_b=True,
482 check_finite=debug,
483 subset_by_index=(0, limit - 1) if limit else None)
484 limit = limit or len(eps)
485 H.data.T[:, :limit] = evecs
486 self.dist.comm.broadcast(eps, 0)
487 else:
488 if slcomm.rank < rows * columns:
489 assert cc
490 assert S is None
491 array = H.data.copy()
492 info = cgpaw.scalapack_diagonalize_dc(array, H.dist.desc, 'U',
493 H.data, eps)
494 assert info == 0, info
496 # necessary to broadcast eps when some ranks are not used
497 # in current scalapack parameter set
498 # eg. (2, 1, 2) with 4 processes
499 if rows * columns < slcomm.size:
500 H.dist.comm.broadcast(eps, 0)
502 if redist:
503 H.redist(self)
505 return eps
507 def eighg(self, L: Matrix, comm2: MPIComm = serial_comm) -> Array1D:
508 """Solve generalized eigenvalue problem.
510 With `H` being self, we solve for the eigenvectors `C` and the
511 eigenvalues `Λ` (a diagonal matrix):::
513 HC = SCΛ,
515 where `L` is a lower triangle matrix such that:::
517 †
518 LSL = 1.
520 The solution has these three steps:::
522 ~ † ~~ ~ †~
523 H = LHL , HC = CΛ, C = L C.
525 Note that `H` must be the full matrix not just half of it!
527 """
528 M, N = self.shape
529 assert M == N
530 comm = self.dist.comm
532 if comm2.rank == 0:
533 if comm.size == 1:
534 H = self
535 L0 = L
536 else:
537 # TODO: Use scalapack
538 H = self.new(dist=(comm,))
539 self.redist(H)
540 L0 = self.new(dist=(comm,))
541 L.redist(L0)
542 if comm.rank == 0:
543 if self.xp is not np:
544 return self.dist.eighg(self, L0)
545 tmp_MM = np.empty_like(H.data)
546 L_MM = L0.data
547 blas.mmm(1.0, L_MM, 'N', H.data, 'N', 0.0, tmp_MM)
548 blas.r2k(0.5, tmp_MM, L_MM, 0.0, H.data)
549 # Ht_MM = L_MM @ H.data @ L_MM.conj().T
550 if get_scipy_version() >= [1, 9]:
551 driver = 'evx' if M == 1 else 'evd'
552 else:
553 driver = None
554 eig_n, Ct_Mn = sla.eigh(
555 H.data,
556 overwrite_a=True,
557 check_finite=debug,
558 driver=driver)
559 assert Ct_Mn.flags.f_contiguous
560 blas.mmm(1.0, L_MM, 'C', Ct_Mn.T, 'T', 0.0, H.data)
561 # H.data[:] = L_MM.T.conj() @ Ct_Mn
562 else:
563 eig_n = np.empty(M)
565 if comm.size > 1:
566 H.redist(self)
567 comm.broadcast(eig_n, 0)
569 if comm2.rank > 0:
570 eig_n = np.empty(M)
571 comm2.broadcast(eig_n, 0)
572 comm2.broadcast(self.data, 0)
574 return eig_n
576 def complex_conjugate(self) -> None:
577 """Inplace complex conjugation."""
578 if np.issubdtype(self.dtype, np.complexfloating):
579 self.xp.negative(self.data.imag, self.data.imag)
581 def add_hermitian_conjugate(self, scale: float = 1.0) -> None:
582 """Add hermitian conjugate to myself."""
583 if self.dist.comm.size == 1:
584 if scale != 1.0:
585 self.data *= scale
586 self.data += self.data.conj().T
587 return
588 tmp = self.copy()
589 cgpaw.pblas_tran(*self.shape, scale, tmp.data, scale, self.data,
590 self.dist.desc, self.dist.desc, True)
592 def tril2full(self) -> None:
593 """Fill in upper triangle from lower triangle.
595 For a real matrix::
597 a ? ? a b d
598 b c ? -> b c e
599 d e f d e f
601 For a complex matrix, the complex conjugate of the lower part will
602 be inserted into the upper part.
603 """
604 M, N = self.shape
605 assert M == N
607 dist = self.dist
609 if dist.comm.size == 1 or dist.rows == 1 and dist.columns == 1:
610 if dist.comm.rank == 0:
611 lower = self.xp.tri(M, k=-1, dtype=bool)
612 self.data.T[lower] = self.data[lower].conj()
613 return
615 desc = dist.desc
616 cgpaw.scalapack_set(self.data, desc, 0.0, 0.0, 'L', M - 1, M - 1, 2, 1)
617 buf = self.data.copy()
618 # Set diagonal to zero in the copy:
619 cgpaw.scalapack_set(buf, desc, 0.0, 0.0, 'L', M, M, 1, 1)
620 # Now transpose tmp_mm adding the result to the original matrix:
621 cgpaw.pblas_tran(M, M, 1.0, buf, 1.0, self.data, desc, desc, True)
623 def add_to_diagonal(self, d: ArrayLike1D | float) -> None:
624 """Add list of numbers or single number to diagonal of matrix."""
625 n1, n2 = self.dist.my_row_range()
626 M, N = self.shape
627 assert M == N
628 self.data.ravel()[n1::N + 1] += d
630 def to_cpu(self) -> Matrix:
631 """Create new matrix object with values transfered from GPU to CPU."""
632 return self.to_xp(np)
634 def to_xp(self, xp) -> Matrix:
635 """Create new matrix object with data on GPU or CPU."""
636 if xp is self.xp:
637 assert xp is np, 'cp -> cp should not be needed!'
638 return self
639 if xp is np:
640 return self.dist.matrix(data=cp.asnumpy(self.data))
641 return self.dist.matrix(data=cp.asarray(self.data))
643 def to_dtype(self, dtype) -> Matrix:
644 """Convert to new data type."""
645 if dtype == self.dtype:
646 return self
647 return self.dist.matrix(data=self.data.astype(dtype))
650def _matrix(M):
651 """Dig out Matrix object from wrapper(s)."""
652 if isinstance(M, Matrix):
653 return M
654 return _matrix(M.matrix)
657def redist(dist1, M1, dist2, M2, context):
658 cgpaw.scalapack_redist(dist1.desc, dist2.desc,
659 M1, M2,
660 dist1.desc[2], dist1.desc[3],
661 1, 1, 1, 1, # 1-indexing
662 context, 'G')
665def create_distribution(M: int,
666 N: int,
667 comm: MPIComm | None = None,
668 r: int = 1,
669 c: int = 1,
670 b: int | None = None,
671 xp=None) -> MatrixDistribution:
672 if xp is cp:
673 assert b is None
674 if r == 1 and c == 1:
675 pass # comm = None
676 comm = comm or serial_comm
677 return CuPyDistribution(M, N, comm,
678 r if r != -1 else comm.size,
679 c if c != -1 else comm.size,
680 b)
682 if comm is None or comm.size == 1:
683 assert r == 1 and abs(c) == 1 or c == 1 and abs(r) == 1
684 return NoDistribution(M, N)
686 return BLACSDistribution(M, N, comm,
687 r if r != -1 else comm.size,
688 c if c != -1 else comm.size,
689 b)
692class MatrixDistribution:
693 comm: MPIComm
694 rows: int
695 columns: int
696 blocksize: int | None # None means everything on rank=0
697 shape: tuple[int, int]
698 full_shape: tuple[int, int]
699 desc: Array1D
701 def matrix(self, dtype=None, data=None):
702 return Matrix(*self.full_shape, dtype=dtype, data=data, dist=self)
704 def multiply(self, alpha, a, opa, b, opb, beta, c, symmetric):
705 raise NotImplementedError
707 def eighg(self, H, L):
708 raise NotImplementedError
710 def new(self, M, N):
711 raise NotImplementedError
713 def my_row_range(self) -> tuple[int, int]:
714 """Return indices for range of my rows.
716 >>> Matrix(2, 2).dist.my_row_range()
717 (0, 2)
718 """
719 ok = (self.rows == self.comm.size and
720 self.columns == 1 and
721 self.blocksize is None)
722 if not ok:
723 raise ValueError(f'Can not create slice of distribution: {self}')
724 M = self.full_shape[0]
725 b = (M + self.rows - 1) // self.rows
726 n1 = self.comm.rank * b
727 n2 = min(n1 + b, M)
728 return n1, n2
731class NoDistribution(MatrixDistribution):
732 comm = serial_comm
733 rows = 1
734 columns = 1
735 blocksize = None
737 def __init__(self, M, N):
738 self.shape = (M, N)
739 self.full_shape = (M, N)
741 def __str__(self):
742 return 'NoDistribution({}x{})'.format(*self.shape)
744 def global_index(self, n):
745 return n
747 def new(self, M, N):
748 return NoDistribution(M, N)
750 def multiply(self, alpha, a, opa, b, opb, beta, c, symmetric):
751 if symmetric:
752 if opa == 'N':
753 assert opb == 'C' or opb == 'T' and a.dtype == float
754 if a is b:
755 blas.rk(alpha, a.data, beta, c.data)
756 else:
757 if beta == 1.0 and a.shape[1] == 0:
758 return
759 blas.r2k(0.5 * alpha, a.data, b.data, beta, c.data)
760 else:
761 1 / 0
762 assert opa == 'C' and opb == 'N'
763 assert a is not b
764 blas.r2k(0.5 * alpha, a.data, b.data, beta, c.data, 'n')
766 else:
767 blas.mmm(alpha, a.data, opa, b.data, opb, beta, c.data)
770class BLACSDistribution(MatrixDistribution):
771 serial = False
773 def __init__(self, M, N, comm, r, c, b):
774 self.comm = comm
775 self.rows = r
776 self.columns = c
777 self.blocksize = b
778 self.full_shape = (M, N)
779 self.simple = False
781 key = (comm, r, c)
782 context = _global_blacs_context_store.get(key)
783 if context is None:
784 try:
785 context = cgpaw.new_blacs_context(comm.get_c_object(),
786 c, r, 'R')
787 except AttributeError:
788 pass
789 else:
790 _global_blacs_context_store[key] = context
792 if b is None:
793 if c == 1:
794 br = (M + r - 1) // r
795 bc = max(1, N)
796 self.simple = True
797 elif r == 1:
798 br = M
799 bc = (N + c - 1) // c
800 else:
801 raise ValueError('Please specify block size!')
802 else:
803 br = bc = b
805 if context is None:
806 assert b is None
807 assert c == 1
808 n = N
809 m = min((comm.rank + 1) * br, M) - min(comm.rank * br, M)
810 else:
811 n, m = cgpaw.get_blacs_local_shape(context, N, M, bc, br, 0, 0)
812 if n < 0 or m < 0:
813 n = m = 0
814 self.shape = (m, n)
815 lld = max(1, n)
816 if context is not None:
817 self.desc = np.array([1, context, N, M, bc, br, 0, 0, lld],
818 np.intc)
820 def __str__(self):
821 return ('BLACSDistribution(global={}, local={}, blocksize={})'
822 .format(*('{}x{}'.format(*shape)
823 for shape in [self.desc[3:1:-1],
824 self.shape,
825 self.desc[5:3:-1]])))
827 def global_index(self, myi):
828 return self.comm.rank * int(self.desc[5]) + myi
830 def new(self, M, N):
831 return BLACSDistribution(M, N,
832 self.comm,
833 self.rows, self.columns,
834 self.blocksize)
836 def multiply(self, alpha, a, opa, b, opb, beta, c, symmetric):
837 if self.comm.size > 1:
838 ok = a.dist.simple and b.dist.simple and c.dist.simple
839 if ok:
840 # Special cases that don't need scalapack - most likely also
841 # faster:
842 if opa == 'N' and opb == 'N':
843 return mmm_nn(a, b, c, alpha, beta, blas.mmm)
844 if opa == 'N' and opb == 'C':
845 if symmetric:
846 if beta == 1.0:
847 return mmm_nc_sym(a, b, c, alpha, blas.mmm)
848 else:
849 return mmm_nc(a, b, c, alpha, beta, blas.mmm)
851 if symmetric:
852 assert opa == 'N'
853 assert opb == 'C' or opb == 'T' and a.dtype == float
854 N, K = a.shape
855 if a is b:
856 cgpaw.pblas_rk(N, K, alpha, a.data,
857 beta, c.data,
858 a.dist.desc, c.dist.desc,
859 'U')
860 else:
861 cgpaw.pblas_r2k(N, K, 0.5 * alpha, b.data, a.data,
862 beta, c.data,
863 b.dist.desc, a.dist.desc, c.dist.desc,
864 'U')
865 else:
866 Ka, M = a.shape
867 N, Kb = b.shape
868 if opa == 'N':
869 Ka, M = M, Ka
870 if opb == 'N':
871 N, Kb = Kb, N
872 cgpaw.pblas_gemm(N, M, Ka, alpha, b.data, a.data,
873 beta, c.data,
874 b.dist.desc, a.dist.desc, c.dist.desc,
875 opb, opa)
876 return c
879def cublas_mmm(alpha, a, opa, b, opb, beta, c):
880 if c.size == 0:
881 return
882 if a.size == 0 and beta == 1.0:
883 return
884 gpu_gemm(opa.replace('C', 'H'), opb.replace('C', 'H'),
885 a, b, c, alpha, beta)
888class CuPyDistribution(MatrixDistribution):
889 def __init__(self, M, N, comm, r, c, b):
890 self.comm = comm
891 self.rows = r
892 self.columns = c
893 self.blocksize = b
894 self.full_shape = (M, N)
895 # assert r == comm.size, (M, N, comm, r, c, b)
896 assert c == 1
897 br = (M + r - 1) // r
898 m = min((comm.rank + 1) * br, M) - min(comm.rank * br, M)
899 self.shape = (m, N)
901 def __str__(self):
902 M, N = self.full_shape
903 m, N = self.shape
904 return f'CuPyDistribution(global={M}x{N}, local={m}x{N})'
906 def global_index(self, n):
907 1 / 0
908 return n
910 def new(self, M, N):
911 return CuPyDistribution(M, N,
912 self.comm,
913 self.rows, self.columns,
914 self.blocksize)
916 def multiply(self, alpha, a, opa, b, opb, beta, c, *, symmetric=False):
917 if self.comm.size > 1:
918 if opa == 'N' and opb == 'N':
919 return mmm_nn(a, b, c, alpha, beta, cublas_mmm)
920 if opa == 'N' and opb == 'C':
921 if symmetric:
922 if beta == 1.0:
923 return mmm_nc_sym(a, b, c, alpha, cublas_mmm)
924 else:
925 return mmm_nc(a, b, c, alpha, beta, cublas_mmm)
926 1 / 0
928 if symmetric:
929 if opa == 'N':
930 assert opb == 'C' or opb == 'T' and a.dtype == float
931 if a is b:
932 gpu_gemm('N', 'H',
933 a.data, a.data, c.data,
934 alpha, beta)
935 # cp.cublas.syrk('N', a.data, c.data, alpha, beta, True)
936 else:
937 if beta == 1.0 and a.shape[1] == 0:
938 return
939 if c.data.size > 0:
940 assert beta in [0.0, 1.0]
941 # CuPy doesn't have dsyrk, so we roll our own:
942 gpu_gemm('N', 'H',
943 a.data, b.data, c.data,
944 0.5 * alpha, beta)
945 gpu_gemm('N', 'H',
946 b.data, a.data, c.data,
947 0.5 * alpha, 1.0)
948 else:
949 1 / 0
950 assert opa == 'C' and opb == 'N'
951 assert a is not b
952 raise NotImplementedError
953 blas.r2k(0.5 * alpha, a.data, b.data, beta, c.data, 'n')
955 else:
956 cublas_mmm(alpha, a.data, opa, b.data, opb, beta, c.data)
958 def eighg(self, H, L):
959 """
960 :::
962 ~ † ~~ ~ †~
963 H = LHL , HC = CΛ, C = L C.
964 """
965 assert self.comm.size == 1
966 tmp = H.new()
967 self.multiply(1.0, L, 'N', H, 'N', 0.0, tmp)
968 self.multiply(1.0, tmp, 'N', L, 'C', 0.0, H, symmetric=True)
969 eig_M, Ct_MM = cupy_eigh(H.data, UPLO='L')
970 assert Ct_MM.flags.f_contiguous
971 Ct = H.new(data=Ct_MM.T)
972 self.multiply(1.0, L, 'C', Ct, 'T', 0.0, H)
973 # H.complex_conjugate()
974 return eig_M
977def mmm_nn(m1, m2, m3, alpha, beta, mmm):
978 """Parallel matrix-matrix multiplication.
980 :::
982 m <- αm m + βm
983 3 1 2 3
984 """
985 comm = m1.dist.comm
986 buf1 = m2.data
987 xp = m1.xp
989 N = m1.shape[1]
990 assert N == m2.shape[0], f'{N}, {m2.shape[0]}'
991 n = (N + comm.size - 1) // comm.size
993 for r in range(comm.size):
994 if r == 0:
995 # Buffers...
996 buf2 = xp.empty((n, buf1.shape[1]), dtype=buf1.dtype)
998 rrequest = None
999 srequest = None
1000 if r < comm.size - 1:
1001 rrank = (comm.rank + r + 1) % comm.size
1002 rn1 = min(rrank * n, N)
1003 rn2 = min(rn1 + n, N)
1004 if rn2 > rn1:
1005 rrequest = comm.receive(buf2[:rn2 - rn1], rrank, 21, False)
1006 srank = (comm.rank - r - 1) % comm.size
1007 if len(m2.data) > 0:
1008 srequest = comm.send(m2.data, srank, 21, False)
1010 r0 = (comm.rank + r) % comm.size
1011 n1 = min(r0 * n, N)
1012 n2 = min(n1 + n, N)
1013 # Contiguity...
1014 mmm(alpha, m1.data[:, n1:n2], 'N', buf1[:n2 - n1], 'N', beta, m3.data)
1016 beta = 1.0
1018 if r == 0:
1019 # Buffers...
1020 buf1 = xp.empty_like(buf2)
1022 buf1, buf2 = buf2, buf1
1024 if rrequest:
1025 comm.wait(rrequest)
1026 if srequest:
1027 comm.wait(srequest)
1029 return m3
1032def mmm_nc_sym(a, b, out, alpha, mmm):
1033 """Symmetric parallel matrix-matrix multiplication.
1035 :::
1037 †
1038 c <- αab + c
1040 This function utilizes the fact that c is symmetric, s.t.:
1041 † †
1042 c <- 0.5 * (αab + αba) + c
1043 Only lower half of c is updated.
1044 """
1045 comm = a.dist.comm
1046 M, N = b.shape
1047 m = (M + comm.size - 1) // comm.size
1048 mym = len(b.data)
1049 xp = a.xp
1051 # Buffers...
1052 buf1 = xp.empty((m, N), dtype=a.dtype)
1053 buf2 = xp.empty((m, N), dtype=a.dtype)
1054 half = comm.size // 2
1055 aa = a.data
1056 bb = b.data
1058 for r in range(half + 1):
1059 rrequest = None
1060 srequest = None
1062 if r < half:
1063 srank = (comm.rank + r + 1) % comm.size
1064 rrank = (comm.rank - r - 1) % comm.size
1065 skip = (comm.size % 2 == 0 and r == half - 1)
1066 m1 = min(rrank * m, M)
1067 m2 = min(m1 + m, M)
1068 if not (skip and comm.rank < half) and m2 > m1:
1069 rrequest = comm.receive(buf1[:m2 - m1], rrank, 11, False)
1070 if not (skip and comm.rank >= half) and mym > 0:
1071 srequest = comm.send(b.data, srank, 11, False)
1073 if not (comm.size % 2 == 0 and r == half and comm.rank < half):
1074 m1 = min(((comm.rank - r) % comm.size) * m, M)
1075 m2 = min(m1 + m, M)
1076 if r == 0:
1077 # symmmmmmmmmmmmmmmmmmmmmmetricccccccccccccccc
1078 # Contiguity...
1079 mmm(alpha, aa, 'N', bb, 'C', 1.0, out.data[:, m1:m2])
1080 else:
1081 beta = 1.0 if r <= comm.rank else 0.0
1082 mmm(alpha, aa, 'N', buf2[:m2 - m1], 'C',
1083 beta, out.data[:, m1:m2])
1084 # out.data[:, m1:m2] = m12.data[:, :m2 - m1]
1086 if rrequest:
1087 comm.wait(rrequest)
1088 if srequest:
1089 comm.wait(srequest)
1091 bb = buf1
1092 buf1, buf2 = buf2, buf1
1094 requests = []
1095 blocks = []
1096 nrows = (comm.size - 1) // 2
1097 for row in range(nrows):
1098 for column in range(comm.size - nrows + row, comm.size):
1099 if comm.rank == row:
1100 m1 = min(column * m, M)
1101 m2 = min(m1 + m, M)
1102 if mym > 0 and m2 > m1:
1103 requests.append(
1104 comm.send(out.data[:, m1:m2].T.conj().copy(),
1105 column, 12, False))
1106 elif comm.rank == column:
1107 m1 = min(row * m, M)
1108 m2 = min(m1 + m, M)
1109 if mym > 0 and m2 > m1:
1110 block = xp.empty((mym, m2 - m1), out.dtype)
1111 blocks.append((m1, m2, block))
1112 requests.append(comm.receive(block, row, 12, False))
1114 comm.waitall(requests)
1115 for m1, m2, block in blocks:
1116 out.data[:, m1:m2] += block
1118 return out
1121def mmm_nc(a, b, out, alpha, beta, mmm):
1122 """Parallel matrix-matrix multiplication.
1124 :::
1126 †
1127 c <- αab + βc
1128 """
1129 comm = a.dist.comm
1130 M, N = b.shape
1131 m = (M + comm.size - 1) // comm.size
1132 mym = len(b.data)
1133 xp = a.xp
1135 # Nasty buffers
1136 buf1 = xp.empty((m, N), dtype=a.dtype)
1137 buf2 = xp.empty((m, N), dtype=a.dtype)
1138 aa = a.data
1139 bb = b.data
1141 for r in range(comm.size):
1142 rrequest = None
1143 srequest = None
1145 if r < comm.size - 1:
1146 srank = (comm.rank + r + 1) % comm.size
1147 rrank = (comm.rank - r - 1) % comm.size
1148 m1 = min(rrank * m, M)
1149 m2 = min(m1 + m, M)
1150 if m2 > m1:
1151 rrequest = comm.receive(buf1[:m2 - m1], rrank, 11, False)
1152 if mym > 0:
1153 srequest = comm.send(b.data, srank, 11, False)
1155 m1 = min(((comm.rank - r) % comm.size) * m, M)
1156 m2 = min(m1 + m, M)
1157 # symmmmmmmmmmmmmmmmmmmmmmetricccccccccccccccc ??
1158 mmm(alpha, aa, 'N', bb[:m2 - m1], 'C', beta, out.data[:, m1:m2])
1160 if rrequest:
1161 comm.wait(rrequest)
1162 if srequest:
1163 comm.wait(srequest)
1165 bb = buf1
1166 buf1, buf2 = buf2, buf1
1168 return out