Coverage for gpaw/mpi.py: 67%

558 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1# Copyright (C) 2003 CAMP 

2# Please see the accompanying LICENSE file for further information. 

3from __future__ import annotations 

4 

5import atexit 

6import pickle 

7import sys 

8import time 

9import traceback 

10from contextlib import contextmanager 

11from typing import Any 

12 

13import gpaw.cgpaw as cgpaw 

14import numpy as np 

15import warnings 

16from ase.parallel import MPI as ASE_MPI 

17from ase.parallel import world as aseworld 

18 

19import gpaw 

20 

21from ._broadcast_imports import world as _world 

22 

23MASTER = 0 

24 

25 

26def is_contiguous(*args, **kwargs): 

27 from gpaw.utilities import is_contiguous 

28 return is_contiguous(*args, **kwargs) 

29 

30 

31@contextmanager 

32def broadcast_exception(comm): 

33 """Make sure all ranks get a possible exception raised. 

34 

35 This example:: 

36 

37 with broadcast_exception(world): 

38 if world.rank == 0: 

39 x = 1 / 0 

40 

41 will raise ZeroDivisionError in the whole world. 

42 """ 

43 # Each core will send -1 on success or its rank on failure. 

44 try: 

45 yield 

46 except Exception as ex: 

47 rank = comm.max_scalar(comm.rank) 

48 if rank == comm.rank: 

49 broadcast(ex, rank, comm) 

50 raise 

51 else: 

52 rank = comm.max_scalar(-1) 

53 # rank will now be the highest failing rank or -1 

54 if rank >= 0: 

55 raise broadcast(None, rank, comm) 

56 

57 

58class _Communicator: 

59 def __init__(self, comm, parent=None): 

60 """Construct a wrapper of the C-object for any MPI-communicator. 

61 

62 Parameters: 

63 

64 comm: MPI-communicator 

65 Communicator. 

66 

67 Attributes: 

68 

69 ============ ====================================================== 

70 ``size`` Number of ranks in the MPI group. 

71 ``rank`` Number of this CPU in the MPI group. 

72 ``parent`` Parent MPI-communicator. 

73 ============ ====================================================== 

74 """ 

75 self.comm = comm 

76 self.size = comm.size 

77 self.rank = comm.rank 

78 self.parent = parent # XXX check C-object against comm.parent? 

79 

80 def __repr__(self): 

81 return f'MPIComm(size={self.size}, rank={self.rank})' 

82 

83 def new_communicator(self, ranks): 

84 """Create a new MPI communicator for a subset of ranks in a group. 

85 Must be called with identical arguments by all relevant processes. 

86 

87 Note that a valid communicator is only returned to the processes 

88 which are included in the new group; other ranks get None returned. 

89 

90 Parameters: 

91 

92 ranks: ndarray (type int) 

93 List of integers of the ranks to include in the new group. 

94 Note that these ranks correspond to indices in the current 

95 group whereas the rank attribute in the new communicators 

96 correspond to their respective index in the subset. 

97 

98 """ 

99 

100 comm = self.comm.new_communicator(ranks) 

101 if comm is None: 

102 # This cpu is not in the new communicator: 

103 return None 

104 else: 

105 return _Communicator(comm, parent=self) 

106 

107 def sum(self, a, root=-1): 

108 """Perform summation by MPI reduce operations of numerical data. 

109 

110 Parameters: 

111 

112 a: ndarray or value (type int, float or complex) 

113 Numerical data to sum over all ranks in the communicator group. 

114 If the data is a single value of type int, float or complex, 

115 the result is returned because the input argument is immutable. 

116 Otherwise, the reduce operation is carried out in-place such 

117 that the elements of the input array will represent the sum of 

118 the equivalent elements across all processes in the group. 

119 root: int (default -1) 

120 Rank of the root process, on which the outcome of the reduce 

121 operation is valid. A root rank of -1 signifies that the result 

122 will be distributed back to all processes, i.e. a broadcast. 

123 

124 """ 

125 if isinstance(a, (int, float, complex)): 

126 warnings.warn('Please use sum_scalar(...)', stacklevel=2) 

127 return self.comm.sum_scalar(a, root) 

128 else: 

129 # assert a.ndim != 0 

130 tc = a.dtype 

131 assert is_contiguous(a, tc) 

132 assert root == -1 or 0 <= root < self.size 

133 self.comm.sum(a, root) 

134 

135 def sum_scalar(self, a, root=-1): 

136 assert isinstance(a, (int, float, complex)) 

137 return self.comm.sum_scalar(a, root) 

138 

139 def product(self, a, root=-1): 

140 """Do multiplication by MPI reduce operations of numerical data. 

141 

142 Parameters: 

143 

144 a: ndarray or value (type int or float) 

145 Numerical data to multiply across all ranks in the communicator 

146 group. NB: Find the global product from the local products. 

147 If the data is a single value of type int or float (no complex), 

148 the result is returned because the input argument is immutable. 

149 Otherwise, the reduce operation is carried out in-place such 

150 that the elements of the input array will represent the product 

151 of the equivalent elements across all processes in the group. 

152 root: int (default -1) 

153 Rank of the root process, on which the outcome of the reduce 

154 operation is valid. A root rank of -1 signifies that the result 

155 will be distributed back to all processes, i.e. a broadcast. 

156 

157 """ 

158 if isinstance(a, (int, float)): 

159 1 / 0 

160 return self.comm.product(a, root) 

161 else: 

162 tc = a.dtype 

163 assert tc == int or tc == float 

164 assert is_contiguous(a, tc) 

165 assert root == -1 or 0 <= root < self.size 

166 self.comm.product(a, root) 

167 

168 def max(self, a, root=-1): 

