Coverage for gpaw/matrix.py: 34%
397 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1"""BLACS distributed matrix object."""
2from typing import Dict, Tuple
3import numpy as np
4import scipy.linalg as linalg
6import gpaw.cgpaw as cgpaw
7from gpaw import debug
8from gpaw.mpi import serial_comm, _Communicator
9import gpaw.utilities.blas as blas
12_global_blacs_context_store: Dict[Tuple[_Communicator, int, int], int] = {}
15def matrix_matrix_multiply(alpha, a, opa, b, opb, beta=0.0, c=None,
16 symmetric=False):
17 """BLAS-style matrix-matrix multiplication.
19 Will use dgemm/zgemm/dsyrk/zherk/dsyr2k/zher2k as apropriate or the
20 equivalent PBLAS functions for distributed matrices.
22 The coefficients alpha and beta are of type float. Matrices a, b and c
23 must have same type (float or complex). The strings opa and opb must be
24 'N', 'T', or 'C' . For opa='N' and opb='N', the operation performed is
25 equivalent to::
27 c.array[:] = alpha * np.dot(a.array, b.array) + beta * c.array
29 Replace a.array with a.array.T or a.array.T.conj() for opa='T' and 'C'
30 respectively (similarly for opb).
32 Use symmetric=True if the result matrix is symmetric/hermetian
33 (only lower half of c will be evaluated).
34 """
35 return _matrix(a).multiply(alpha, opa, _matrix(b), opb,
36 beta, c if c is None else _matrix(c),
37 symmetric)
40def suggest_blocking(N, ncpus):
41 """Suggest blocking of NxN matrix.
43 Returns rows, columns, blocksize tuple."""
45 nprow = ncpus
46 npcol = 1
48 # Make npcol and nprow as close to each other as possible
49 npcol_try = npcol
50 while npcol_try < nprow:
51 if ncpus % npcol_try == 0:
52 npcol = npcol_try
53 nprow = ncpus // npcol
54 npcol_try += 1
56 assert npcol * nprow == ncpus
58 # ScaLAPACK creates trouble if there aren't at least a few whole blocks.
59 # Choose block size so that there will always be at least one whole block
60 # and at least two blocks in total.
61 blocksize = max((N - 2) // max(nprow, npcol), 1)
62 # The next commented line would give more whole blocks.
63 # blocksize = max(N // max(nprow, npcol) - 2, 1)
65 # Use block size that is a power of 2 and at most 64
66 blocksize = 2**int(np.log2(blocksize))
67 blocksize = max(min(blocksize, 64), 1)
69 return nprow, npcol, blocksize
72class Matrix:
73 def __init__(self, M, N, dtype=None, data=None, dist=None):
74 """Matrix object.
76 M: int
77 Rows.
78 N: int
79 Columns.
80 dtype: type
81 Data type (float or complex).
82 dist: tuple or None
83 BLACS distribution given as (communicator, rows, colums, blocksize)
84 tuple. Default is None meaning no distribution.
85 data: ndarray or None.
86 Numpy ndarray to use for starage. By default, a new ndarray
87 will be allocated.
88 """
89 self.shape = (M, N)
91 if dtype is None:
92 if data is None:
93 dtype = float
94 else:
95 dtype = data.dtype
96 self.dtype = np.dtype(dtype)
98 dist = dist or ()
99 if isinstance(dist, tuple):
100 dist = create_distribution(M, N, *dist)
101 self.dist = dist
103 if data is None:
104 self.array = np.empty(dist.shape, self.dtype)
105 else:
106 self.array = data.reshape(dist.shape)
108 self.comm = serial_comm
109 self.state = 'everything is fine'
111 def __len__(self):
112 return self.shape[0]
114 def __repr__(self):
115 dist = str(self.dist).split('(')[1]
116 return f'Matrix({self.dtype.name}: {dist}'
118 def new(self, dist='inherit'):
119 """Create new matrix of same shape and dtype.
121 Default is to use same BLACS distribution. Use dist to use another
122 distribution.
123 """
124 return Matrix(*self.shape, dtype=self.dtype,
125 dist=self.dist if dist == 'inherit' else dist)
127 def __setitem__(self, i, x):
128 # assert i == slice(None)
129 if isinstance(x, np.ndarray):
130 1 / 0 # sssssself.array[:] = x
131 else:
132 x.eval(self)
134 def __iadd__(self, x):
135 x.eval(self, 1.0)
136 return self
138 def multiply(self, alpha, opa, b, opb, beta=0.0, out=None,
139 symmetric=False):
140 """BLAS-style Matrix-matrix multiplication.
142 See matrix_matrix_multipliction() for details.
143 """
144 dist = self.dist
145 if out is None:
146 assert beta == 0.0
147 if opa == 'N':
148 M = self.shape[0]
149 else:
150 M = self.shape[1]
151 if opb == 'N':
152 N = b.shape[1]
153 else:
154 N = b.shape[0]
155 out = Matrix(M, N, self.dtype,
156 dist=(dist.comm, dist.rows, dist.columns))
157 if dist.comm.size > 1:
158 # Special cases that don't need scalapack - most likely also
159 # faster:
160 if alpha == 1.0 and opa == 'N' and opb == 'N':
161 return fastmmm(self, b, out, beta)
162 if alpha == 1.0 and beta == 1.0 and opa == 'N' and opb == 'C':
163 if symmetric:
164 return fastmmm2(self, b, out)
165 else:
166 return fastmmm2notsym(self, b, out)
168 dist.multiply(alpha, self, opa, b, opb, beta, out, symmetric)
169 return out
171 def redist(self, other):
172 """Redistribute to other BLACS layout."""
173 if self is other:
174 return
175 d1 = self.dist
176 d2 = other.dist
177 n1 = d1.rows * d1.columns
178 n2 = d2.rows * d2.columns
179 if n1 == n2 == 1:
180 other.array[:] = self.array
181 return
183 if n2 == 1 and d1.blocksize is None:
184 assert d2.blocksize is None
185 comm = d1.comm
186 if comm.rank == 0:
187 M = len(self)
188 m = (M + comm.size - 1) // comm.size
189 other.array[:m] = self.array
190 for r in range(1, comm.size):
191 m1 = min(r * m, M)
192 m2 = min(m1 + m, M)
193 comm.receive(other.array[m1:m2], r)
194 else:
195 comm.send(self.array, 0)
196 return
198 if n1 == 1 and d2.blocksize is None:
199 assert d1.blocksize is None
200 comm = d1.comm
201 if comm.rank == 0:
202 M = len(self)
203 m = (M + comm.size - 1) // comm.size
204 other.array[:] = self.array[:m]
205 for r in range(1, comm.size):
206 m1 = min(r * m, M)
207 m2 = min(m1 + m, M)
208 comm.send(self.array[m1:m2], r)
209 else:
210 comm.receive(other.array, 0)
211 return
213 c = d1.comm if d1.comm.size > d2.comm.size else d2.comm
214 n = max(n1, n2)
215 if n < c.size:
216 c = c.new_communicator(np.arange(n))
217 if c is not None:
218 M, N = self.shape
219 d1 = create_distribution(M, N, c,
220 d1.rows, d1.columns, d1.blocksize)
221 d2 = create_distribution(M, N, c,
222 d2.rows, d2.columns, d2.blocksize)
223 if n1 == n:
224 ctx = d1.desc[1]
225 else:
226 ctx = d2.desc[1]
227 redist(d1, self.array, d2, other.array, ctx)
229 def invcholesky(self):
230 """Inverse of Cholesky decomposition.
232 Only the lower part is used.
233 """
234 if self.state == 'a sum is needed':
235 self.comm.sum(self.array, 0)
237 if self.comm.rank == 0:
238 if self.dist.comm.size > 1:
239 S = self.new(dist=(self.dist.comm, 1, 1))
240 self.redist(S)
241 else:
242 S = self
243 if self.dist.comm.rank == 0:
244 if debug:
245 S.array[np.triu_indices(S.shape[0], 1)] = 42.0
246 L_nn = linalg.cholesky(S.array,
247 lower=True,
248 overwrite_a=True,
249 check_finite=debug)
250 S.array[:] = linalg.inv(L_nn,
251 overwrite_a=True,
252 check_finite=debug)
253 if S is not self:
254 S.redist(self)
256 if self.comm.size > 1:
257 self.comm.broadcast(self.array, 0)
258 self.state == 'everything is fine'
260 def eigh(self, cc=False, scalapack=(None, 1, 1, None)):
261 """Calculate eigenvectors and eigenvalues.
263 Matrix must be symmetric/hermitian and stored in lower half.
265 cc: bool
266 Complex conjugate matrix before finding eigenvalues.
267 scalapack: tuple
268 BLACS distribution for ScaLapack to use. Default is to do serial
269 diagonalization.
270 """
271 slcomm, rows, columns, blocksize = scalapack
273 if self.state == 'a sum is needed':
274 self.comm.sum(self.array, 0)
276 slcomm = slcomm or self.dist.comm
277 dist = (slcomm, rows, columns, blocksize)
279 redist = (rows != self.dist.rows or
280 columns != self.dist.columns or
281 blocksize != self.dist.blocksize)
283 if redist:
284 H = self.new(dist=dist)
285 self.redist(H)
286 else:
287 assert self.dist.comm.size == slcomm.size
288 H = self
290 eps = np.empty(H.shape[0])
292 if rows * columns == 1:
293 if self.comm.rank == 0 and self.dist.comm.rank == 0:
294 if cc and H.dtype == complex:
295 np.negative(H.array.imag, H.array.imag)
296 if debug:
297 H.array[np.triu_indices(H.shape[0], 1)] = 42.0
298 eps[:], H.array.T[:] = linalg.eigh(H.array,
299 lower=True, # ???
300 overwrite_a=True,
301 check_finite=debug)
302 self.dist.comm.broadcast(eps, 0)
303 else:
304 if slcomm.rank < rows * columns:
305 assert cc
306 array = H.array.copy()
307 info = cgpaw.scalapack_diagonalize_dc(array, H.dist.desc, 'U',
308 H.array, eps)
309 assert info == 0, info
311 # necessary to broadcast eps when some ranks are not used
312 # in current scalapack parameter set
313 # eg. (2, 1, 2) with 4 processes
314 if rows * columns < slcomm.size:
315 H.dist.comm.broadcast(eps, 0)
317 if redist:
318 H.redist(self)
320 assert (self.state == 'a sum is needed') == (
321 self.comm.size > 1)
322 if self.comm.size > 1:
323 self.comm.broadcast(self.array, 0)
324 self.comm.broadcast(eps, 0)
325 self.state == 'everything is fine'
327 return eps
329 def complex_conjugate(self):
330 """Inplace complex conjugation."""
331 if self.dtype == complex:
332 np.negative(self.array.imag, self.array.imag)
335def _matrix(M):
336 """Dig out Matrix object from wrapper(s)."""
337 if isinstance(M, Matrix):
338 return M
339 return _matrix(M.matrix)
342class NoDistribution:
343 comm = serial_comm
344 rows = 1
345 columns = 1
346 blocksize = None
348 def __init__(self, M, N):
349 self.shape = (M, N)
351 def __str__(self):
352 return 'NoDistribution({}x{})'.format(*self.shape)
354 def global_index(self, n):
355 return n
357 def multiply(self, alpha, a, opa, b, opb, beta, c, symmetric):
358 if symmetric:
359 assert opa == 'N'
360 assert opb == 'C' or opb == 'T' and a.dtype == float
361 if a is b:
362 blas.rk(alpha, a.array, beta, c.array)
363 else:
364 if beta == 1.0 and a.shape[1] == 0:
365 return
366 blas.r2k(0.5 * alpha, a.array, b.array, beta, c.array)
367 else:
368 blas.mmm(alpha, a.array, opa, b.array, opb, beta, c.array)
371class BLACSDistribution:
372 serial = False
374 def __init__(self, M, N, comm, r, c, b):
375 self.comm = comm
376 self.rows = r
377 self.columns = c
378 self.blocksize = b
380 key = (comm, r, c)
381 context = _global_blacs_context_store.get(key)
382 if context is None:
383 try:
384 context = cgpaw.new_blacs_context(comm.get_c_object(),
385 c, r, 'R')
386 except AttributeError:
387 pass
388 else:
389 _global_blacs_context_store[key] = context
391 if b is None:
392 if c == 1:
393 br = (M + r - 1) // r
394 bc = max(1, N)
395 elif r == 1:
396 br = M
397 bc = (N + c - 1) // c
398 else:
399 raise ValueError('Please specify block size!')
400 else:
401 br = bc = b
403 if context is None:
404 assert b is None
405 assert c == 1
406 n = N
407 m = min((comm.rank + 1) * br, M) - min(comm.rank * br, M)
408 else:
409 n, m = cgpaw.get_blacs_local_shape(context, N, M, bc, br, 0, 0)
410 if n < 0 or m < 0:
411 n = m = 0
412 self.shape = (m, n)
413 lld = max(1, n)
414 if context is not None:
415 self.desc = np.array([1, context, N, M, bc, br, 0, 0, lld],
416 np.intc)
418 def __str__(self):
419 return ('BLACSDistribution(global={}, local={}, blocksize={})'
420 .format(*('{}x{}'.format(*shape)
421 for shape in [self.desc[3:1:-1],
422 self.shape,
423 self.desc[5:3:-1]])))
425 def global_index(self, myi):
426 return self.comm.rank * int(self.desc[5]) + myi
428 def multiply(self, alpha, a, opa, b, opb, beta, c, symmetric):
429 if symmetric:
430 assert opa == 'N'
431 assert opb == 'C' or opb == 'T' and a.dtype == float
432 N, K = a.shape
433 if a is b:
434 cgpaw.pblas_rk(N, K, alpha, a.array,
435 beta, c.array,
436 a.dist.desc, c.dist.desc,
437 'U')
438 else:
439 cgpaw.pblas_r2k(N, K, 0.5 * alpha, b.array, a.array,
440 beta, c.array,
441 b.dist.desc, a.dist.desc, c.dist.desc,
442 'U')
443 else:
444 Ka, M = a.shape
445 N, Kb = b.shape
446 if opa == 'N':
447 Ka, M = M, Ka
448 if opb == 'N':
449 N, Kb = Kb, N
450 cgpaw.pblas_gemm(N, M, Ka, alpha, b.array, a.array,
451 beta, c.array,
452 b.dist.desc, a.dist.desc, c.dist.desc,
453 opb, opa)
456def redist(dist1, M1, dist2, M2, context):
457 cgpaw.scalapack_redist(dist1.desc, dist2.desc,
458 M1, M2,
459 dist1.desc[2], dist1.desc[3],
460 1, 1, 1, 1, # 1-indexing
461 context, 'G')
464def create_distribution(M, N, comm=None, r=1, c=1, b=None):
465 if comm is None or comm.size == 1:
466 assert r == 1 and abs(c) == 1 or c == 1 and abs(r) == 1
467 return NoDistribution(M, N)
469 return BLACSDistribution(M, N, comm,
470 r if r != -1 else comm.size,
471 c if c != -1 else comm.size,
472 b)
475def fastmmm(m1, m2, m3, beta):
476 comm = m1.dist.comm
478 buf1 = m2.array
480 N = len(m1)
481 n = (N + comm.size - 1) // comm.size
483 for r in range(comm.size):
484 if r == 0:
485 buf2 = np.empty((n, buf1.shape[1]), dtype=buf1.dtype)
487 rrequest = None
488 srequest = None
489 if r < comm.size - 1:
490 rrank = (comm.rank + r + 1) % comm.size
491 rn1 = min(rrank * n, N)
492 rn2 = min(rn1 + n, N)
493 if rn2 > rn1:
494 rrequest = comm.receive(buf2[:rn2 - rn1], rrank, 21, False)
495 srank = (comm.rank - r - 1) % comm.size
496 if len(m2.array) > 0:
497 srequest = comm.send(m2.array, srank, 21, False)
499 r0 = (comm.rank + r) % comm.size
500 n1 = min(r0 * n, N)
501 n2 = min(n1 + n, N)
502 blas.mmm(1.0, m1.array[:, n1:n2], 'N', buf1[:n2 - n1], 'N',
503 beta, m3.array)
505 beta = 1.0
507 if r == 0:
508 buf1 = np.empty_like(buf2)
510 buf1, buf2 = buf2, buf1
512 if rrequest:
513 comm.wait(rrequest)
514 if srequest:
515 comm.wait(srequest)
517 return m3
520def fastmmm2(a, b, out):
521 if a.comm:
522 assert b.comm is a.comm
523 if a.comm.size > 1:
524 assert out.comm == a.comm
525 assert out.state == 'a sum is needed'
527 comm = a.dist.comm
528 M, N = a.shape
529 m = (M + comm.size - 1) // comm.size
530 mym = len(a.array)
532 buf1 = np.empty((m, N), dtype=a.dtype)
533 buf2 = np.empty((m, N), dtype=a.dtype)
534 half = comm.size // 2
535 aa = a.array
536 bb = b.array
538 for r in range(half + 1):
539 rrequest = None
540 srequest = None
542 if r < half:
543 srank = (comm.rank + r + 1) % comm.size
544 rrank = (comm.rank - r - 1) % comm.size
545 skip = (comm.size % 2 == 0 and r == half - 1)
546 m1 = min(rrank * m, M)
547 m2 = min(m1 + m, M)
548 if not (skip and comm.rank < half) and m2 > m1:
549 rrequest = comm.receive(buf1[:m2 - m1], rrank, 11, False)
550 if not (skip and comm.rank >= half) and mym > 0:
551 srequest = comm.send(b.array, srank, 11, False)
553 if not (comm.size % 2 == 0 and r == half and comm.rank < half):
554 m1 = min(((comm.rank - r) % comm.size) * m, M)
555 m2 = min(m1 + m, M)
556 if r == 0:
557 # symmmmmmmmmmmmmmmmmmmmmmetricccccccccccccccc
558 blas.mmm(1.0, aa, 'N', bb, 'C', 1.0, out.array[:, m1:m2])
559 else:
560 beta = 1.0 if r <= comm.rank else 0.0
561 blas.mmm(1.0, aa, 'N', buf2[:m2 - m1], 'C',
562 beta, out.array[:, m1:m2])
563 # out.array[:, m1:m2] = m12.array[:, :m2 - m1]
565 if rrequest:
566 comm.wait(rrequest)
567 if srequest:
568 comm.wait(srequest)
570 bb = buf1
571 buf1, buf2 = buf2, buf1
573 requests = []
574 blocks = []
575 nrows = (comm.size - 1) // 2
576 for row in range(nrows):
577 for column in range(comm.size - nrows + row, comm.size):
578 if comm.rank == row:
579 m1 = min(column * m, M)
580 m2 = min(m1 + m, M)
581 if mym > 0 and m2 > m1:
582 requests.append(
583 comm.send(out.array[:, m1:m2].T.conj().copy(),
584 column, 12, False))
585 elif comm.rank == column:
586 m1 = min(row * m, M)
587 m2 = min(m1 + m, M)
588 if mym > 0 and m2 > m1:
589 block = np.empty((mym, m2 - m1), out.dtype)
590 blocks.append((m1, m2, block))
591 requests.append(comm.receive(block, row, 12, False))
593 comm.waitall(requests)
594 for m1, m2, block in blocks:
595 out.array[:, m1:m2] += block
597 return out
600def fastmmm2notsym(a, b, out):
601 if a.comm:
602 assert b.comm is a.comm
603 if a.comm.size > 1:
604 assert out.comm == a.comm
605 assert out.state == 'a sum is needed'
607 comm = a.dist.comm
608 M, N = a.shape
609 m = (M + comm.size - 1) // comm.size
610 mym = len(a.array)
612 buf1 = np.empty((m, N), dtype=a.dtype)
613 buf2 = np.empty((m, N), dtype=a.dtype)
614 aa = a.array
615 bb = b.array
617 for r in range(comm.size):
618 rrequest = None
619 srequest = None
621 if r < comm.size - 1:
622 srank = (comm.rank + r + 1) % comm.size
623 rrank = (comm.rank - r - 1) % comm.size
624 m1 = min(rrank * m, M)
625 m2 = min(m1 + m, M)
626 if m2 > m1:
627 rrequest = comm.receive(buf1[:m2 - m1], rrank, 11, False)
628 if mym > 0:
629 srequest = comm.send(b.array, srank, 11, False)
631 m1 = min(((comm.rank - r) % comm.size) * m, M)
632 m2 = min(m1 + m, M)
633 # symmmmmmmmmmmmmmmmmmmmmmetricccccccccccccccc ??
634 blas.mmm(1.0, aa, 'N', bb[:m2 - m1], 'C', 1.0, out.array[:, m1:m2])
636 if rrequest:
637 comm.wait(rrequest)
638 if srequest:
639 comm.wait(srequest)
641 bb = buf1
642 buf1, buf2 = buf2, buf1
644 return out