Coverage for gpaw/core/matrix.py: 64%

681 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1"""BLACS distributed matrix object.""" 

2from __future__ import annotations 

3 

4from types import ModuleType 

5from typing import Dict, Tuple 

6import gpaw.cgpaw as cgpaw 

7import numpy as np 

8import scipy.linalg as sla 

9 

10import gpaw.utilities.blas as blas 

11from gpaw import debug, get_scipy_version 

12from gpaw.gpu import cupy as cp, cupy_eigh, XP, gpu_gemm 

13from gpaw.mpi import MPIComm, _Communicator, serial_comm 

14from gpaw.typing import Array1D, ArrayLike1D, ArrayLike2D, Array2D 

15 

16_global_blacs_context_store: Dict[Tuple[_Communicator, int, int], int] = {} 

17 

18 

19def suggest_blocking(N: int, ncpus: int) -> tuple[int, int, int]: 

20 """Suggest blocking of ``NxN`` matrix. 

21 

22 Returns rows, columns, blocksize tuple. 

23 

24 >>> suggest_blocking(10, 6) 

25 (3, 2, 2) 

26 """ 

27 

28 nprow = ncpus 

29 npcol = 1 

30 

31 # Make npcol and nprow as close to each other as possible 

32 npcol_try = npcol 

33 while npcol_try < nprow: 

34 if ncpus % npcol_try == 0: 

35 npcol = npcol_try 

36 nprow = ncpus // npcol 

37 npcol_try += 1 

38 

39 assert npcol * nprow == ncpus 

40 

41 # ScaLAPACK creates trouble if there aren't at least a few whole blocks. 

42 # Choose block size so that there will always be at least one whole block 

43 # and at least two blocks in total. 