169 """Find maximal value by an MPI reduce operation of numerical data. 

170 

171 Parameters: 

172 

173 a: ndarray or value (type int or float) 

174 Numerical data to find the maximum value of across all ranks in 

175 the communicator group. NB: Find global maximum from local max. 

176 If the data is a single value of type int or float (no complex), 

177 the result is returned because the input argument is immutable. 

178 Otherwise, the reduce operation is carried out in-place such 

179 that the elements of the input array will represent the max of 

180 the equivalent elements across all processes in the group. 

181 root: int (default -1) 

182 Rank of the root process, on which the outcome of the reduce 

183 operation is valid. A root rank of -1 signifies that the result 

184 will be distributed back to all processes, i.e. a broadcast. 

185 

186 """ 

187 if isinstance(a, (int, float)): 

188 warnings.warn('Please use max_scalar(...)', stacklevel=2) 

189 return self.comm.max_scalar(a, root) 

190 else: 

191 tc = a.dtype 

192 assert tc == int or tc == float 

193 assert is_contiguous(a, tc) 

194 assert root == -1 or 0 <= root < self.size 

195 self.comm.max(a, root) 

196 

197 def max_scalar(self, a, root=-1): 

198 assert isinstance(a, (int, float, np.int64)) 

199 return self.comm.max_scalar(a, root) 

200 

201 def min(self, a, root=-1): 

202 """Find minimal value by an MPI reduce operation of numerical data. 

203 

204 Parameters: 

205 

206 a: ndarray or value (type int or float) 

207 Numerical data to find the minimal value of across all ranks in 

208 the communicator group. NB: Find global minimum from local min. 

209 If the data is a single value of type int or float (no complex), 

210 the result is returned because the input argument is immutable. 

211 Otherwise, the reduce operation is carried out in-place such 

212 that the elements of the input array will represent the min of 

213 the equivalent elements across all processes in the group. 

214 root: int (default -1) 

215 Rank of the root process, on which the outcome of the reduce 

216 operation is valid. A root rank of -1 signifies that the result 

217 will be distributed back to all processes, i.e. a broadcast. 

218 

219 """ 

220 if isinstance(a, (int, float)): 

221 warnings.warn('Please use min_scalar(...)', stacklevel=2) 

222 return self.comm.min_scalar(a, root) 

223 else: 

224 tc = a.dtype 

225 assert tc == int or tc == float 

226 assert is_contiguous(a, tc) 

227 assert root == -1 or 0 <= root < self.size 

228 self.comm.min(a, root) 

229 

230 def min_scalar(self, a, root=-1): 

231 assert isinstance(a, (int, float)) 

232 return self.comm.min_scalar(a, root) 

233 

234 def scatter(self, a, b, root: int) -> None: 

235 """Distribute data from one rank to all other processes in a group. 

236 

237 Parameters: 

238 

239 a: ndarray (ignored on all ranks different from root; use None) 

240 Source of the data to distribute, i.e. send buffer on root rank. 

241 b: ndarray 

242 Destination of the distributed data, i.e. local receive buffer. 

243 The size of this array multiplied by the number of process in 

244 the group must match the size of the source array on the root. 

245 root: int 

246 Rank of the root process, from which the source data originates. 

247 

248 The reverse operation is ``gather``. 

249 

250 Example:: 

251 

252 # The master has all the interesting data. Distribute it. 

253 seed = 123456 

254 rng = np.random.default_rng(seed) 

255 if comm.rank == 0: 

256 data = rng.normal(size=N*comm.size) 

257 else: 

258 data = None 

259 mydata = np.empty(N, dtype=float) 

260 comm.scatter(data, mydata, 0) 

261 

262 # .. which is equivalent to .. 

263 

264 if comm.rank == 0: 

265 # Extract my part directly 

266 mydata[:] = data[0:N] 

267 # Distribute parts to the slaves 

268 for rank in range(1, comm.size): 

269 buf = data[rank*N:(rank+1)*N] 

270 comm.send(buf, rank, tag=123) 

271 else: 

272 # Receive from the master 

273 comm.receive(mydata, 0, tag=123) 

274 

275 """ 

276 if self.rank == root: 

277 assert a.dtype == b.dtype 

278 assert a.size == self.size * b.size 

279 assert a.flags.c_contiguous 

280 assert b.flags.c_contiguous 

281 assert 0 <= root < self.size 

282 self.comm.scatter(a, b, root) 

283 

284 def alltoallv(self, sbuffer, scounts, sdispls, rbuffer, rcounts, rdispls): 

285 """All-to-all in a group. 

286 

287 Parameters: 

288 

289 sbuffer: ndarray 

290 Source of the data to distribute, i.e., send buffers on all ranks 

291 scounts: ndarray 

292 Integer array equal to the group size specifying the number of 

293 elements to send to each processor 

294 sdispls: ndarray 

295 Integer array (of length group size). Entry j specifies the 

296 displacement (relative to sendbuf from which to take the 

297 outgoing data destined for process j) 

298 rbuffer: ndarray 

299 Destination of the distributed data, i.e., local receive buffer. 

300 rcounts: ndarray 

301 Integer array equal to the group size specifying the maximum 

302 number of elements that can be received from each processor. 

303 rdispls: 

304 Integer array (of length group size). Entry i specifies the 

305 displacement (relative to recvbuf at which to place the incoming 

306 data from process i 

307 """ 

308 assert sbuffer.flags.c_contiguous 

309 assert scounts.flags.c_contiguous 

310 assert sdispls.flags.c_contiguous 

311 assert rbuffer.flags.c_contiguous 

312 assert rcounts.flags.c_contiguous 

313 assert rdispls.flags.c_contiguous 

314 assert sbuffer.dtype == rbuffer.dtype 

315 

316 for arr in [scounts, sdispls, rcounts, rdispls]: 

317 assert arr.dtype == int, arr.dtype 

