Coverage for gpaw/matrix.py: 34%

397 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1"""BLACS distributed matrix object.""" 

2from typing import Dict, Tuple 

3import numpy as np 

4import scipy.linalg as linalg 

5 

6import gpaw.cgpaw as cgpaw 

7from gpaw import debug 

8from gpaw.mpi import serial_comm, _Communicator 

9import gpaw.utilities.blas as blas 

10 

11 

12_global_blacs_context_store: Dict[Tuple[_Communicator, int, int], int] = {} 

13 

14 

15def matrix_matrix_multiply(alpha, a, opa, b, opb, beta=0.0, c=None, 

16 symmetric=False): 

17 """BLAS-style matrix-matrix multiplication. 

18 

19 Will use dgemm/zgemm/dsyrk/zherk/dsyr2k/zher2k as apropriate or the 

20 equivalent PBLAS functions for distributed matrices. 

21 

22 The coefficients alpha and beta are of type float. Matrices a, b and c 

23 must have same type (float or complex). The strings opa and opb must be 

24 'N', 'T', or 'C' . For opa='N' and opb='N', the operation performed is 

25 equivalent to:: 

26 

27 c.array[:] = alpha * np.dot(a.array, b.array) + beta * c.array 

28 

29 Replace a.array with a.array.T or a.array.T.conj() for opa='T' and 'C' 

30 respectively (similarly for opb). 

31 

32 Use symmetric=True if the result matrix is symmetric/hermetian 

33 (only lower half of c will be evaluated). 

34 """ 

35 return _matrix(a).multiply(alpha, opa, _matrix(b), opb, 

36 beta, c if c is None else _matrix(c), 

37 symmetric) 

38 

39 

40def suggest_blocking(N, ncpus): 

41 """Suggest blocking of NxN matrix. 

42 

43 Returns rows, columns, blocksize tuple.""" 

44 

45 nprow = ncpus 

46 npcol = 1 

47 

48 # Make npcol and nprow as close to each other as possible 

49 npcol_try = npcol 

50 while npcol_try < nprow: 

51 if ncpus % npcol_try == 0: 

52 npcol = npcol_try 

53 nprow = ncpus // npcol 

54 npcol_try += 1 

55 

56 assert npcol * nprow == ncpus 

57 

58 # ScaLAPACK creates trouble if there aren't at least a few whole blocks. 

59 # Choose block size so that there will always be at least one whole block 

60 # and at least two blocks in total. 