44 blocksize = max((N - 2) // max(nprow, npcol), 1) 

45 # The next commented line would give more whole blocks. 

46 # blocksize = max(N // max(nprow, npcol) - 2, 1) 

47 

48 # Use block size that is a power of 2 and at most 64 

49 blocksize = 2**int(np.log2(blocksize)) 

50 blocksize = max(min(blocksize, 64), 1) 

51 

52 return nprow, npcol, blocksize 

53 

54 

55class MatrixWithNoData: 

56 def __init__(self, 

57 M: int, 

58 N: int, 

59 *, 

60 dtype=None, 

61 dist: MatrixDistribution | tuple | None = None): 

62 self.shape = (M, N) 

63 self.dtype = dtype 

64 self.data = np.empty((0, 0), dtype) 

65 dist = dist or () 

66 if isinstance(dist, tuple): 

67 kwargs = {key: val for key, val in zip(['comm', 'r', 'c', 'b'], 

68 dist)} 

69 dist = create_distribution(M, N, **kwargs) 

70 self.dist = dist 

71 

72 def create(self) -> Matrix: 

73 return Matrix(*self.shape, dtype=self.dtype, dist=self.dist) 

74 

75 

76class Matrix(XP): 

77 def __init__(self, 

78 M: int, 

79 N: int, 

80 *, 

81 dtype=None, 

82 data: ArrayLike2D | None = None, 

83 dist: MatrixDistribution | tuple | None = None, 

84 xp=None): 

85 """Matrix object. 

86 

87 Parameters 

88 ---------- 

89 M: 

90 Rows. 

91 N: 

92 Columns. 

93 dtype: 

94 Data type (float or complex). 

95 dist: 

96 BLACS distribution given as 

97 (communicator, rows, columns, blocksize) 

98 tuple. Default is None meaning no distribution. 

99 data: 

100 Numpy ndarray to use for storage. By default, a new ndarray 

101 will be allocated. 

102 """ 

103 self.shape = (M, N) 

104 

105 if data is None or isinstance(data, (np.ndarray, cp.ndarray)): 

106 pass 

107 else: 

108 data = np.asarray(data) 

109 

110 if dtype is None: 

111 if data is None: 

112 dtype = float 

113 else: 

114 dtype = data.dtype 

115 self.dtype = np.dtype(dtype) 

116 assert np.dtype(self.dtype) in \ 

117 [np.float32, np.float64, np.complex64, np.complex128], dtype 

118 

119 self.xp: ModuleType 

120 if xp is None: 

121 if isinstance(dist, CuPyDistribution): 

122 xp = cp 

123 elif data is not None and not isinstance(data, np.ndarray): 

124 xp = cp 

125 else: 

126 xp = np 

127 XP.__init__(self, xp) 

128 

129 dist = dist or () 

130 if isinstance(dist, tuple): 

131 kwargs = {key: val for key, val in zip(['comm', 'r', 'c', 'b'], 

132 dist)} 

133 dist = create_distribution(M, N, xp=self.xp, **kwargs) 

134 else: 

135 assert self.shape == dist.full_shape 

136 self.dist = dist 

137 

138 self.data: Array2D 

139 if data is None: 

140 self.data = self.xp.empty(dist.shape, self.dtype) 

141 else: 

142 assert data.shape == dist.shape, (data.shape, dist.shape, dist) 

143 self.data = data 

144 

145 def __repr__(self): 

146 dist = str(self.dist).split('(')[1] 

147 if self.xp is cp: 

148 dist = 'xp=cp, ' + dist 

149 return f'Matrix({self.dtype.name}: {dist}' 

150 

151 def new(self, dist='inherit', data=None) -> Matrix: 

152 """Create new matrix of same shape and dtype. 

153 

154 Default is to use same BLACS distribution. Use dist to use another 

155 distribution. 

156 """ 

157 return Matrix(*self.shape, 

158 dtype=self.dtype, 

159 dist=self.dist if dist == 'inherit' else dist, 

160 data=data, 

161 xp=self.xp) 

162 

163 def copy(self) -> Matrix: 

164 """Create a copy.""" 

165 M = self.new() 

166 M.data[:] = self.data 

167 return M 

168 

169 def __setitem__(self, item, value): 

170 assert item == slice(None) 

171 assert isinstance(value, Matrix) 

172 self.data[:] = value.data 

173 

174 def __iadd__(self, other): 

175 if isinstance(other, Matrix): 

176 other = other.data 

177 self.data += other 

178 return self 

179 

180 def multiply(self, 

181 other, 

182 alpha=1.0, 

183 opa='N', 

184 opb='N', 

185 out=None, 

186 data_buffer=None, 

187 beta=0.0, 

188 symmetric=False) -> Matrix: 

189 """BLAS matrix-multiplication with other matrix.""" 

190 if not isinstance(other, Matrix): 

191 other = other.matrix 

192 A = self 

193 B = other 

194 dist = self.dist 

195 if out is None: 

196 assert beta == 0.0 

197 M = A.shape[0] if opa == 'N' else A.shape[1] 

198 N = B.shape[1] if opb == 'N' else B.shape[0] 

199 out = Matrix(M, N, dtype=A.dtype, dist=dist.new(M, N)) 

200 elif not isinstance(out, Matrix): 

201 out = out.matrix 

202 if out.data is other.data: 

203 # Repeatably call multiply using data_buffer 

204 assert opa == 'N', 'Not implemented' 

205 assert opb == 'N', 'Not implemented' 

206 assert not beta, 'Not implemented' 

207 assert other.shape[0] == self.shape[0] 

208 

209 # Assert simple (only row distributed) distributions: 

210 assert self.shape[1] == self.data.shape[1] 

211 assert other.shape[1] == other.data.shape[1] 

212 assert out.shape[1] == out.data.shape[1] 

213 

214 if data_buffer is None: 

215 raise ValueError('other is out, and data_buffer is None') 

216 

217 assert isinstance(data_buffer, other.xp.ndarray) 

218 dtype = other.data.dtype 

219 data_buffer = data_buffer.view(dtype) 

220 if other.data.shape[0] > 0: 

221 # Obtain buffer size s.t. the maximum number of 

222 # columns in other.data fits into data_buffer 

223 buffer_size = min( 

224 data_buffer.size // other.data.shape[0], 

225 other.data.shape[1]) 

226 else: 

227 # There is no data in other. Thus buffer_size 

228 # fits all. 

229 buffer_size = other.data.shape[1] 

230 buffer_size = dist.comm.min_scalar(buffer_size) 

231 max_B = other.data.shape[1] 

232 

233 if buffer_size >= max_B: 

234 # No need for sliced multiply 

235 other_buffer = other.new( 

236 data=data_buffer[:other.data.size].reshape( 

237 other.data.shape)) 

238 other_buffer.data[:] = other.data 

239 dist.multiply(alpha, A, opa, other_buffer, opb, beta, out, 

240 symmetric=symmetric) 

241 return out 

242 

243 # Sliced multiply 

244 for i in range(0, max_B, buffer_size): 

245 r_buffer_size = min(max(other.data.shape[1] - i, 0), 

246 buffer_size) 

247 l_buffer_size = r_buffer_size * other.data.shape[0] 

248 buffer = Matrix( 

249 M=other.shape[0], 

250 N=r_buffer_size, 

251 data=data_buffer[ 

252 :l_buffer_size].reshape( 

253 (other.data.shape[0], r_buffer_size) 

254 ), 

255 dist=dist.new(M=other.shape[0], N=r_buffer_size), 

256 xp=other.xp) 

257 buffer.data[:] \ 

258 = other.data[:, i:i + buffer_size] 

259 out_view = buffer.new( 

260 data=out.data[:, i:i + buffer_size]) 

261 dist.multiply(alpha, A, opa, buffer, 

262 opb, beta, out_view, symmetric=False) 

263 return out 

264 

265 dist.multiply(alpha, A, opa, B, opb, beta, out, symmetric=symmetric) 

266 return out 

267 

268 def redist(self, other: Matrix) -> None: 

269 """Redistribute to other BLACS layout.""" 

270 if self is other: 

271 return 

272 d1 = self.dist 

273 d2 = other.dist 

274 n1 = d1.rows * d1.columns 

275 n2 = d2.rows * d2.columns 

276 if n1 == n2 == 1: 

277 other.data[:] = self.data 

278 return 

279 

280 if n2 == 1 and d1.blocksize is None: 

281 assert d2.blocksize is None 

282 assert d1.columns == 1 

283 comm = d1.comm 

284 if comm.rank == 0: 

285 M = self.shape[0] 

286 m = (M + comm.size - 1) // comm.size 

287 other.data[:m] = self.data 

288 for r in range(1, comm.size): 

289 m1 = min(r * m, M) 

290 m2 = min(m1 + m, M) 

291 comm.receive(other.data[m1:m2], r) 

292 else: 

293 comm.send(self.data, 0) 

294 return 

295 

296 if n1 == 1 and d2.blocksize is None: 

297 assert d1.blocksize is None 

298 assert d1.columns == 1 

299 comm = d1.comm 

300 if comm.rank == 0: 

301 M = self.shape[0] 

302 m = (M + comm.size - 1) // comm.size 

303 other.data[:] = self.data[:m] 

304 for r in range(1, comm.size): 

305 m1 = min(r * m, M) 

306 m2 = min(m1 + m, M) 

307 comm.send(self.data[m1:m2], r) 

308 else: 

309 comm.receive(other.data, 0) 

310 return 

311 

312 c = d1.comm if d1.comm.size > d2.comm.size else d2.comm 

313 n = max(n1, n2) 

314 if n < c.size: 

315 c = c.new_communicator(np.arange(n)) 

316 if c is not None: 

317 M, N = self.shape 

318 d1 = create_distribution(M, N, c, 

319 d1.rows, d1.columns, d1.blocksize) 

320 d2 = create_distribution(M, N, c, 

321 d2.rows, d2.columns, d2.blocksize) 

322 if n1 == n: 

323 ctx = d1.desc[1] 

324 else: 

325 ctx = d2.desc[1] 

326 redist(d1, self.data, d2, other.data, ctx) 

327 

328 def gather(self, root: int = 0, broadcast=False) -> Matrix: 

329 """Gather the Matrix on the root rank. 

330 

331 Returns a new Matrix distributed so that all data is on the root rank 

332 """ 

333 assert root == 0 

334 if self.dist.comm.size > 1: 

335 S = self.new(dist=(self.dist.comm, 1, 1)) 

336 self.redist(S) 

337 if broadcast: 

338 if self.dist.comm.rank > 0: 

339 S = self.new(dist=None) 

340 self.dist.comm.broadcast(S.data, 0) 

341 else: 

342 S = self 

343 

344 return S 

345 

346 def inv(self, uplo='L'): 

347 """Inplace inversion.""" 

348 assert uplo == 'L' 

349 M, N = self.shape 

350 assert M == N 

351 dist = self.dist 

352 if dist.comm.size == 1: 

353 self.tril2full() 

354 self.data[:] = sla.inv(self.data, 

355 overwrite_a=True, 

356 check_finite=debug) 

357 return 

358 bc, br = dist.desc[4:6] 

359 assert bc == br 

360 info = cgpaw.scalapack_inverse(self.data, dist.desc, 'U') 

361 if info != 0: 

362 raise ValueError(f'scalapack_inverse error: {info}') 

363 

364 def invcholesky(self) -> None: 

365 """In-place inverse of Cholesky decomposition. 

366 

367 Calculate a lower triangle matrix `L` where::: 

368 

369 

370 LSL = 1, 

371 

372 and `S` is self. Only the lower part of `S` is used. 

373 

374 >>> S = Matrix(2, 2, data=[[1.0, np.nan], 

375 ... [0.1, 1.0]]) 

376 >>> S.invcholesky() 

377 >>> S.data 

378 array([[ 1. , -0. ], 

379 [-0.10050378, 1.00503782]]) 

380 """ 

381 S = self.gather() 

382 if self.dist.comm.rank == 0: 

383 if isinstance(S.data, np.ndarray): 

384 if debug: 

385 S.data[np.triu_indices(S.shape[0], 1)] = 42.0 

386 L_nn = sla.cholesky(S.data, 

387 lower=True, 

388 overwrite_a=True, 

389 check_finite=debug) 

390 S.data[:] = sla.inv(L_nn, 

391 overwrite_a=True, 

392 check_finite=debug) 

393 else: 

394 S.tril2full() 

395 L_nn = cp.linalg.cholesky(S.data) 

396 S.data[:] = cp.linalg.inv(L_nn) 

397 

398 if S is not self: 

399 S.redist(self) 

400 

401 def eigh(self, 

402 S=None, 

403 *, 

404 cc=False, 

405 scalapack=(None, 1, 1, None), 

406 limit: int | None = None) -> Array1D: 

407 """Calculate eigenvectors and eigenvalues. 

408 

409 Matrix must be symmetric/hermitian and stored in lower half. 

410 If ``S`` is given, solve a generalized eigenvalue problem. 

411 

412 Parameters 

413 ---------- 

414 cc: bool 

415 Complex conjugate matrix before finding eigenvalues. 

416 scalapack: tuple 

417 BLACS distribution for ScaLapack to use. Default is to do serial 

418 diagonalization. 

419 limit: 

420 Number of eigenvector and values to find. Defaults to all. 

421 """ 

422 slcomm, rows, columns, blocksize = scalapack 

423 slcomm = slcomm or self.dist.comm 

424 dist = (slcomm, rows, columns, blocksize) 

425 

426 redist = (rows != self.dist.rows or 

427 columns != self.dist.columns or 

428 blocksize != self.dist.blocksize) 

429 

430 if redist: 

431 H = self.new(dist=dist) 

432 self.redist(H) 

433 if S is not None: 

434 S0 = S 

435 S = S0.new(dist=dist) 

436 S0.redist(S) 

437 else: 

438 assert self.dist.comm.size == slcomm.size 

439 H = self 

440 

441 if limit == H.shape[0]: 

442 limit = None 

443 

444 if limit: 

445 eps = self.xp.empty(limit) 

446 else: 

447 eps = self.xp.empty(H.shape[0]) 

448 

449 if rows * columns == 1: 

450 if self.dist.comm.rank == 0: 

451 if cc and np.issubdtype(H.dtype, np.complexfloating): 

452 np.negative(H.data.imag, H.data.imag) 

453 if debug: 

454 H.data[np.triu_indices(H.shape[0], 1)] = 42.0 

455 if S is None: 

456 if self.xp is not np: 

457 assert isinstance(H.data, cp.ndarray) 

458 eps[:], H.data.T[:] = cupy_eigh(H.data, UPLO='L') 

459 else: 

460 eps[:], H.data.T[:] = sla.eigh( 

461 H.data, 

462 lower=True, 

463 overwrite_a=True, 

464 check_finite=debug, 

465 driver='evx' if H.data.size == 1 else 'evd') 

466 else: 

467 if self.xp is cp: 

468 assert self.dist.comm.size == 1 

469 S.invcholesky() 

470 self.tril2full() 

471 eigs = self.eighg(S) 

472 self.data[:] = self.data.T.copy() 

473 return eigs 

474 if debug: 

475 S.data[self.xp.triu_indices(H.shape[0], 1)] = 42.0 

476 eps, evecs = sla.eigh( 

477 H.data, 

478 S.data, 

479 lower=True, 

480 overwrite_a=True, 

481 overwrite_b=True, 

482 check_finite=debug, 

483 subset_by_index=(0, limit - 1) if limit else None) 

484 limit = limit or len(eps) 

485 H.data.T[:, :limit] = evecs 

486 self.dist.comm.broadcast(eps, 0) 

487 else: 

488 if slcomm.rank < rows * columns: 

489 assert cc 

490 assert S is None 

491 array = H.data.copy() 

492 info = cgpaw.scalapack_diagonalize_dc(array, H.dist.desc, 'U', 

493 H.data, eps) 

494 assert info == 0, info 

495 

496 # necessary to broadcast eps when some ranks are not used 

497 # in current scalapack parameter set 

498 # eg. (2, 1, 2) with 4 processes 

499 if rows * columns < slcomm.size: 

500 H.dist.comm.broadcast(eps, 0) 

501 

502 if redist: 

503 H.redist(self) 

504 

505 return eps 

506 

507 def eighg(self, L: Matrix, comm2: MPIComm = serial_comm) -> Array1D: 

508 """Solve generalized eigenvalue problem. 

509 

510 With `H` being self, we solve for the eigenvectors `C` and the 

511 eigenvalues `Λ` (a diagonal matrix)::: 

512 

513 HC = SCΛ, 

514 

515 where `L` is a lower triangle matrix such that::: 

516 

517 

518 LSL = 1. 

519 

520 The solution has these three steps::: 

521 

522 ~ † ~~ ~ †~ 

523 H = LHL , HC = CΛ, C = L C. 

524 

525 Note that `H` must be the full matrix not just half of it! 

526 

527 """ 

528 M, N = self.shape 

529 assert M == N 

530 comm = self.dist.comm 

531 

532 if comm2.rank == 0: 

533 if comm.size == 1: 

534 H = self 

535 L0 = L 

536 else: 

537 # TODO: Use scalapack 

538 H = self.new(dist=(comm,)) 

539 self.redist(H) 

540 L0 = self.new(dist=(comm,)) 

541 L.redist(L0) 

542 if comm.rank == 0: 

543 if self.xp is not np: 

544 return self.dist.eighg(self, L0) 

545 tmp_MM = np.empty_like(H.data) 

546 L_MM = L0.data 

547 blas.mmm(1.0, L_MM, 'N', H.data, 'N', 0.0, tmp_MM) 

548 blas.r2k(0.5, tmp_MM, L_MM, 0.0, H.data) 

549 # Ht_MM = L_MM @ H.data @ L_MM.conj().T 

550 if get_scipy_version() >= [1, 9]: 

551 driver = 'evx' if M == 1 else 'evd' 

552 else: 

553 driver = None 

554 eig_n, Ct_Mn = sla.eigh( 

555 H.data, 

556 overwrite_a=True, 

557 check_finite=debug, 

558 driver=driver) 

559 assert Ct_Mn.flags.f_contiguous 

560 blas.mmm(1.0, L_MM, 'C', Ct_Mn.T, 'T', 0.0, H.data) 

561 # H.data[:] = L_MM.T.conj() @ Ct_Mn 

562 else: 

563 eig_n = np.empty(M) 

564 

565 if comm.size > 1: 

566 H.redist(self) 

567 comm.broadcast(eig_n, 0) 

568 

569 if comm2.rank > 0: 

570 eig_n = np.empty(M) 

571 comm2.broadcast(eig_n, 0) 

572 comm2.broadcast(self.data, 0) 

573 

574 return eig_n 

575 

576 def complex_conjugate(self) -> None: 

577 """Inplace complex conjugation.""" 

578 if np.issubdtype(self.dtype, np.complexfloating): 

579 self.xp.negative(self.data.imag, self.data.imag) 

580 

581 def add_hermitian_conjugate(self, scale: float = 1.0) -> None: 

582 """Add hermitian conjugate to myself.""" 

583 if self.dist.comm.size == 1: 

584 if scale != 1.0: 

585 self.data *= scale 

586 self.data += self.data.conj().T 

587 return 

588 tmp = self.copy() 

589 cgpaw.pblas_tran(*self.shape, scale, tmp.data, scale, self.data, 

590 self.dist.desc, self.dist.desc, True) 

591 

592 def tril2full(self) -> None: 

593 """Fill in upper triangle from lower triangle. 

594 

595 For a real matrix:: 

596 

597 a ? ? a b d 

598 b c ? -> b c e 

599 d e f d e f 

600 

601 For a complex matrix, the complex conjugate of the lower part will 

602 be inserted into the upper part. 

603 """ 

604 M, N = self.shape 

605 assert M == N 

606 

607 dist = self.dist 

608 

609 if dist.comm.size == 1 or dist.rows == 1 and dist.columns == 1: 

610 if dist.comm.rank == 0: 

611 lower = self.xp.tri(M, k=-1, dtype=bool) 

612 self.data.T[lower] = self.data[lower].conj() 

613 return 

614 

615 desc = dist.desc 

616 cgpaw.scalapack_set(self.data, desc, 0.0, 0.0, 'L', M - 1, M - 1, 2, 1) 

617 buf = self.data.copy() 

618 # Set diagonal to zero in the copy: 

619 cgpaw.scalapack_set(buf, desc, 0.0, 0.0, 'L', M, M, 1, 1) 

620 # Now transpose tmp_mm adding the result to the original matrix: 

621 cgpaw.pblas_tran(M, M, 1.0, buf, 1.0, self.data, desc, desc, True) 

622 

623 def add_to_diagonal(self, d: ArrayLike1D | float) -> None: 

624 """Add list of numbers or single number to diagonal of matrix.""" 

625 n1, n2 = self.dist.my_row_range() 

626 M, N = self.shape 

627 assert M == N 

628 self.data.ravel()[n1::N + 1] += d 

629 

630 def to_cpu(self) -> Matrix: 

631 """Create new matrix object with values transfered from GPU to CPU.""" 

632 return self.to_xp(np) 

633 

634 def to_xp(self, xp) -> Matrix: 

635 """Create new matrix object with data on GPU or CPU.""" 

636 if xp is self.xp: 

637 assert xp is np, 'cp -> cp should not be needed!' 

638 return self 

639 if xp is np: 

640 return self.dist.matrix(data=cp.asnumpy(self.data)) 

641 return self.dist.matrix(data=cp.asarray(self.data)) 

642 

643 def to_dtype(self, dtype) -> Matrix: 

644 """Convert to new data type.""" 

645 if dtype == self.dtype: 

646 return self 

647 return self.dist.matrix(data=self.data.astype(dtype)) 

648 

649 

650def _matrix(M): 

651 """Dig out Matrix object from wrapper(s).""" 

652 if isinstance(M, Matrix): 

653 return M 

654 return _matrix(M.matrix) 

655 

656 

657def redist(dist1, M1, dist2, M2, context): 

658 cgpaw.scalapack_redist(dist1.desc, dist2.desc, 

659 M1, M2, 

660 dist1.desc[2], dist1.desc[3], 

661 1, 1, 1, 1, # 1-indexing 

662 context, 'G') 

663 

664 

665def create_distribution(M: int, 

666 N: int, 

667 comm: MPIComm | None = None, 

668 r: int = 1, 

669 c: int = 1, 

670 b: int | None = None, 

671 xp=None) -> MatrixDistribution: 

672 if xp is cp: 

673 assert b is None 

674 if r == 1 and c == 1: 

675 pass # comm = None 

676 comm = comm or serial_comm 

677 return CuPyDistribution(M, N, comm, 

678 r if r != -1 else comm.size, 

679 c if c != -1 else comm.size, 

680 b) 

681 

682 if comm is None or comm.size == 1: 

683 assert r == 1 and abs(c) == 1 or c == 1 and abs(r) == 1 

684 return NoDistribution(M, N) 

685 

686 return BLACSDistribution(M, N, comm, 

687 r if r != -1 else comm.size, 

688 c if c != -1 else comm.size, 

689 b) 

690 

691 

692class MatrixDistribution: 

693 comm: MPIComm 

694 rows: int 

695 columns: int 

696 blocksize: int | None # None means everything on rank=0 

697 shape: tuple[int, int] 

698 full_shape: tuple[int, int] 

699 desc: Array1D 

700 

701 def matrix(self, dtype=None, data=None): 

702 return Matrix(*self.full_shape, dtype=dtype, data=data, dist=self) 

703 

704 def multiply(self, alpha, a, opa, b, opb, beta, c, symmetric): 

705 raise NotImplementedError 

706 

707 def eighg(self, H, L): 

708 raise NotImplementedError 

709 

710 def new(self, M, N): 

711 raise NotImplementedError 

712 

713 def my_row_range(self) -> tuple[int, int]: 

714 """Return indices for range of my rows. 

715 

716 >>> Matrix(2, 2).dist.my_row_range() 

717 (0, 2) 

718 """ 

719 ok = (self.rows == self.comm.size and 

720 self.columns == 1 and 

721 self.blocksize is None) 

722 if not ok: 

723 raise ValueError(f'Can not create slice of distribution: {self}') 

724 M = self.full_shape[0] 

725 b = (M + self.rows - 1) // self.rows 

726 n1 = self.comm.rank * b 

727 n2 = min(n1 + b, M) 

728 return n1, n2 

729 

730 

731class NoDistribution(MatrixDistribution): 

732 comm = serial_comm 

733 rows = 1 

734 columns = 1 

735 blocksize = None 

736 

737 def __init__(self, M, N): 

738 self.shape = (M, N) 

739 self.full_shape = (M, N) 

740 

741 def __str__(self): 

742 return 'NoDistribution({}x{})'.format(*self.shape) 

743 

744 def global_index(self, n): 

745 return n 

746 

747 def new(self, M, N): 

748 return NoDistribution(M, N) 

749 

750 def multiply(self, alpha, a, opa, b, opb, beta, c, symmetric): 

751 if symmetric: 

752 if opa == 'N': 

753 assert opb == 'C' or opb == 'T' and a.dtype == float 

754 if a is b: 

755 blas.rk(alpha, a.data, beta, c.data) 

756 else: 

757 if beta == 1.0 and a.shape[1] == 0: 

758 return 

759 blas.r2k(0.5 * alpha, a.data, b.data, beta, c.data) 

760 else: 

761 1 / 0 

762 assert opa == 'C' and opb == 'N' 

763 assert a is not b 

764 blas.r2k(0.5 * alpha, a.data, b.data, beta, c.data, 'n') 

765 

766 else: 

767 blas.mmm(alpha, a.data, opa, b.data, opb, beta, c.data) 

768 

769 

770class BLACSDistribution(MatrixDistribution): 

771 serial = False 

772 

773 def __init__(self, M, N, comm, r, c, b): 

774 self.comm = comm 

775 self.rows = r 

776 self.columns = c 

777 self.blocksize = b 

778 self.full_shape = (M, N) 

779 self.simple = False 

780 

781 key = (comm, r, c) 

782 context = _global_blacs_context_store.get(key) 

783 if context is None: 

784 try: 

785 context = cgpaw.new_blacs_context(comm.get_c_object(), 

786 c, r, 'R') 

787 except AttributeError: 

788 pass 

789 else: 

790 _global_blacs_context_store[key] = context 

791 

792 if b is None: 

793 if c == 1: 

794 br = (M + r - 1) // r 

795 bc = max(1, N) 

796 self.simple = True 

797 elif r == 1: 

798 br = M 

799 bc = (N + c - 1) // c 

800 else: 

801 raise ValueError('Please specify block size!') 

802 else: 

803 br = bc = b 

804 

805 if context is None: 

806 assert b is None 

807 assert c == 1 

808 n = N 

809 m = min((comm.rank + 1) * br, M) - min(comm.rank * br, M) 

810 else: 

811 n, m = cgpaw.get_blacs_local_shape(context, N, M, bc, br, 0, 0) 

812 if n < 0 or m < 0: 

813 n = m = 0 

814 self.shape = (m, n) 

815 lld = max(1, n) 

816 if context is not None: 

817 self.desc = np.array([1, context, N, M, bc, br, 0, 0, lld], 

818 np.intc) 

819 

820 def __str__(self): 

821 return ('BLACSDistribution(global={}, local={}, blocksize={})' 

822 .format(*('{}x{}'.format(*shape) 

823 for shape in [self.desc[3:1:-1], 

824 self.shape, 

825 self.desc[5:3:-1]]))) 

826 

827 def global_index(self, myi): 

828 return self.comm.rank * int(self.desc[5]) + myi 

829 

830 def new(self, M, N): 

831 return BLACSDistribution(M, N, 

832 self.comm, 

833 self.rows, self.columns, 

834 self.blocksize) 

835 

836 def multiply(self, alpha, a, opa, b, opb, beta, c, symmetric): 

837 if self.comm.size > 1: 

838 ok = a.dist.simple and b.dist.simple and c.dist.simple 

839 if ok: 

840 # Special cases that don't need scalapack - most likely also 

841 # faster: 

842 if opa == 'N' and opb == 'N': 

843 return mmm_nn(a, b, c, alpha, beta, blas.mmm) 

844 if opa == 'N' and opb == 'C': 

845 if symmetric: 

846 if beta == 1.0: 

847 return mmm_nc_sym(a, b, c, alpha, blas.mmm) 

848 else: 

849 return mmm_nc(a, b, c, alpha, beta, blas.mmm) 

850 

851 if symmetric: 

852 assert opa == 'N' 

853 assert opb == 'C' or opb == 'T' and a.dtype == float 

854 N, K = a.shape 

855 if a is b: 

856 cgpaw.pblas_rk(N, K, alpha, a.data, 

857 beta, c.data, 

858 a.dist.desc, c.dist.desc, 

859 'U') 

860 else: 

861 cgpaw.pblas_r2k(N, K, 0.5 * alpha, b.data, a.data, 

862 beta, c.data, 

863 b.dist.desc, a.dist.desc, c.dist.desc, 

864 'U') 

865 else: 

866 Ka, M = a.shape 

867 N, Kb = b.shape 

868 if opa == 'N': 

869 Ka, M = M, Ka 

870 if opb == 'N': 

871 N, Kb = Kb, N 

872 cgpaw.pblas_gemm(N, M, Ka, alpha, b.data, a.data, 

873 beta, c.data, 

874 b.dist.desc, a.dist.desc, c.dist.desc, 

875 opb, opa) 

876 return c 

877 

878 

879def cublas_mmm(alpha, a, opa, b, opb, beta, c): 

880 if c.size == 0: 

881 return 

882 if a.size == 0 and beta == 1.0: 

883 return 

884 gpu_gemm(opa.replace('C', 'H'), opb.replace('C', 'H'), 

885 a, b, c, alpha, beta) 

886 

887 

888class CuPyDistribution(MatrixDistribution): 

889 def __init__(self, M, N, comm, r, c, b): 

890 self.comm = comm 

891 self.rows = r 

892 self.columns = c 

893 self.blocksize = b 

894 self.full_shape = (M, N) 

895 # assert r == comm.size, (M, N, comm, r, c, b) 

896 assert c == 1 

897 br = (M + r - 1) // r 

898 m = min((comm.rank + 1) * br, M) - min(comm.rank * br, M) 

899 self.shape = (m, N) 

900 

901 def __str__(self): 

902 M, N = self.full_shape 

903 m, N = self.shape 

904 return f'CuPyDistribution(global={M}x{N}, local={m}x{N})' 

905 

906 def global_index(self, n): 

907 1 / 0 

908 return n 

909 

910 def new(self, M, N): 

911 return CuPyDistribution(M, N, 

912 self.comm, 

913 self.rows, self.columns, 

914 self.blocksize) 

915 

916 def multiply(self, alpha, a, opa, b, opb, beta, c, *, symmetric=False): 

917 if self.comm.size > 1: 

918 if opa == 'N' and opb == 'N': 

919 return mmm_nn(a, b, c, alpha, beta, cublas_mmm) 

920 if opa == 'N' and opb == 'C': 

921 if symmetric: 

922 if beta == 1.0: 

923 return mmm_nc_sym(a, b, c, alpha, cublas_mmm) 

924 else: 

925 return mmm_nc(a, b, c, alpha, beta, cublas_mmm) 

926 1 / 0 

927 

928 if symmetric: 

929 if opa == 'N': 

930 assert opb == 'C' or opb == 'T' and a.dtype == float 

931 if a is b: 

932 gpu_gemm('N', 'H', 

933 a.data, a.data, c.data, 

934 alpha, beta) 

935 # cp.cublas.syrk('N', a.data, c.data, alpha, beta, True) 

936 else: 

937 if beta == 1.0 and a.shape[1] == 0: 

938 return 

939 if c.data.size > 0: 

940 assert beta in [0.0, 1.0] 

941 # CuPy doesn't have dsyrk, so we roll our own: 

942 gpu_gemm('N', 'H', 

943 a.data, b.data, c.data, 

944 0.5 * alpha, beta) 

945 gpu_gemm('N', 'H', 

946 b.data, a.data, c.data, 

947 0.5 * alpha, 1.0) 

948 else: 

949 1 / 0 

950 assert opa == 'C' and opb == 'N' 

951 assert a is not b 

952 raise NotImplementedError 

953 blas.r2k(0.5 * alpha, a.data, b.data, beta, c.data, 'n') 

954 

955 else: 

956 cublas_mmm(alpha, a.data, opa, b.data, opb, beta, c.data) 

957 

958 def eighg(self, H, L): 

959 """ 

960 ::: 

961 

962 ~ † ~~ ~ †~ 

963 H = LHL , HC = CΛ, C = L C. 

964 """ 

965 assert self.comm.size == 1 

966 tmp = H.new() 

967 self.multiply(1.0, L, 'N', H, 'N', 0.0, tmp) 

968 self.multiply(1.0, tmp, 'N', L, 'C', 0.0, H, symmetric=True) 

969 eig_M, Ct_MM = cupy_eigh(H.data, UPLO='L') 

970 assert Ct_MM.flags.f_contiguous 

971 Ct = H.new(data=Ct_MM.T) 

972 self.multiply(1.0, L, 'C', Ct, 'T', 0.0, H) 

973 # H.complex_conjugate() 

974 return eig_M 

975 

976 

977def mmm_nn(m1, m2, m3, alpha, beta, mmm): 

978 """Parallel matrix-matrix multiplication. 

979 

980 ::: 

981 

982 m <- αm m + βm 

983 3 1 2 3 

984 """ 

985 comm = m1.dist.comm 

986 buf1 = m2.data 

987 xp = m1.xp 

988 

989 N = m1.shape[1] 

990 assert N == m2.shape[0], f'{N}, {m2.shape[0]}' 

991 n = (N + comm.size - 1) // comm.size 

992 

993 for r in range(comm.size): 

994 if r == 0: 

995 # Buffers... 

996 buf2 = xp.empty((n, buf1.shape[1]), dtype=buf1.dtype) 

997 

998 rrequest = None 

999 srequest = None 

1000 if r < comm.size - 1: 

1001 rrank = (comm.rank + r + 1) % comm.size 

1002 rn1 = min(rrank * n, N) 

1003 rn2 = min(rn1 + n, N) 

1004 if rn2 > rn1: 

1005 rrequest = comm.receive(buf2[:rn2 - rn1], rrank, 21, False) 

1006 srank = (comm.rank - r - 1) % comm.size 

1007 if len(m2.data) > 0: 

1008 srequest = comm.send(m2.data, srank, 21, False) 

1009 

1010 r0 = (comm.rank + r) % comm.size 

1011 n1 = min(r0 * n, N) 

1012 n2 = min(n1 + n, N) 

1013 # Contiguity... 

1014 mmm(alpha, m1.data[:, n1:n2], 'N', buf1[:n2 - n1], 'N', beta, m3.data) 

1015 

1016 beta = 1.0 

1017 

1018 if r == 0: 

1019 # Buffers... 

1020 buf1 = xp.empty_like(buf2) 

1021 

1022 buf1, buf2 = buf2, buf1 

1023 

1024 if rrequest: 

1025 comm.wait(rrequest) 

1026 if srequest: 

1027 comm.wait(srequest) 

1028 

1029 return m3 

1030 

1031 

1032def mmm_nc_sym(a, b, out, alpha, mmm): 

1033 """Symmetric parallel matrix-matrix multiplication. 

1034 

1035 ::: 

1036 

1037 

1038 c <- αab + c 

1039 

1040 This function utilizes the fact that c is symmetric, s.t.: 

1041 † † 

1042 c <- 0.5 * (αab + αba) + c 

1043 Only lower half of c is updated. 

1044 """ 

1045 comm = a.dist.comm 

1046 M, N = b.shape 

1047 m = (M + comm.size - 1) // comm.size 

1048 mym = len(b.data) 

1049 xp = a.xp 

1050 

1051 # Buffers... 

1052 buf1 = xp.empty((m, N), dtype=a.dtype) 

1053 buf2 = xp.empty((m, N), dtype=a.dtype) 

1054 half = comm.size // 2 

1055 aa = a.data 

1056 bb = b.data 

1057 

1058 for r in range(half + 1): 

1059 rrequest = None 

1060 srequest = None 

1061 

1062 if r < half: 

1063 srank = (comm.rank + r + 1) % comm.size 

1064 rrank = (comm.rank - r - 1) % comm.size 

1065 skip = (comm.size % 2 == 0 and r == half - 1) 

1066 m1 = min(rrank * m, M) 

1067 m2 = min(m1 + m, M) 

1068 if not (skip and comm.rank < half) and m2 > m1: 

1069 rrequest = comm.receive(buf1[:m2 - m1], rrank, 11, False) 

1070 if not (skip and comm.rank >= half) and mym > 0: 

1071 srequest = comm.send(b.data, srank, 11, False) 

1072 

1073 if not (comm.size % 2 == 0 and r == half and comm.rank < half): 

1074 m1 = min(((comm.rank - r) % comm.size) * m, M) 

1075 m2 = min(m1 + m, M) 

1076 if r == 0: 

1077 # symmmmmmmmmmmmmmmmmmmmmmetricccccccccccccccc 

1078 # Contiguity... 

1079 mmm(alpha, aa, 'N', bb, 'C', 1.0, out.data[:, m1:m2]) 

1080 else: 

1081 beta = 1.0 if r <= comm.rank else 0.0 

1082 mmm(alpha, aa, 'N', buf2[:m2 - m1], 'C', 

1083 beta, out.data[:, m1:m2]) 

1084 # out.data[:, m1:m2] = m12.data[:, :m2 - m1] 

1085 

1086 if rrequest: 

1087 comm.wait(rrequest) 

1088 if srequest: 

1089 comm.wait(srequest) 

1090 

1091 bb = buf1 

1092 buf1, buf2 = buf2, buf1 

1093 

1094 requests = [] 

1095 blocks = [] 

1096 nrows = (comm.size - 1) // 2 

1097 for row in range(nrows): 

1098 for column in range(comm.size - nrows + row, comm.size): 

1099 if comm.rank == row: 

1100 m1 = min(column * m, M) 

1101 m2 = min(m1 + m, M) 

1102 if mym > 0 and m2 > m1: 

1103 requests.append( 

1104 comm.send(out.data[:, m1:m2].T.conj().copy(), 

1105 column, 12, False)) 

1106 elif comm.rank == column: 

1107 m1 = min(row * m, M) 

1108 m2 = min(m1 + m, M) 

1109 if mym > 0 and m2 > m1: 

1110 block = xp.empty((mym, m2 - m1), out.dtype) 

1111 blocks.append((m1, m2, block)) 

1112 requests.append(comm.receive(block, row, 12, False)) 

1113 

1114 comm.waitall(requests) 

1115 for m1, m2, block in blocks: 

1116 out.data[:, m1:m2] += block 

1117 

1118 return out 

1119 

1120 

1121def mmm_nc(a, b, out, alpha, beta, mmm): 

1122 """Parallel matrix-matrix multiplication. 

1123 

1124 ::: 

1125 

1126 

1127 c <- αab + βc 

1128 """ 

1129 comm = a.dist.comm 

1130 M, N = b.shape 

1131 m = (M + comm.size - 1) // comm.size 

1132 mym = len(b.data) 

1133 xp = a.xp 

1134 

1135 # Nasty buffers 

1136 buf1 = xp.empty((m, N), dtype=a.dtype) 

1137 buf2 = xp.empty((m, N), dtype=a.dtype) 

1138 aa = a.data 

1139 bb = b.data 

1140 

1141 for r in range(comm.size): 

1142 rrequest = None 

1143 srequest = None 

1144 

1145 if r < comm.size - 1: 

1146 srank = (comm.rank + r + 1) % comm.size 

1147 rrank = (comm.rank - r - 1) % comm.size 

1148 m1 = min(rrank * m, M) 

1149 m2 = min(m1 + m, M) 

1150 if m2 > m1: 

1151 rrequest = comm.receive(buf1[:m2 - m1], rrank, 11, False) 

1152 if mym > 0: 

1153 srequest = comm.send(b.data, srank, 11, False) 

1154 

1155 m1 = min(((comm.rank - r) % comm.size) * m, M) 

1156 m2 = min(m1 + m, M) 

1157 # symmmmmmmmmmmmmmmmmmmmmmetricccccccccccccccc ?? 

1158 mmm(alpha, aa, 'N', bb[:m2 - m1], 'C', beta, out.data[:, m1:m2]) 

1159 

1160 if rrequest: 

1161 comm.wait(rrequest) 

1162 if srequest: 

1163 comm.wait(srequest) 

1164 

1165 bb = buf1 

1166 buf1, buf2 = buf2, buf1 

1167 

1168 return out