318 assert len(arr) == self.size 

319 

320 assert np.all(0 <= sdispls) 

321 assert np.all(0 <= rdispls) 

322 assert np.all(sdispls + scounts <= sbuffer.size) 

323 assert np.all(rdispls + rcounts <= rbuffer.size) 

324 self.comm.alltoallv(sbuffer, scounts, sdispls, 

325 rbuffer, rcounts, rdispls) 

326 

327 def all_gather(self, a, b): 

328 """Gather data from all ranks onto all processes in a group. 

329 

330 Parameters: 

331 

332 a: ndarray 

333 Source of the data to gather, i.e. send buffer of this rank. 

334 b: ndarray 

335 Destination of the distributed data, i.e. receive buffer. 

336 The size of this array must match the size of the distributed 

337 source arrays multiplied by the number of process in the group. 

338 

339 Example:: 

340 

341 # All ranks have parts of interesting data. Gather on all ranks. 

342 seed = 123456 

343 rng = np.random.default_rng(seed) 

344 mydata = rng.normal(size=N) 

345 data = np.empty(N*comm.size, dtype=float) 

346 comm.all_gather(mydata, data) 

347 

348 # .. which is equivalent to .. 

349 

350 if comm.rank == 0: 

351 # Insert my part directly 

352 data[0:N] = mydata 

353 # Gather parts from the slaves 

354 buf = np.empty(N, dtype=float) 

355 for rank in range(1, comm.size): 

356 comm.receive(buf, rank, tag=123) 

357 data[rank*N:(rank+1)*N] = buf 

358 else: 

359 # Send to the master 

360 comm.send(mydata, 0, tag=123) 

361 # Broadcast from master to all slaves 

362 comm.broadcast(data, 0) 

363 

364 """ 

365 assert a.flags.contiguous 

366 assert b.flags.contiguous 

367 assert b.dtype == a.dtype 

368 assert (b.shape[0] == self.size and a.shape == b.shape[1:] or 

369 a.size * self.size == b.size) 

370 self.comm.all_gather(a, b) 

371 

372 def gather(self, a, root, b=None): 

373 """Gather data from all ranks onto a single process in a group. 

374 

375 Parameters: 

376 

377 a: ndarray 

378 Source of the data to gather, i.e. send buffer of this rank. 

379 root: int 

380 Rank of the root process, on which the data is to be gathered. 

381 b: ndarray (ignored on all ranks different from root; default None) 

382 Destination of the distributed data, i.e. root's receive buffer. 

383 The size of this array must match the size of the distributed 

384 source arrays multiplied by the number of process in the group. 

385 

386 The reverse operation is ``scatter``. 

387 

388 Example:: 

389 

390 # All ranks have parts of interesting data. Gather it on master. 

391 seed = 123456 

392 rng = np.random.default_rng(seed) 

393 mydata = rng.normal(size=N) 

394 if comm.rank == 0: 

395 data = np.empty(N*comm.size, dtype=float) 

396 else: 

397 data = None 

398 comm.gather(mydata, 0, data) 

399 

400 # .. which is equivalent to .. 

401 

402 if comm.rank == 0: 

403 # Extract my part directly 

404 data[0:N] = mydata 

405 # Gather parts from the slaves 

406 buf = np.empty(N, dtype=float) 

407 for rank in range(1, comm.size): 

408 comm.receive(buf, rank, tag=123) 

409 data[rank*N:(rank+1)*N] = buf 

410 else: 

411 # Send to the master 

412 comm.send(mydata, 0, tag=123) 

413 

414 """ 

415 assert a.flags.c_contiguous 

416 assert 0 <= root < self.size 

417 if root == self.rank: 

418 assert b.flags.c_contiguous and b.dtype == a.dtype 

419 assert (b.shape[0] == self.size and a.shape == b.shape[1:] or 

420 a.size * self.size == b.size) 

421 self.comm.gather(a, root, b) 

422 else: 

423 assert b is None 

424 self.comm.gather(a, root) 

425 

426 def broadcast(self, a, root): 

427 """Share data from a single process to all ranks in a group. 

428 

429 Parameters: 

430 

431 a: ndarray 

432 Data, i.e. send buffer on root rank, receive buffer elsewhere. 

433 Note that after the broadcast, all ranks have the same data. 

434 root: int 

435 Rank of the root process, from which the data is to be shared. 

436 

437 Example:: 

438 

439 # All ranks have parts of interesting data. Take a given index. 

440 seed = 123456 

441 rng = np.random.default_rng(seed) 

442 mydata[:] = rng.normal(size=N) 

443 

444 # Who has the element at global index 13? Everybody needs it! 

445 index = 13 

446 root, myindex = divmod(index, N) 

447 element = np.empty(1, dtype=float) 

448 if comm.rank == root: 

449 # This process has the requested element so extract it 

450 element[:] = mydata[myindex] 

451 

452 # Broadcast from owner to everyone else 

453 comm.broadcast(element, root) 

454 

455 # .. which is equivalent to .. 

456 

457 if comm.rank == root: 

458 # We are root so send it to the other ranks 

459 for rank in range(comm.size): 

460 if rank != root: 

461 comm.send(element, rank, tag=123) 

462 else: 

463 # We don't have it so receive from root 

464 comm.receive(element, root, tag=123) 

465 

466 """ 

467 assert 0 <= root < self.size 

468 assert is_contiguous(a) 

469 self.comm.broadcast(a, root) 

470 

471 def sendreceive(self, a, dest, b, src, sendtag=123, recvtag=123): 

472 assert 0 <= dest < self.size 

473 assert dest != self.rank 

474 assert is_contiguous(a) 

475 assert 0 <= src < self.size 

476 assert src != self.rank 

477 assert is_contiguous(b) 