61 blocksize = max((N - 2) // max(nprow, npcol), 1) 

62 # The next commented line would give more whole blocks. 

63 # blocksize = max(N // max(nprow, npcol) - 2, 1) 

64 

65 # Use block size that is a power of 2 and at most 64 

66 blocksize = 2**int(np.log2(blocksize)) 

67 blocksize = max(min(blocksize, 64), 1) 

68 

69 return nprow, npcol, blocksize 

70 

71 

72class Matrix: 

73 def __init__(self, M, N, dtype=None, data=None, dist=None): 

74 """Matrix object. 

75 

76 M: int 

77 Rows. 

78 N: int 

79 Columns. 

80 dtype: type 

81 Data type (float or complex). 

82 dist: tuple or None 

83 BLACS distribution given as (communicator, rows, colums, blocksize) 

84 tuple. Default is None meaning no distribution. 

85 data: ndarray or None. 

86 Numpy ndarray to use for starage. By default, a new ndarray 

87 will be allocated. 

88 """ 

89 self.shape = (M, N) 

90 

91 if dtype is None: 

92 if data is None: 

93 dtype = float 

94 else: 

95 dtype = data.dtype 

96 self.dtype = np.dtype(dtype) 

97 

98 dist = dist or () 

99 if isinstance(dist, tuple): 

100 dist = create_distribution(M, N, *dist) 

101 self.dist = dist 

102 

103 if data is None: 

104 self.array = np.empty(dist.shape, self.dtype) 

105 else: 

106 self.array = data.reshape(dist.shape) 

107 

108 self.comm = serial_comm 

109 self.state = 'everything is fine' 

110 

111 def __len__(self): 

112 return self.shape[0] 

113 

114 def __repr__(self): 

115 dist = str(self.dist).split('(')[1] 

116 return f'Matrix({self.dtype.name}: {dist}' 

117 

118 def new(self, dist='inherit'): 

119 """Create new matrix of same shape and dtype. 

120 

121 Default is to use same BLACS distribution. Use dist to use another 

122 distribution. 

123 """ 

124 return Matrix(*self.shape, dtype=self.dtype, 

125 dist=self.dist if dist == 'inherit' else dist) 

126 

127 def __setitem__(self, i, x): 

128 # assert i == slice(None) 

129 if isinstance(x, np.ndarray): 

130 1 / 0 # sssssself.array[:] = x 

131 else: 

132 x.eval(self) 

133 

134 def __iadd__(self, x): 

135 x.eval(self, 1.0) 

136 return self 

137 

138 def multiply(self, alpha, opa, b, opb, beta=0.0, out=None, 

139 symmetric=False): 

140 """BLAS-style Matrix-matrix multiplication. 

141 

142 See matrix_matrix_multipliction() for details. 

143 """ 

144 dist = self.dist 

145 if out is None: 

146 assert beta == 0.0 

147 if opa == 'N': 

148 M = self.shape[0] 

149 else: 

150 M = self.shape[1] 

151 if opb == 'N': 

152 N = b.shape[1] 

153 else: 

154 N = b.shape[0] 

155 out = Matrix(M, N, self.dtype, 

156 dist=(dist.comm, dist.rows, dist.columns)) 

157 if dist.comm.size > 1: 

158 # Special cases that don't need scalapack - most likely also 

159 # faster: 

160 if alpha == 1.0 and opa == 'N' and opb == 'N': 

161 return fastmmm(self, b, out, beta) 

162 if alpha == 1.0 and beta == 1.0 and opa == 'N' and opb == 'C': 

163 if symmetric: 

164 return fastmmm2(self, b, out) 

165 else: 

166 return fastmmm2notsym(self, b, out) 

167 

168 dist.multiply(alpha, self, opa, b, opb, beta, out, symmetric) 

169 return out 

170 

171 def redist(self, other): 

172 """Redistribute to other BLACS layout.""" 

173 if self is other: 

174 return 

175 d1 = self.dist 

176 d2 = other.dist 

177 n1 = d1.rows * d1.columns 

178 n2 = d2.rows * d2.columns 

179 if n1 == n2 == 1: 

180 other.array[:] = self.array 

181 return 

182 

183 if n2 == 1 and d1.blocksize is None: 

184 assert d2.blocksize is None 

185 comm = d1.comm 

186 if comm.rank == 0: 

187 M = len(self) 

188 m = (M + comm.size - 1) // comm.size 

189 other.array[:m] = self.array 

190 for r in range(1, comm.size): 

191 m1 = min(r * m, M) 

192 m2 = min(m1 + m, M) 

193 comm.receive(other.array[m1:m2], r) 

194 else: 

195 comm.send(self.array, 0) 

196 return 

197 

198 if n1 == 1 and d2.blocksize is None: 

199 assert d1.blocksize is None 

200 comm = d1.comm 

201 if comm.rank == 0: 

202 M = len(self) 

203 m = (M + comm.size - 1) // comm.size 

204 other.array[:] = self.array[:m] 

205 for r in range(1, comm.size): 

206 m1 = min(r * m, M) 

207 m2 = min(m1 + m, M) 

208 comm.send(self.array[m1:m2], r) 

209 else: 

210 comm.receive(other.array, 0) 

211 return 

212 

213 c = d1.comm if d1.comm.size > d2.comm.size else d2.comm 

214 n = max(n1, n2) 

215 if n < c.size: 

216 c = c.new_communicator(np.arange(n)) 

217 if c is not None: 

218 M, N = self.shape 

219 d1 = create_distribution(M, N, c, 

220 d1.rows, d1.columns, d1.blocksize) 

221 d2 = create_distribution(M, N, c, 

222 d2.rows, d2.columns, d2.blocksize) 

223 if n1 == n: 

224 ctx = d1.desc[1] 

225 else: 

226 ctx = d2.desc[1] 

227 redist(d1, self.array, d2, other.array, ctx) 

228 

229 def invcholesky(self): 

230 """Inverse of Cholesky decomposition. 

231 

232 Only the lower part is used. 

233 """ 

234 if self.state == 'a sum is needed': 

235 self.comm.sum(self.array, 0) 

236 

237 if self.comm.rank == 0: 

238 if self.dist.comm.size > 1: 

239 S = self.new(dist=(self.dist.comm, 1, 1)) 

240 self.redist(S) 

241 else: 

242 S = self 

243 if self.dist.comm.rank == 0: 

244 if debug: 

245 S.array[np.triu_indices(S.shape[0], 1)] = 42.0 

246 L_nn = linalg.cholesky(S.array, 

247 lower=True, 

248 overwrite_a=True, 

249 check_finite=debug) 

250 S.array[:] = linalg.inv(L_nn, 

251 overwrite_a=True, 

252 check_finite=debug) 

253 if S is not self: 

254 S.redist(self) 

255 

256 if self.comm.size > 1: 

257 self.comm.broadcast(self.array, 0) 

258 self.state == 'everything is fine' 

259 

260 def eigh(self, cc=False, scalapack=(None, 1, 1, None)): 

261 """Calculate eigenvectors and eigenvalues. 

262 

263 Matrix must be symmetric/hermitian and stored in lower half. 

264 

265 cc: bool 

266 Complex conjugate matrix before finding eigenvalues. 

267 scalapack: tuple 

268 BLACS distribution for ScaLapack to use. Default is to do serial 

269 diagonalization. 

270 """ 

271 slcomm, rows, columns, blocksize = scalapack 

272 

273 if self.state == 'a sum is needed': 

274 self.comm.sum(self.array, 0) 

275 

276 slcomm = slcomm or self.dist.comm 

277 dist = (slcomm, rows, columns, blocksize) 

278 

279 redist = (rows != self.dist.rows or 

280 columns != self.dist.columns or 

281 blocksize != self.dist.blocksize) 

282 

283 if redist: 

284 H = self.new(dist=dist) 

285 self.redist(H) 

286 else: 

287 assert self.dist.comm.size == slcomm.size 

288 H = self 

289 

290 eps = np.empty(H.shape[0]) 

291 

292 if rows * columns == 1: 

293 if self.comm.rank == 0 and self.dist.comm.rank == 0: 

294 if cc and H.dtype == complex: 

295 np.negative(H.array.imag, H.array.imag) 

296 if debug: 

297 H.array[np.triu_indices(H.shape[0], 1)] = 42.0 

298 eps[:], H.array.T[:] = linalg.eigh(H.array, 

299 lower=True, # ??? 

300 overwrite_a=True, 

301 check_finite=debug) 

302 self.dist.comm.broadcast(eps, 0) 

303 else: 

304 if slcomm.rank < rows * columns: 

305 assert cc 

306 array = H.array.copy() 

307 info = cgpaw.scalapack_diagonalize_dc(array, H.dist.desc, 'U', 

308 H.array, eps) 

309 assert info == 0, info 

310 

311 # necessary to broadcast eps when some ranks are not used 

312 # in current scalapack parameter set 

313 # eg. (2, 1, 2) with 4 processes 

314 if rows * columns < slcomm.size: 

315 H.dist.comm.broadcast(eps, 0) 

316 

317 if redist: 

318 H.redist(self) 

319 

320 assert (self.state == 'a sum is needed') == ( 

321 self.comm.size > 1) 

322 if self.comm.size > 1: 

323 self.comm.broadcast(self.array, 0) 

324 self.comm.broadcast(eps, 0) 

325 self.state == 'everything is fine' 

326 

327 return eps 

328 

329 def complex_conjugate(self): 

330 """Inplace complex conjugation.""" 

331 if self.dtype == complex: 

332 np.negative(self.array.imag, self.array.imag) 

333 

334 

335def _matrix(M): 

336 """Dig out Matrix object from wrapper(s).""" 

337 if isinstance(M, Matrix): 

338 return M 

339 return _matrix(M.matrix) 

340 

341 

342class NoDistribution: 

343 comm = serial_comm 

344 rows = 1 

345 columns = 1 

346 blocksize = None 

347 

348 def __init__(self, M, N): 

349 self.shape = (M, N) 

350 

351 def __str__(self): 

352 return 'NoDistribution({}x{})'.format(*self.shape) 

353 

354 def global_index(self, n): 

355 return n 

356 

357 def multiply(self, alpha, a, opa, b, opb, beta, c, symmetric): 

358 if symmetric: 

359 assert opa == 'N' 

360 assert opb == 'C' or opb == 'T' and a.dtype == float 

361 if a is b: 

362 blas.rk(alpha, a.array, beta, c.array) 

363 else: 

364 if beta == 1.0 and a.shape[1] == 0: 

365 return 

366 blas.r2k(0.5 * alpha, a.array, b.array, beta, c.array) 

367 else: 

368 blas.mmm(alpha, a.array, opa, b.array, opb, beta, c.array) 

369 

370 

371class BLACSDistribution: 

372 serial = False 

373 

374 def __init__(self, M, N, comm, r, c, b): 

375 self.comm = comm 

376 self.rows = r 

377 self.columns = c 

378 self.blocksize = b 

379 

380 key = (comm, r, c) 

381 context = _global_blacs_context_store.get(key) 

382 if context is None: 

383 try: 

384 context = cgpaw.new_blacs_context(comm.get_c_object(), 

385 c, r, 'R') 

386 except AttributeError: 

387 pass 

388 else: 

389 _global_blacs_context_store[key] = context 

390 

391 if b is None: 

392 if c == 1: 

393 br = (M + r - 1) // r 

394 bc = max(1, N) 

395 elif r == 1: 

396 br = M 

397 bc = (N + c - 1) // c 

398 else: 

399 raise ValueError('Please specify block size!') 

400 else: 

401 br = bc = b 

402 

403 if context is None: 

404 assert b is None 

405 assert c == 1 

406 n = N 

407 m = min((comm.rank + 1) * br, M) - min(comm.rank * br, M) 

408 else: 

409 n, m = cgpaw.get_blacs_local_shape(context, N, M, bc, br, 0, 0) 

410 if n < 0 or m < 0: 

411 n = m = 0 

412 self.shape = (m, n) 

413 lld = max(1, n) 

414 if context is not None: 

415 self.desc = np.array([1, context, N, M, bc, br, 0, 0, lld], 

416 np.intc) 

417 

418 def __str__(self): 

419 return ('BLACSDistribution(global={}, local={}, blocksize={})' 

420 .format(*('{}x{}'.format(*shape) 

421 for shape in [self.desc[3:1:-1], 

422 self.shape, 

423 self.desc[5:3:-1]]))) 

424 

425 def global_index(self, myi): 

426 return self.comm.rank * int(self.desc[5]) + myi 

427 

428 def multiply(self, alpha, a, opa, b, opb, beta, c, symmetric): 

429 if symmetric: 

430 assert opa == 'N' 

431 assert opb == 'C' or opb == 'T' and a.dtype == float 

432 N, K = a.shape 

433 if a is b: 

434 cgpaw.pblas_rk(N, K, alpha, a.array, 

435 beta, c.array, 

436 a.dist.desc, c.dist.desc, 

437 'U') 

438 else: 

439 cgpaw.pblas_r2k(N, K, 0.5 * alpha, b.array, a.array, 

440 beta, c.array, 

441 b.dist.desc, a.dist.desc, c.dist.desc, 

442 'U') 

443 else: 

444 Ka, M = a.shape 

445 N, Kb = b.shape 

446 if opa == 'N': 

447 Ka, M = M, Ka 

448 if opb == 'N': 

449 N, Kb = Kb, N 

450 cgpaw.pblas_gemm(N, M, Ka, alpha, b.array, a.array, 

451 beta, c.array, 

452 b.dist.desc, a.dist.desc, c.dist.desc, 

453 opb, opa) 

454 

455 

456def redist(dist1, M1, dist2, M2, context): 

457 cgpaw.scalapack_redist(dist1.desc, dist2.desc, 

458 M1, M2, 

459 dist1.desc[2], dist1.desc[3], 

460 1, 1, 1, 1, # 1-indexing 

461 context, 'G') 

462 

463 

464def create_distribution(M, N, comm=None, r=1, c=1, b=None): 

465 if comm is None or comm.size == 1: 

466 assert r == 1 and abs(c) == 1 or c == 1 and abs(r) == 1 

467 return NoDistribution(M, N) 

468 

469 return BLACSDistribution(M, N, comm, 

470 r if r != -1 else comm.size, 

471 c if c != -1 else comm.size, 

472 b) 

473 

474 

475def fastmmm(m1, m2, m3, beta): 

476 comm = m1.dist.comm 

477 

478 buf1 = m2.array 

479 

480 N = len(m1) 

481 n = (N + comm.size - 1) // comm.size 

482 

483 for r in range(comm.size): 

484 if r == 0: 

485 buf2 = np.empty((n, buf1.shape[1]), dtype=buf1.dtype) 

486 

487 rrequest = None 

488 srequest = None 

489 if r < comm.size - 1: 

490 rrank = (comm.rank + r + 1) % comm.size 

491 rn1 = min(rrank * n, N) 

492 rn2 = min(rn1 + n, N) 

493 if rn2 > rn1: 

494 rrequest = comm.receive(buf2[:rn2 - rn1], rrank, 21, False) 

495 srank = (comm.rank - r - 1) % comm.size 

496 if len(m2.array) > 0: 

497 srequest = comm.send(m2.array, srank, 21, False) 

498 

499 r0 = (comm.rank + r) % comm.size 

500 n1 = min(r0 * n, N) 

501 n2 = min(n1 + n, N) 

502 blas.mmm(1.0, m1.array[:, n1:n2], 'N', buf1[:n2 - n1], 'N', 

503 beta, m3.array) 

504 

505 beta = 1.0 

506 

507 if r == 0: 

508 buf1 = np.empty_like(buf2) 

509 

510 buf1, buf2 = buf2, buf1 

511 

512 if rrequest: 

513 comm.wait(rrequest) 

514 if srequest: 

515 comm.wait(srequest) 

516 

517 return m3 

518 

519 

520def fastmmm2(a, b, out): 

521 if a.comm: 

522 assert b.comm is a.comm 

523 if a.comm.size > 1: 

524 assert out.comm == a.comm 

525 assert out.state == 'a sum is needed' 

526 

527 comm = a.dist.comm 

528 M, N = a.shape 

529 m = (M + comm.size - 1) // comm.size 

530 mym = len(a.array) 

531 

532 buf1 = np.empty((m, N), dtype=a.dtype) 

533 buf2 = np.empty((m, N), dtype=a.dtype) 

534 half = comm.size // 2 

535 aa = a.array 

536 bb = b.array 

537 

538 for r in range(half + 1): 

539 rrequest = None 

540 srequest = None 

541 

542 if r < half: 

543 srank = (comm.rank + r + 1) % comm.size 

544 rrank = (comm.rank - r - 1) % comm.size 

545 skip = (comm.size % 2 == 0 and r == half - 1) 

546 m1 = min(rrank * m, M) 

547 m2 = min(m1 + m, M) 

548 if not (skip and comm.rank < half) and m2 > m1: 

549 rrequest = comm.receive(buf1[:m2 - m1], rrank, 11, False) 

550 if not (skip and comm.rank >= half) and mym > 0: 

551 srequest = comm.send(b.array, srank, 11, False) 

552 

553 if not (comm.size % 2 == 0 and r == half and comm.rank < half): 

554 m1 = min(((comm.rank - r) % comm.size) * m, M) 

555 m2 = min(m1 + m, M) 

556 if r == 0: 

557 # symmmmmmmmmmmmmmmmmmmmmmetricccccccccccccccc 

558 blas.mmm(1.0, aa, 'N', bb, 'C', 1.0, out.array[:, m1:m2]) 

559 else: 

560 beta = 1.0 if r <= comm.rank else 0.0 

561 blas.mmm(1.0, aa, 'N', buf2[:m2 - m1], 'C', 

562 beta, out.array[:, m1:m2]) 

563 # out.array[:, m1:m2] = m12.array[:, :m2 - m1] 

564 

565 if rrequest: 

566 comm.wait(rrequest) 

567 if srequest: 

568 comm.wait(srequest) 

569 

570 bb = buf1 

571 buf1, buf2 = buf2, buf1 

572 

573 requests = [] 

574 blocks = [] 

575 nrows = (comm.size - 1) // 2 

576 for row in range(nrows): 

577 for column in range(comm.size - nrows + row, comm.size): 

578 if comm.rank == row: 

579 m1 = min(column * m, M) 

580 m2 = min(m1 + m, M) 

581 if mym > 0 and m2 > m1: 

582 requests.append( 

583 comm.send(out.array[:, m1:m2].T.conj().copy(), 

584 column, 12, False)) 

585 elif comm.rank == column: 

586 m1 = min(row * m, M) 

587 m2 = min(m1 + m, M) 

588 if mym > 0 and m2 > m1: 

589 block = np.empty((mym, m2 - m1), out.dtype) 

590 blocks.append((m1, m2, block)) 

591 requests.append(comm.receive(block, row, 12, False)) 

592 

593 comm.waitall(requests) 

594 for m1, m2, block in blocks: 

595 out.array[:, m1:m2] += block 

596 

597 return out 

598 

599 

600def fastmmm2notsym(a, b, out): 

601 if a.comm: 

602 assert b.comm is a.comm 

603 if a.comm.size > 1: 

604 assert out.comm == a.comm 

605 assert out.state == 'a sum is needed' 

606 

607 comm = a.dist.comm 

608 M, N = a.shape 

609 m = (M + comm.size - 1) // comm.size 

610 mym = len(a.array) 

611 

612 buf1 = np.empty((m, N), dtype=a.dtype) 

613 buf2 = np.empty((m, N), dtype=a.dtype) 

614 aa = a.array 

615 bb = b.array 

616 

617 for r in range(comm.size): 

618 rrequest = None 

619 srequest = None 

620 

621 if r < comm.size - 1: 

622 srank = (comm.rank + r + 1) % comm.size 

623 rrank = (comm.rank - r - 1) % comm.size 

624 m1 = min(rrank * m, M) 

625 m2 = min(m1 + m, M) 

626 if m2 > m1: 

627 rrequest = comm.receive(buf1[:m2 - m1], rrank, 11, False) 

628 if mym > 0: 

629 srequest = comm.send(b.array, srank, 11, False) 

630 

631 m1 = min(((comm.rank - r) % comm.size) * m, M) 

632 m2 = min(m1 + m, M) 

633 # symmmmmmmmmmmmmmmmmmmmmmetricccccccccccccccc ?? 

634 blas.mmm(1.0, aa, 'N', bb[:m2 - m1], 'C', 1.0, out.array[:, m1:m2]) 

635 

636 if rrequest: 

637 comm.wait(rrequest) 

638 if srequest: 

639 comm.wait(srequest) 

640 

641 bb = buf1 

642 buf1, buf2 = buf2, buf1 

643 

644 return out