478 return self.comm.sendreceive(a, dest, b, src, sendtag, recvtag) 

479 

480 def send(self, a, dest, tag=123, block=True): 

481 assert 0 <= dest < self.size 

482 assert dest != self.rank 

483 assert is_contiguous(a) 

484 if not block: 

485 pass # assert sys.getrefcount(a) > 3 

486 return self.comm.send(a, dest, tag, block) 

487 

488 def ssend(self, a, dest, tag=123): 

489 assert 0 <= dest < self.size 

490 assert dest != self.rank 

491 assert is_contiguous(a) 

492 return self.comm.ssend(a, dest, tag) 

493 

494 def receive(self, a, src, tag=123, block=True): 

495 assert 0 <= src < self.size 

496 assert src != self.rank 

497 assert is_contiguous(a) 

498 return self.comm.receive(a, src, tag, block) 

499 

500 def test(self, request): 

501 """Test whether a non-blocking MPI operation has completed. A boolean 

502 is returned immediately and the request is not modified in any way. 

503 

504 Parameters: 

505 

506 request: MPI request 

507 Request e.g. returned from send/receive when block=False is used. 

508 

509 """ 

510 return self.comm.test(request) 

511 

512 def testall(self, requests): 

513 """Test whether non-blocking MPI operations have completed. A boolean 

514 is returned immediately but requests may have been deallocated as a 

515 result, provided they have completed before or during this invokation. 

516 

517 Parameters: 

518 

519 request: MPI request 

520 Request e.g. returned from send/receive when block=False is used. 

521 

522 """ 

523 return self.comm.testall(requests) # may deallocate requests! 

524 

525 def wait(self, request): 

526 """Wait for a non-blocking MPI operation to complete before returning. 

527 

528 Parameters: 

529 

530 request: MPI request 

531 Request e.g. returned from send/receive when block=False is used. 

532 

533 """ 

534 self.comm.wait(request) 

535 

536 def waitall(self, requests): 

537 """Wait for non-blocking MPI operations to complete before returning. 

538 

539 Parameters: 

540 

541 requests: list 

542 List of MPI requests e.g. aggregated from returned requests of 

543 multiple send/receive calls where block=False was used. 

544 

545 """ 

546 self.comm.waitall(requests) 

547 

548 def abort(self, errcode): 

549 """Terminate MPI execution environment of all tasks in the group. 

550 This function only returns in the advent of an error occurring. 

551 

552 Parameters: 

553 

554 errcode: int 

555 Error code to return to the invoking environment. 

556 

557 """ 

558 return self.comm.abort(errcode) 

559 

560 def name(self): 

561 """Return the name of the processor as a string.""" 

562 return self.comm.name() 

563 

564 def barrier(self): 

565 """Block execution until all process have reached this point.""" 

566 self.comm.barrier() 

567 

568 def compare(self, othercomm): 

569 """Compare communicator to other. 

570 

571 Returns 'ident' if they are identical, 'congruent' if they are 

572 copies of each other, 'similar' if they are permutations of 

573 each other, and otherwise 'unequal'. 

574 

575 This method corresponds to MPI_Comm_compare.""" 

576 if isinstance(self.comm, SerialCommunicator): 

577 return self.comm.compare(othercomm.comm) # argh! 

578 result = self.comm.compare(othercomm.get_c_object()) 

579 assert result in ['ident', 'congruent', 'similar', 'unequal'] 

580 return result 

581 

582 def translate_ranks(self, other, ranks): 

583 """"Translate ranks from communicator to other. 

584 

585 ranks must be valid on this communicator. Returns ranks 

586 on other communicator corresponding to the same processes. 

587 Ranks that are not defined on the other communicator are 

588 assigned values of -1. (In contrast to MPI which would 

589 assign MPI_UNDEFINED).""" 

590 assert hasattr(other, 'translate_ranks'), \ 

591 'Excpected communicator, got %s' % other 

592 assert all(0 <= rank for rank in ranks) 

593 assert all(rank < self.size for rank in ranks) 

594 if isinstance(self.comm, SerialCommunicator): 

595 return self.comm.translate_ranks(other.comm, ranks) # argh! 

596 otherranks = self.comm.translate_ranks(other.get_c_object(), ranks) 

597 assert all(-1 <= rank for rank in otherranks) 

598 assert ranks.dtype == otherranks.dtype 

599 return otherranks 

600 

601 def get_members(self): 

602 """Return the subset of processes which are members of this MPI group 

603 in terms of the ranks they are assigned on the parent communicator. 

604 For the world communicator, this is all integers up to ``size``. 

605 

606 Example:: 

607 

608 >>> world.rank, world.size # doctest: +SKIP 

609 (3, 4) 

610 >>> world.get_members() # doctest: +SKIP 

611 array([0, 1, 2, 3]) 

612 >>> comm = world.new_communicator(np.array([2, 3])) # doctest: +SKIP 

613 >>> comm.rank, comm.size # doctest: +SKIP 

614 (1, 2) 

615 >>> comm.get_members() # doctest: +SKIP 

616 array([2, 3]) 

617 >>> comm.get_members()[comm.rank] == world.rank # doctest: +SKIP 

618 True 

619 

620 """ 

621 return self.comm.get_members() 

622 

623 def get_c_object(self): 

624 """Return the C-object wrapped by this debug interface. 

625 

626 Whenever a communicator object is passed to C code, that object 

627 must be a proper C-object - *not* e.g. this debug wrapper. For 

628 this reason. The C-communicator object has a get_c_object() 

629 implementation which returns itself; thus, always call 

630 comm.get_c_object() and pass the resulting object to the C code. 

631 """ 

632 c_obj = self.comm.get_c_object() 

633 if isinstance(c_obj, cgpaw.Communicator): 

634 return c_obj 

635 return c_obj.get_c_object() 

636 

637 

638MPIComm = _Communicator # for type hints 

639 

640 

641# Serial communicator 

642class SerialCommunicator: 

643 size = 1 

644 rank = 0 

645 

646 def __init__(self, parent=None): 

647 self.parent = parent 

648 

649 def __repr__(self): 

650 return 'SerialCommunicator()' 

651 

652 def sum(self, array, root=-1): 

653 if isinstance(array, (int, float, complex)): 

654 warnings.warn('Please use sum_scalar(...)', stacklevel=2) 

655 return array 

656 

657 def sum_scalar(self, a, root=-1): 

658 return a 

659 

660 def scatter(self, s, r, root): 

661 r[:] = s 

662 

663 def min(self, value, root=-1): 

664 if isinstance(value, (int, float, complex)): 

665 warnings.warn('Please use min_scalar(...)', stacklevel=2) 

666 return value 

667 

668 def min_scalar(self, value, root=-1): 

669 return value 

670 

671 def max(self, value, root=-1): 

672 if isinstance(value, (int, float, complex)): 

673 warnings.warn('Please use max_scalar(...)', stacklevel=2) 

674 return value 

675 

676 def max_scalar(self, value, root=-1): 

677 return value 

678 

679 def broadcast(self, buf, root): 

680 pass 

681 

682 def send(self, buff, dest, tag=123, block=True): 

683 pass 

684 

685 def barrier(self): 

686 pass 

687 

688 def gather(self, a, root, b): 

689 b[:] = a 

690 

691 def all_gather(self, a, b): 

692 b[:] = a 

693 

694 def alltoallv(self, sbuffer, scounts, sdispls, rbuffer, rcounts, rdispls): 

695 assert len(scounts) == 1 

696 assert len(sdispls) == 1 

697 assert len(rcounts) == 1 

698 assert len(rdispls) == 1 

699 assert len(sbuffer) == len(rbuffer) 

700 

701 rbuffer[rdispls[0]:rdispls[0] + rcounts[0]] = \ 

702 sbuffer[sdispls[0]:sdispls[0] + scounts[0]] 

703 

704 def new_communicator(self, ranks): 

705 if self.rank not in ranks: 

706 return None 

707 comm = SerialCommunicator(parent=self) 

708 comm.size = len(ranks) 

709 return comm 

710 

711 def test(self, request): 

712 return 1 

713 

714 def testall(self, requests): 

715 return 1 

716 

717 def wait(self, request): 

718 raise NotImplementedError('Calls to mpi wait should not happen in ' 

719 'serial mode') 

720 

721 def waitall(self, requests): 

722 if not requests: 

723 return 

724 raise NotImplementedError('Calls to mpi waitall should not happen in ' 

725 'serial mode') 

726 

727 def get_members(self): 

728 return np.array([0]) 

729 

730 def compare(self, other): 

731 if self == other: 

732 return 'ident' 

733 elif other.size == 1: 

734 return 'congruent' 

735 else: 

736 raise NotImplementedError('Compare serial comm to other') 

737 

738 def translate_ranks(self, other, ranks): 

739 if isinstance(other, SerialCommunicator): 

740 assert all(rank == 0 for rank in ranks) or gpaw.dry_run 

741 return np.zeros(len(ranks), dtype=int) 

742 return np.array([other.rank for rank in ranks]) 

743 raise NotImplementedError( 

744 'Translate non-trivial ranks with serial comm') 

745 

746 def get_c_object(self): 

747 if gpaw.dry_run: 

748 return None # won't actually be passed to C 

749 return _world 

750 

751 

752_serial_comm = SerialCommunicator() 

753 

754have_mpi = _world is not None 

755 

756if not have_mpi: 

757 _world = _serial_comm # type: ignore 

758 

759if gpaw.debug: 

760 serial_comm = _Communicator(_serial_comm) 

761 if _world.size == 1: 

762 world = serial_comm 

763 else: 

764 world = _Communicator(_world) 

765else: 

766 serial_comm = _serial_comm # type: ignore 

767 world = _world # type: ignore 

768 

769rank = world.rank 

770size = world.size 

771parallel = (size > 1) 

772 

773 

774def verify_ase_world(): 

775 # ASE does not like that GPAW uses world.size at import time. 

776 # .... because of GPAW's own import time communicator mish-mash. 

777 # Now, GPAW wants to verify world.size and cannot do so, 

778 # because of what ASE does for GPAW's sake. 

779 # This really needs improvement! 

780 assert aseworld is not None 

781 

782 if isinstance(aseworld, ASE_MPI): 

783 # We only want to check if the communicator was already initialized. 

784 # Otherwise the communicator will be initialized as a side effect 

785 # of accessing the .size attribute, 

786 # which ASE's tests will complain about. 

787 check_size = aseworld.comm is not None 

788 else: 

789 check_size = True # A real communicator, so we want to check that 

790 

791 if check_size and world.size != aseworld.size: 

792 raise RuntimeError('Please use "gpaw python" to run in parallel') 

793 

794 

795verify_ase_world() 

796 

797 

798def broadcast(obj, root=0, comm=world): 

799 """Broadcast a Python object across an MPI communicator and return it.""" 

800 if comm.rank == root: 

801 assert obj is not None 

802 b = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) 

803 else: 

804 assert obj is None 

805 b = None 

806 b = broadcast_bytes(b, root, comm) 

807 if comm.rank == root: 

808 return obj 

809 else: 

810 return pickle.loads(b) 

811 

812 

813def broadcast_float(x, comm): 

814 array = np.array([x]) 

815 comm.broadcast(array, 0) 

816 return array[0] 

817 

818 

819def synchronize_atoms(atoms, comm, tolerance=1e-8): 

820 """Synchronize atoms between multiple CPUs removing numerical noise. 

821 

822 If the atoms differ significantly, raise ValueError on all ranks. 

823 The error object contains the ranks where the check failed. 

824 

825 In debug mode, write atoms to files in case of failure.""" 

826 

827 if len(atoms) == 0: 

828 return 

829 

830 if comm.rank == 0: 

831 src = (atoms.positions, atoms.cell, atoms.numbers, atoms.pbc) 

832 else: 

833 src = None 

834 

835 # XXX replace with ase.cell.same_cell in the future 

836 # (if that functions gets to exist) 

837 # def same_cell(cell1, cell2): 

838 # return ((cell1 is None) == (cell2 is None) and 

839 # (cell1 is None or (cell1 == cell2).all())) 

840 

841 # Cell vectors should be compared with a tolerance like positions? 

842 def same_cell(cell1, cell2, tolerance=1e-8): 

843 return ((cell1 is None) == (cell2 is None) and 

844 (cell1 is None or (abs(cell1 - cell2).max() <= tolerance))) 

845 

846 positions, cell, numbers, pbc = broadcast(src, root=0, comm=comm) 

847 ok = (len(positions) == len(atoms.positions) and 

848 (abs(positions - atoms.positions).max() <= tolerance) and 

849 (numbers == atoms.numbers).all() and 

850 same_cell(cell, atoms.cell) and 

851 (pbc == atoms.pbc).all()) 

852 

853 # We need to fail equally on all ranks to avoid trouble. Thus 

854 # we use an array to gather check results from everyone. 

855 my_fail = np.array(not ok, dtype=bool) 

856 all_fail = np.zeros(comm.size, dtype=bool) 

857 comm.all_gather(my_fail, all_fail) 

858 

859 if all_fail.any(): 

860 if gpaw.debug: 

861 with open('synchronize_atoms_r%d.pckl' % comm.rank, 'wb') as fd: 

862 pickle.dump((atoms.positions, atoms.cell, 

863 atoms.numbers, atoms.pbc, 

864 positions, cell, numbers, pbc), fd) 

865 err_ranks = np.arange(comm.size)[all_fail] 

866 raise ValueError('Mismatch of Atoms objects. In debug ' 

867 'mode, atoms will be dumped to files.', 

868 err_ranks) 

869 

870 atoms.positions = positions 

871 atoms.cell = cell 

872 

873 

874def broadcast_string(string=None, root=0, comm=world): 

875 if comm.rank == root: 

876 string = string.encode() 

877 return broadcast_bytes(string, root, comm).decode() 

878 

879 

880def broadcast_bytes(b=None, root=0, comm=world): 

881 """Broadcast a bytes across an MPI communicator and return it.""" 

882 if comm.rank == root: 

883 assert isinstance(b, bytes) 

884 n = np.array(len(b), int) 

885 else: 

886 assert b is None 

887 n = np.zeros(1, int) 

888 comm.broadcast(n, root) 

889 if comm.rank == root: 

890 b = np.frombuffer(b, np.int8) 

891 else: 

892 b = np.zeros(n, np.int8) 

893 comm.broadcast(b, root) 

894 return b.tobytes() 

895 

896 

897def broadcast_array(array: np.ndarray, *communicators) -> np.ndarray: 

898 """Broadcast np.ndarray across sequence of MPI-communicators.""" 

899 comms = list(communicators) 

900 while comms: 

901 comm = comms.pop() 

902 if all(comm.rank == 0 for comm in comms): 

903 comm.broadcast(array, 0) 

904 return array 

905 

906 

907def send(obj, rank: int, comm: MPIComm) -> None: 

908 """Send object to rank on the MPI communicator comm.""" 

909 b = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) 

910 comm.send(np.array(len(b)), rank) 

911 comm.send(np.frombuffer(b, np.int8).copy(), rank) 

912 

913 

914def receive(rank: int, comm: MPIComm) -> Any: 

915 """Receive object from rank on the MPI communicator comm.""" 

916 n = np.array(0) 

917 comm.receive(n, rank) 

918 buf = np.empty(int(n), np.int8) 

919 comm.receive(buf, rank) 

920 return pickle.loads(buf.tobytes()) 

921 

922 

923def send_string(string, rank, comm=world): 

924 b = string.encode() 

925 comm.send(np.array(len(b)), rank) 

926 comm.send(np.frombuffer(b, np.int8).copy(), rank) 

927 

928 

929def receive_string(rank, comm=world): 

930 n = np.array(0) 

931 comm.receive(n, rank) 

932 string = np.empty(n, np.int8) 

933 comm.receive(string, rank) 

934 return string.tobytes().decode() 

935 

936 

937def alltoallv_string(send_dict, comm=world): 

938 scounts = np.zeros(comm.size, dtype=int) 

939 sdispls = np.zeros(comm.size, dtype=int) 

940 stotal = 0 

941 for proc in range(comm.size): 

942 if proc in send_dict: 

943 data = np.frombuffer(send_dict[proc].encode(), np.int8) 

944 scounts[proc] = data.size 

945 sdispls[proc] = stotal 

946 stotal += scounts[proc] 

947 

948 rcounts = np.zeros(comm.size, dtype=int) 

949 comm.alltoallv(scounts, np.ones(comm.size, dtype=int), 

950 np.arange(comm.size, dtype=int), 

951 rcounts, np.ones(comm.size, dtype=int), 

952 np.arange(comm.size, dtype=int)) 

953 rdispls = np.zeros(comm.size, dtype=int) 

954 rtotal = 0 

955 for proc in range(comm.size): 

956 rdispls[proc] = rtotal 

957 rtotal += rcounts[proc] 

958 # rtotal += rcounts[proc] # CHECK: is this correct? 

959 

960 sbuffer = np.zeros(stotal, dtype=np.int8) 

961 for proc in range(comm.size): 

962 sbuffer[sdispls[proc]:(sdispls[proc] + scounts[proc])] = ( 

963 np.frombuffer(send_dict[proc].encode(), np.int8)) 

964 

965 rbuffer = np.zeros(rtotal, dtype=np.int8) 

966 comm.alltoallv(sbuffer, scounts, sdispls, rbuffer, rcounts, rdispls) 

967 

968 rdict = {} 

969 for proc in range(comm.size): 

970 i = rdispls[proc] 

971 rdict[proc] = rbuffer[i:i + rcounts[proc]].tobytes().decode() 

972 

973 return rdict 

974 

975 

976def ibarrier(timeout=None, root=0, tag=123, comm=world): 

977 """Non-blocking barrier returning a list of requests to wait for. 

978 An optional time-out may be given, turning the call into a blocking 

979 barrier with an upper time limit, beyond which an exception is raised.""" 

980 requests = [] 

981 byte = np.ones(1, dtype=np.int8) 

982 if comm.rank == root: 

983 # Everybody else: 

984 for rank in range(comm.size): 

985 if rank == root: 

986 continue 

987 rbuf, sbuf = np.empty_like(byte), byte.copy() 

988 requests.append(comm.send(sbuf, rank, tag=2 * tag + 0, 

989 block=False)) 

990 requests.append(comm.receive(rbuf, rank, tag=2 * tag + 1, 

991 block=False)) 

992 else: 

993 rbuf, sbuf = np.empty_like(byte), byte 

994 requests.append(comm.receive(rbuf, root, tag=2 * tag + 0, block=False)) 

995 requests.append(comm.send(sbuf, root, tag=2 * tag + 1, block=False)) 

996 

997 if comm.size == 1 or timeout is None: 

998 return requests 

999 

1000 t0 = time.time() 

1001 while not comm.testall(requests): # automatic clean-up upon success 

1002 if time.time() - t0 > timeout: 

1003 raise RuntimeError('MPI barrier timeout.') 

1004 return [] 

1005 

1006 

1007def run(iterators): 

1008 """Run through list of iterators one step at a time.""" 

1009 if not isinstance(iterators, list): 

1010 # It's a single iterator - empty it: 

1011 for i in iterators: 

1012 pass 

1013 return 

1014 

1015 if len(iterators) == 0: 

1016 return 

1017 

1018 while True: 

1019 try: 

1020 results = [next(iter) for iter in iterators] 

1021 except StopIteration: 

1022 return results 

1023 

1024 

1025class Parallelization: 

1026 def __init__(self, comm, nkpts): 

1027 self.comm = comm 

1028 self.size = comm.size 

1029 self.nkpts = nkpts 

1030 

1031 self.kpt = None 

1032 self.domain = None 

1033 self.band = None 

1034 

1035 self.nclaimed = 1 

1036 self.navail = comm.size 

1037 

1038 def set(self, kpt=None, domain=None, band=None): 

1039 if kpt is not None: 

1040 self.kpt = kpt 

1041 if domain is not None: 

1042 self.domain = domain 

1043 if band is not None: 

1044 self.band = band 

1045 

1046 nclaimed = 1 

1047 for group, name in zip([self.kpt, self.domain, self.band], 

1048 ['k-point', 'domain', 'band']): 

1049 if group is not None: 

1050 assert group > 0, ('Bad: Only {} cores requested for ' 

1051 '{} parallelization'.format(group, name)) 

1052 if self.size % group != 0: 

1053 msg = ('Cannot parallelize as the ' 

1054 'communicator size %d is not divisible by the ' 

1055 'requested number %d of ranks for %s ' 

1056 'parallelization' % (self.size, group, name)) 

1057 raise ValueError(msg) 

1058 nclaimed *= group 

1059 navail = self.size // nclaimed 

1060 

1061 assert self.size % nclaimed == 0 

1062 assert self.size % navail == 0 

1063 

1064 self.navail = navail 

1065 self.nclaimed = nclaimed 

1066 

1067 def get_communicator_sizes(self, kpt=None, domain=None, band=None): 

1068 self.set(kpt=kpt, domain=domain, band=band) 

1069 self.autofinalize() 

1070 return self.kpt, self.domain, self.band 

1071 

1072 def build_communicators(self, kpt=None, domain=None, band=None, 

1073 order='kbd'): 

1074 """Construct communicators. 

1075 

1076 Returns a communicator for k-points, domains, bands and 

1077 k-points/bands. The last one "unites" all ranks that are 

1078 responsible for the same domain. 

1079 

1080 The order must be a permutation of the characters 'kbd', each 

1081 corresponding to each a parallelization mode. The last 

1082 character signifies the communicator that will be assigned 

1083 contiguous ranks, i.e. order='kbd' will yield contiguous 

1084 domain ranks, whereas order='kdb' will yield contiguous band 

1085 ranks.""" 

1086 self.set(kpt=kpt, domain=domain, band=band) 

1087 self.autofinalize() 

1088 

1089 comm = self.comm 

1090 rank = comm.rank 

1091 communicators = {} 

1092 parent_stride = self.size 

1093 offset = 0 

1094 

1095 groups = dict(k=self.kpt, b=self.band, d=self.domain) 

1096 

1097 # Build communicators in hierarchical manner 

1098 # The ranks in the first group have largest separation while 

1099 # the ranks in the last group are next to each other 

1100 for name in order: 

1101 group = groups[name] 

1102 stride = parent_stride // group 

1103 # First rank in this group 

1104 r0 = rank % stride + offset 

1105 # Last rank in this group 

1106 r1 = r0 + stride * group 

1107 ranks = np.arange(r0, r1, stride) 

1108 communicators[name] = comm.new_communicator(ranks) 

1109 parent_stride = stride 

1110 # Offset for the next communicator 

1111 offset += communicators[name].rank * stride 

1112 

1113 # We want a communicator for kpts/bands, i.e. the complement of the 

1114 # grid comm: a communicator uniting all cores with the same domain. 

1115 c1, c2, c3 = (communicators[name] for name in order) 

1116 allranks = [range(c1.size), range(c2.size), range(c3.size)] 

1117 

1118 def get_communicator_complement(name): 

1119 relevant_ranks = list(allranks) 

1120 relevant_ranks[order.find(name)] = [communicators[name].rank] 

1121 ranks = np.array([r3 + c3.size * (r2 + c2.size * r1) 

1122 for r1 in relevant_ranks[0] 

1123 for r2 in relevant_ranks[1] 

1124 for r3 in relevant_ranks[2]]) 

1125 return comm.new_communicator(ranks) 

1126 

1127 # The communicator of all processes that share a domain, i.e. 

1128 # the combination of k-point and band dommunicators. 

1129 communicators['D'] = get_communicator_complement('d') 

1130 # For each k-point comm rank, a communicator of all 

1131 # band/domain ranks. This is typically used with ScaLAPACK 

1132 # and LCAO orbital stuff. 

1133 communicators['K'] = get_communicator_complement('k') 

1134 return communicators 

1135 

1136 def autofinalize(self): 

1137 if self.kpt is None: 

1138 self.set(kpt=self.get_optimal_kpt_parallelization()) 

1139 if self.domain is None: 

1140 self.set(domain=self.navail) 

1141 if self.band is None: 

1142 self.set(band=self.navail) 

1143 

1144 if self.navail > 1: 

1145 assignments = dict(kpt=self.kpt, 

1146 domain=self.domain, 

1147 band=self.band) 

1148 raise gpaw.BadParallelization( 

1149 f'All the CPUs must be used. Have {assignments} but ' 

1150 f'{self.navail} times more are available.') 

1151 

1152 def get_optimal_kpt_parallelization(self, kptprioritypower=1.4): 

1153 if self.domain and self.band: 

1154 # Try to use all the CPUs for k-point parallelization 

1155 ncpus = min(self.nkpts, self.navail) 

1156 return ncpus 

1157 ncpuvalues, wastevalues = self.find_kpt_parallelizations() 

1158 scores = ((self.navail // ncpuvalues) * 

1159 ncpuvalues**kptprioritypower)**(1.0 - wastevalues) 

1160 arg = np.argmax(scores) 

1161 ncpus = ncpuvalues[arg] 

1162 return ncpus 

1163 

1164 def find_kpt_parallelizations(self): 

1165 nkpts = self.nkpts 

1166 ncpuvalues = [] 

1167 wastevalues = [] 

1168 

1169 ncpus = nkpts 

1170 while ncpus > 0: 

1171 if self.navail % ncpus == 0: 

1172 nkptsmax = -(-nkpts // ncpus) 

1173 effort = nkptsmax * ncpus 

1174 efficiency = nkpts / float(effort) 

1175 waste = 1.0 - efficiency 

1176 wastevalues.append(waste) 

1177 ncpuvalues.append(ncpus) 

1178 ncpus -= 1 

1179 return np.array(ncpuvalues), np.array(wastevalues) 

1180 

1181 

1182def cleanup(): 

1183 error = getattr(sys, 'last_type', None) 

1184 if error is not None: # else: Python script completed or raise SystemExit 

1185 if parallel and not (gpaw.dry_run > 1): 

1186 sys.stdout.flush() 

1187 sys.stderr.write(('GPAW CLEANUP (node %d): %s occurred. ' 

1188 'Calling MPI_Abort!\n') % (world.rank, error)) 

1189 sys.stderr.flush() 

1190 # Give other nodes a moment to crash by themselves (perhaps 

1191 # producing helpful error messages) 

1192 time.sleep(10) 

1193 world.abort(42) 

1194 

1195 

1196def print_mpi_stack_trace(type, value, tb): 

1197 """Format exceptions nicely when running in parallel. 

1198 

1199 Use this function as an except hook. Adds rank 

1200 and line number to each line of the exception. Lines will 

1201 still be printed from different ranks in random order, but 

1202 one can grep for a rank or run 'sort' on the output to obtain 

1203 readable data.""" 

1204 

1205 exception_text = traceback.format_exception(type, value, tb) 

1206 ndigits = len(str(world.size - 1)) 

1207 rankstring = ('%%0%dd' % ndigits) % world.rank 

1208 

1209 lines = [] 

1210 # The exception elements may contain newlines themselves 

1211 for element in exception_text: 

1212 lines.extend(element.splitlines()) 

1213 

1214 line_ndigits = len(str(len(lines) - 1)) 

1215 

1216 for lineno, line in enumerate(lines): 

1217 lineno = ('%%0%dd' % line_ndigits) % lineno 

1218 sys.stderr.write(f'rank={rankstring} L{lineno}: {line}\n') 

1219 

1220 

1221if world.size > 1: # Triggers for dry-run communicators too, but we care not. 

1222 sys.excepthook = print_mpi_stack_trace 

1223 

1224 

1225def exit(error='Manual exit'): 

1226 # Note that exit must be called on *all* MPI tasks 

1227 atexit._exithandlers = [] # not needed because we are intentially exiting 

1228 if parallel and not (gpaw.dry_run > 1): 

1229 sys.stdout.flush() 

1230 sys.stderr.write(('GPAW CLEANUP (node %d): %s occurred. ' + 

1231 'Calling MPI_Finalize!\n') % (world.rank, error)) 

1232 sys.stderr.flush() 

1233 else: 

1234 cleanup(error) 

1235 world.barrier() # sync up before exiting 

1236 sys.exit() # quit for serial case, return to cgpaw.c for parallel case 

1237 

1238 

1239atexit.register(cleanup)