Coverage for gpaw/mpi.py: 67%
558 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1# Copyright (C) 2003 CAMP
2# Please see the accompanying LICENSE file for further information.
3from __future__ import annotations
5import atexit
6import pickle
7import sys
8import time
9import traceback
10from contextlib import contextmanager
11from typing import Any
13import gpaw.cgpaw as cgpaw
14import numpy as np
15import warnings
16from ase.parallel import MPI as ASE_MPI
17from ase.parallel import world as aseworld
19import gpaw
21from ._broadcast_imports import world as _world
23MASTER = 0
26def is_contiguous(*args, **kwargs):
27 from gpaw.utilities import is_contiguous
28 return is_contiguous(*args, **kwargs)
31@contextmanager
32def broadcast_exception(comm):
33 """Make sure all ranks get a possible exception raised.
35 This example::
37 with broadcast_exception(world):
38 if world.rank == 0:
39 x = 1 / 0
41 will raise ZeroDivisionError in the whole world.
42 """
43 # Each core will send -1 on success or its rank on failure.
44 try:
45 yield
46 except Exception as ex:
47 rank = comm.max_scalar(comm.rank)
48 if rank == comm.rank:
49 broadcast(ex, rank, comm)
50 raise
51 else:
52 rank = comm.max_scalar(-1)
53 # rank will now be the highest failing rank or -1
54 if rank >= 0:
55 raise broadcast(None, rank, comm)
58class _Communicator:
59 def __init__(self, comm, parent=None):
60 """Construct a wrapper of the C-object for any MPI-communicator.
62 Parameters:
64 comm: MPI-communicator
65 Communicator.
67 Attributes:
69 ============ ======================================================
70 ``size`` Number of ranks in the MPI group.
71 ``rank`` Number of this CPU in the MPI group.
72 ``parent`` Parent MPI-communicator.
73 ============ ======================================================
74 """
75 self.comm = comm
76 self.size = comm.size
77 self.rank = comm.rank
78 self.parent = parent # XXX check C-object against comm.parent?
80 def __repr__(self):
81 return f'MPIComm(size={self.size}, rank={self.rank})'
83 def new_communicator(self, ranks):
84 """Create a new MPI communicator for a subset of ranks in a group.
85 Must be called with identical arguments by all relevant processes.
87 Note that a valid communicator is only returned to the processes
88 which are included in the new group; other ranks get None returned.
90 Parameters:
92 ranks: ndarray (type int)
93 List of integers of the ranks to include in the new group.
94 Note that these ranks correspond to indices in the current
95 group whereas the rank attribute in the new communicators
96 correspond to their respective index in the subset.
98 """
100 comm = self.comm.new_communicator(ranks)
101 if comm is None:
102 # This cpu is not in the new communicator:
103 return None
104 else:
105 return _Communicator(comm, parent=self)
107 def sum(self, a, root=-1):
108 """Perform summation by MPI reduce operations of numerical data.
110 Parameters:
112 a: ndarray or value (type int, float or complex)
113 Numerical data to sum over all ranks in the communicator group.
114 If the data is a single value of type int, float or complex,
115 the result is returned because the input argument is immutable.
116 Otherwise, the reduce operation is carried out in-place such
117 that the elements of the input array will represent the sum of
118 the equivalent elements across all processes in the group.
119 root: int (default -1)
120 Rank of the root process, on which the outcome of the reduce
121 operation is valid. A root rank of -1 signifies that the result
122 will be distributed back to all processes, i.e. a broadcast.
124 """
125 if isinstance(a, (int, float, complex)):
126 warnings.warn('Please use sum_scalar(...)', stacklevel=2)
127 return self.comm.sum_scalar(a, root)
128 else:
129 # assert a.ndim != 0
130 tc = a.dtype
131 assert is_contiguous(a, tc)
132 assert root == -1 or 0 <= root < self.size
133 self.comm.sum(a, root)
135 def sum_scalar(self, a, root=-1):
136 assert isinstance(a, (int, float, complex))
137 return self.comm.sum_scalar(a, root)
139 def product(self, a, root=-1):
140 """Do multiplication by MPI reduce operations of numerical data.
142 Parameters:
144 a: ndarray or value (type int or float)
145 Numerical data to multiply across all ranks in the communicator
146 group. NB: Find the global product from the local products.
147 If the data is a single value of type int or float (no complex),
148 the result is returned because the input argument is immutable.
149 Otherwise, the reduce operation is carried out in-place such
150 that the elements of the input array will represent the product
151 of the equivalent elements across all processes in the group.
152 root: int (default -1)
153 Rank of the root process, on which the outcome of the reduce
154 operation is valid. A root rank of -1 signifies that the result
155 will be distributed back to all processes, i.e. a broadcast.
157 """
158 if isinstance(a, (int, float)):
159 1 / 0
160 return self.comm.product(a, root)
161 else:
162 tc = a.dtype
163 assert tc == int or tc == float
164 assert is_contiguous(a, tc)
165 assert root == -1 or 0 <= root < self.size
166 self.comm.product(a, root)
168 def max(self, a, root=-1):
169 """Find maximal value by an MPI reduce operation of numerical data.
171 Parameters:
173 a: ndarray or value (type int or float)
174 Numerical data to find the maximum value of across all ranks in
175 the communicator group. NB: Find global maximum from local max.
176 If the data is a single value of type int or float (no complex),
177 the result is returned because the input argument is immutable.
178 Otherwise, the reduce operation is carried out in-place such
179 that the elements of the input array will represent the max of
180 the equivalent elements across all processes in the group.
181 root: int (default -1)
182 Rank of the root process, on which the outcome of the reduce
183 operation is valid. A root rank of -1 signifies that the result
184 will be distributed back to all processes, i.e. a broadcast.
186 """
187 if isinstance(a, (int, float)):
188 warnings.warn('Please use max_scalar(...)', stacklevel=2)
189 return self.comm.max_scalar(a, root)
190 else:
191 tc = a.dtype
192 assert tc == int or tc == float
193 assert is_contiguous(a, tc)
194 assert root == -1 or 0 <= root < self.size
195 self.comm.max(a, root)
197 def max_scalar(self, a, root=-1):
198 assert isinstance(a, (int, float, np.int64))
199 return self.comm.max_scalar(a, root)
201 def min(self, a, root=-1):
202 """Find minimal value by an MPI reduce operation of numerical data.
204 Parameters:
206 a: ndarray or value (type int or float)
207 Numerical data to find the minimal value of across all ranks in
208 the communicator group. NB: Find global minimum from local min.
209 If the data is a single value of type int or float (no complex),
210 the result is returned because the input argument is immutable.
211 Otherwise, the reduce operation is carried out in-place such
212 that the elements of the input array will represent the min of
213 the equivalent elements across all processes in the group.
214 root: int (default -1)
215 Rank of the root process, on which the outcome of the reduce
216 operation is valid. A root rank of -1 signifies that the result
217 will be distributed back to all processes, i.e. a broadcast.
219 """
220 if isinstance(a, (int, float)):
221 warnings.warn('Please use min_scalar(...)', stacklevel=2)
222 return self.comm.min_scalar(a, root)
223 else:
224 tc = a.dtype
225 assert tc == int or tc == float
226 assert is_contiguous(a, tc)
227 assert root == -1 or 0 <= root < self.size
228 self.comm.min(a, root)
230 def min_scalar(self, a, root=-1):
231 assert isinstance(a, (int, float))
232 return self.comm.min_scalar(a, root)
234 def scatter(self, a, b, root: int) -> None:
235 """Distribute data from one rank to all other processes in a group.
237 Parameters:
239 a: ndarray (ignored on all ranks different from root; use None)
240 Source of the data to distribute, i.e. send buffer on root rank.
241 b: ndarray
242 Destination of the distributed data, i.e. local receive buffer.
243 The size of this array multiplied by the number of process in
244 the group must match the size of the source array on the root.
245 root: int
246 Rank of the root process, from which the source data originates.
248 The reverse operation is ``gather``.
250 Example::
252 # The master has all the interesting data. Distribute it.
253 seed = 123456
254 rng = np.random.default_rng(seed)
255 if comm.rank == 0:
256 data = rng.normal(size=N*comm.size)
257 else:
258 data = None
259 mydata = np.empty(N, dtype=float)
260 comm.scatter(data, mydata, 0)
262 # .. which is equivalent to ..
264 if comm.rank == 0:
265 # Extract my part directly
266 mydata[:] = data[0:N]
267 # Distribute parts to the slaves
268 for rank in range(1, comm.size):
269 buf = data[rank*N:(rank+1)*N]
270 comm.send(buf, rank, tag=123)
271 else:
272 # Receive from the master
273 comm.receive(mydata, 0, tag=123)
275 """
276 if self.rank == root:
277 assert a.dtype == b.dtype
278 assert a.size == self.size * b.size
279 assert a.flags.c_contiguous
280 assert b.flags.c_contiguous
281 assert 0 <= root < self.size
282 self.comm.scatter(a, b, root)
284 def alltoallv(self, sbuffer, scounts, sdispls, rbuffer, rcounts, rdispls):
285 """All-to-all in a group.
287 Parameters:
289 sbuffer: ndarray
290 Source of the data to distribute, i.e., send buffers on all ranks
291 scounts: ndarray
292 Integer array equal to the group size specifying the number of
293 elements to send to each processor
294 sdispls: ndarray
295 Integer array (of length group size). Entry j specifies the
296 displacement (relative to sendbuf from which to take the
297 outgoing data destined for process j)
298 rbuffer: ndarray
299 Destination of the distributed data, i.e., local receive buffer.
300 rcounts: ndarray
301 Integer array equal to the group size specifying the maximum
302 number of elements that can be received from each processor.
303 rdispls:
304 Integer array (of length group size). Entry i specifies the
305 displacement (relative to recvbuf at which to place the incoming
306 data from process i
307 """
308 assert sbuffer.flags.c_contiguous
309 assert scounts.flags.c_contiguous
310 assert sdispls.flags.c_contiguous
311 assert rbuffer.flags.c_contiguous
312 assert rcounts.flags.c_contiguous
313 assert rdispls.flags.c_contiguous
314 assert sbuffer.dtype == rbuffer.dtype
316 for arr in [scounts, sdispls, rcounts, rdispls]:
317 assert arr.dtype == int, arr.dtype
318 assert len(arr) == self.size
320 assert np.all(0 <= sdispls)
321 assert np.all(0 <= rdispls)
322 assert np.all(sdispls + scounts <= sbuffer.size)
323 assert np.all(rdispls + rcounts <= rbuffer.size)
324 self.comm.alltoallv(sbuffer, scounts, sdispls,
325 rbuffer, rcounts, rdispls)
327 def all_gather(self, a, b):
328 """Gather data from all ranks onto all processes in a group.
330 Parameters:
332 a: ndarray
333 Source of the data to gather, i.e. send buffer of this rank.
334 b: ndarray
335 Destination of the distributed data, i.e. receive buffer.
336 The size of this array must match the size of the distributed
337 source arrays multiplied by the number of process in the group.
339 Example::
341 # All ranks have parts of interesting data. Gather on all ranks.
342 seed = 123456
343 rng = np.random.default_rng(seed)
344 mydata = rng.normal(size=N)
345 data = np.empty(N*comm.size, dtype=float)
346 comm.all_gather(mydata, data)
348 # .. which is equivalent to ..
350 if comm.rank == 0:
351 # Insert my part directly
352 data[0:N] = mydata
353 # Gather parts from the slaves
354 buf = np.empty(N, dtype=float)
355 for rank in range(1, comm.size):
356 comm.receive(buf, rank, tag=123)
357 data[rank*N:(rank+1)*N] = buf
358 else:
359 # Send to the master
360 comm.send(mydata, 0, tag=123)
361 # Broadcast from master to all slaves
362 comm.broadcast(data, 0)
364 """
365 assert a.flags.contiguous
366 assert b.flags.contiguous
367 assert b.dtype == a.dtype
368 assert (b.shape[0] == self.size and a.shape == b.shape[1:] or
369 a.size * self.size == b.size)
370 self.comm.all_gather(a, b)
372 def gather(self, a, root, b=None):
373 """Gather data from all ranks onto a single process in a group.
375 Parameters:
377 a: ndarray
378 Source of the data to gather, i.e. send buffer of this rank.
379 root: int
380 Rank of the root process, on which the data is to be gathered.
381 b: ndarray (ignored on all ranks different from root; default None)
382 Destination of the distributed data, i.e. root's receive buffer.
383 The size of this array must match the size of the distributed
384 source arrays multiplied by the number of process in the group.
386 The reverse operation is ``scatter``.
388 Example::
390 # All ranks have parts of interesting data. Gather it on master.
391 seed = 123456
392 rng = np.random.default_rng(seed)
393 mydata = rng.normal(size=N)
394 if comm.rank == 0:
395 data = np.empty(N*comm.size, dtype=float)
396 else:
397 data = None
398 comm.gather(mydata, 0, data)
400 # .. which is equivalent to ..
402 if comm.rank == 0:
403 # Extract my part directly
404 data[0:N] = mydata
405 # Gather parts from the slaves
406 buf = np.empty(N, dtype=float)
407 for rank in range(1, comm.size):
408 comm.receive(buf, rank, tag=123)
409 data[rank*N:(rank+1)*N] = buf
410 else:
411 # Send to the master
412 comm.send(mydata, 0, tag=123)
414 """
415 assert a.flags.c_contiguous
416 assert 0 <= root < self.size
417 if root == self.rank:
418 assert b.flags.c_contiguous and b.dtype == a.dtype
419 assert (b.shape[0] == self.size and a.shape == b.shape[1:] or
420 a.size * self.size == b.size)
421 self.comm.gather(a, root, b)
422 else:
423 assert b is None
424 self.comm.gather(a, root)
426 def broadcast(self, a, root):
427 """Share data from a single process to all ranks in a group.
429 Parameters:
431 a: ndarray
432 Data, i.e. send buffer on root rank, receive buffer elsewhere.
433 Note that after the broadcast, all ranks have the same data.
434 root: int
435 Rank of the root process, from which the data is to be shared.
437 Example::
439 # All ranks have parts of interesting data. Take a given index.
440 seed = 123456
441 rng = np.random.default_rng(seed)
442 mydata[:] = rng.normal(size=N)
444 # Who has the element at global index 13? Everybody needs it!
445 index = 13
446 root, myindex = divmod(index, N)
447 element = np.empty(1, dtype=float)
448 if comm.rank == root:
449 # This process has the requested element so extract it
450 element[:] = mydata[myindex]
452 # Broadcast from owner to everyone else
453 comm.broadcast(element, root)
455 # .. which is equivalent to ..
457 if comm.rank == root:
458 # We are root so send it to the other ranks
459 for rank in range(comm.size):
460 if rank != root:
461 comm.send(element, rank, tag=123)
462 else:
463 # We don't have it so receive from root
464 comm.receive(element, root, tag=123)
466 """
467 assert 0 <= root < self.size
468 assert is_contiguous(a)
469 self.comm.broadcast(a, root)
471 def sendreceive(self, a, dest, b, src, sendtag=123, recvtag=123):
472 assert 0 <= dest < self.size
473 assert dest != self.rank
474 assert is_contiguous(a)
475 assert 0 <= src < self.size
476 assert src != self.rank
477 assert is_contiguous(b)
478 return self.comm.sendreceive(a, dest, b, src, sendtag, recvtag)
480 def send(self, a, dest, tag=123, block=True):
481 assert 0 <= dest < self.size
482 assert dest != self.rank
483 assert is_contiguous(a)
484 if not block:
485 pass # assert sys.getrefcount(a) > 3
486 return self.comm.send(a, dest, tag, block)
488 def ssend(self, a, dest, tag=123):
489 assert 0 <= dest < self.size
490 assert dest != self.rank
491 assert is_contiguous(a)
492 return self.comm.ssend(a, dest, tag)
494 def receive(self, a, src, tag=123, block=True):
495 assert 0 <= src < self.size
496 assert src != self.rank
497 assert is_contiguous(a)
498 return self.comm.receive(a, src, tag, block)
500 def test(self, request):
501 """Test whether a non-blocking MPI operation has completed. A boolean
502 is returned immediately and the request is not modified in any way.
504 Parameters:
506 request: MPI request
507 Request e.g. returned from send/receive when block=False is used.
509 """
510 return self.comm.test(request)
512 def testall(self, requests):
513 """Test whether non-blocking MPI operations have completed. A boolean
514 is returned immediately but requests may have been deallocated as a
515 result, provided they have completed before or during this invokation.
517 Parameters:
519 request: MPI request
520 Request e.g. returned from send/receive when block=False is used.
522 """
523 return self.comm.testall(requests) # may deallocate requests!
525 def wait(self, request):
526 """Wait for a non-blocking MPI operation to complete before returning.
528 Parameters:
530 request: MPI request
531 Request e.g. returned from send/receive when block=False is used.
533 """
534 self.comm.wait(request)
536 def waitall(self, requests):
537 """Wait for non-blocking MPI operations to complete before returning.
539 Parameters:
541 requests: list
542 List of MPI requests e.g. aggregated from returned requests of
543 multiple send/receive calls where block=False was used.
545 """
546 self.comm.waitall(requests)
548 def abort(self, errcode):
549 """Terminate MPI execution environment of all tasks in the group.
550 This function only returns in the advent of an error occurring.
552 Parameters:
554 errcode: int
555 Error code to return to the invoking environment.
557 """
558 return self.comm.abort(errcode)
560 def name(self):
561 """Return the name of the processor as a string."""
562 return self.comm.name()
564 def barrier(self):
565 """Block execution until all process have reached this point."""
566 self.comm.barrier()
568 def compare(self, othercomm):
569 """Compare communicator to other.
571 Returns 'ident' if they are identical, 'congruent' if they are
572 copies of each other, 'similar' if they are permutations of
573 each other, and otherwise 'unequal'.
575 This method corresponds to MPI_Comm_compare."""
576 if isinstance(self.comm, SerialCommunicator):
577 return self.comm.compare(othercomm.comm) # argh!
578 result = self.comm.compare(othercomm.get_c_object())
579 assert result in ['ident', 'congruent', 'similar', 'unequal']
580 return result
582 def translate_ranks(self, other, ranks):
583 """"Translate ranks from communicator to other.
585 ranks must be valid on this communicator. Returns ranks
586 on other communicator corresponding to the same processes.
587 Ranks that are not defined on the other communicator are
588 assigned values of -1. (In contrast to MPI which would
589 assign MPI_UNDEFINED)."""
590 assert hasattr(other, 'translate_ranks'), \
591 'Excpected communicator, got %s' % other
592 assert all(0 <= rank for rank in ranks)
593 assert all(rank < self.size for rank in ranks)
594 if isinstance(self.comm, SerialCommunicator):
595 return self.comm.translate_ranks(other.comm, ranks) # argh!
596 otherranks = self.comm.translate_ranks(other.get_c_object(), ranks)
597 assert all(-1 <= rank for rank in otherranks)
598 assert ranks.dtype == otherranks.dtype
599 return otherranks
601 def get_members(self):
602 """Return the subset of processes which are members of this MPI group
603 in terms of the ranks they are assigned on the parent communicator.
604 For the world communicator, this is all integers up to ``size``.
606 Example::
608 >>> world.rank, world.size # doctest: +SKIP
609 (3, 4)
610 >>> world.get_members() # doctest: +SKIP
611 array([0, 1, 2, 3])
612 >>> comm = world.new_communicator(np.array([2, 3])) # doctest: +SKIP
613 >>> comm.rank, comm.size # doctest: +SKIP
614 (1, 2)
615 >>> comm.get_members() # doctest: +SKIP
616 array([2, 3])
617 >>> comm.get_members()[comm.rank] == world.rank # doctest: +SKIP
618 True
620 """
621 return self.comm.get_members()
623 def get_c_object(self):
624 """Return the C-object wrapped by this debug interface.
626 Whenever a communicator object is passed to C code, that object
627 must be a proper C-object - *not* e.g. this debug wrapper. For
628 this reason. The C-communicator object has a get_c_object()
629 implementation which returns itself; thus, always call
630 comm.get_c_object() and pass the resulting object to the C code.
631 """
632 c_obj = self.comm.get_c_object()
633 if isinstance(c_obj, cgpaw.Communicator):
634 return c_obj
635 return c_obj.get_c_object()
638MPIComm = _Communicator # for type hints
641# Serial communicator
642class SerialCommunicator:
643 size = 1
644 rank = 0
646 def __init__(self, parent=None):
647 self.parent = parent
649 def __repr__(self):
650 return 'SerialCommunicator()'
652 def sum(self, array, root=-1):
653 if isinstance(array, (int, float, complex)):
654 warnings.warn('Please use sum_scalar(...)', stacklevel=2)
655 return array
657 def sum_scalar(self, a, root=-1):
658 return a
660 def scatter(self, s, r, root):
661 r[:] = s
663 def min(self, value, root=-1):
664 if isinstance(value, (int, float, complex)):
665 warnings.warn('Please use min_scalar(...)', stacklevel=2)
666 return value
668 def min_scalar(self, value, root=-1):
669 return value
671 def max(self, value, root=-1):
672 if isinstance(value, (int, float, complex)):
673 warnings.warn('Please use max_scalar(...)', stacklevel=2)
674 return value
676 def max_scalar(self, value, root=-1):
677 return value
679 def broadcast(self, buf, root):
680 pass
682 def send(self, buff, dest, tag=123, block=True):
683 pass
685 def barrier(self):
686 pass
688 def gather(self, a, root, b):
689 b[:] = a
691 def all_gather(self, a, b):
692 b[:] = a
694 def alltoallv(self, sbuffer, scounts, sdispls, rbuffer, rcounts, rdispls):
695 assert len(scounts) == 1
696 assert len(sdispls) == 1
697 assert len(rcounts) == 1
698 assert len(rdispls) == 1
699 assert len(sbuffer) == len(rbuffer)
701 rbuffer[rdispls[0]:rdispls[0] + rcounts[0]] = \
702 sbuffer[sdispls[0]:sdispls[0] + scounts[0]]
704 def new_communicator(self, ranks):
705 if self.rank not in ranks:
706 return None
707 comm = SerialCommunicator(parent=self)
708 comm.size = len(ranks)
709 return comm
711 def test(self, request):
712 return 1
714 def testall(self, requests):
715 return 1
717 def wait(self, request):
718 raise NotImplementedError('Calls to mpi wait should not happen in '
719 'serial mode')
721 def waitall(self, requests):
722 if not requests:
723 return
724 raise NotImplementedError('Calls to mpi waitall should not happen in '
725 'serial mode')
727 def get_members(self):
728 return np.array([0])
730 def compare(self, other):
731 if self == other:
732 return 'ident'
733 elif other.size == 1:
734 return 'congruent'
735 else:
736 raise NotImplementedError('Compare serial comm to other')
738 def translate_ranks(self, other, ranks):
739 if isinstance(other, SerialCommunicator):
740 assert all(rank == 0 for rank in ranks) or gpaw.dry_run
741 return np.zeros(len(ranks), dtype=int)
742 return np.array([other.rank for rank in ranks])
743 raise NotImplementedError(
744 'Translate non-trivial ranks with serial comm')
746 def get_c_object(self):
747 if gpaw.dry_run:
748 return None # won't actually be passed to C
749 return _world
752_serial_comm = SerialCommunicator()
754have_mpi = _world is not None
756if not have_mpi:
757 _world = _serial_comm # type: ignore
759if gpaw.debug:
760 serial_comm = _Communicator(_serial_comm)
761 if _world.size == 1:
762 world = serial_comm
763 else:
764 world = _Communicator(_world)
765else:
766 serial_comm = _serial_comm # type: ignore
767 world = _world # type: ignore
769rank = world.rank
770size = world.size
771parallel = (size > 1)
774def verify_ase_world():
775 # ASE does not like that GPAW uses world.size at import time.
776 # .... because of GPAW's own import time communicator mish-mash.
777 # Now, GPAW wants to verify world.size and cannot do so,
778 # because of what ASE does for GPAW's sake.
779 # This really needs improvement!
780 assert aseworld is not None
782 if isinstance(aseworld, ASE_MPI):
783 # We only want to check if the communicator was already initialized.
784 # Otherwise the communicator will be initialized as a side effect
785 # of accessing the .size attribute,
786 # which ASE's tests will complain about.
787 check_size = aseworld.comm is not None
788 else:
789 check_size = True # A real communicator, so we want to check that
791 if check_size and world.size != aseworld.size:
792 raise RuntimeError('Please use "gpaw python" to run in parallel')
795verify_ase_world()
798def broadcast(obj, root=0, comm=world):
799 """Broadcast a Python object across an MPI communicator and return it."""
800 if comm.rank == root:
801 assert obj is not None
802 b = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
803 else:
804 assert obj is None
805 b = None
806 b = broadcast_bytes(b, root, comm)
807 if comm.rank == root:
808 return obj
809 else:
810 return pickle.loads(b)
813def broadcast_float(x, comm):
814 array = np.array([x])
815 comm.broadcast(array, 0)
816 return array[0]
819def synchronize_atoms(atoms, comm, tolerance=1e-8):
820 """Synchronize atoms between multiple CPUs removing numerical noise.
822 If the atoms differ significantly, raise ValueError on all ranks.
823 The error object contains the ranks where the check failed.
825 In debug mode, write atoms to files in case of failure."""
827 if len(atoms) == 0:
828 return
830 if comm.rank == 0:
831 src = (atoms.positions, atoms.cell, atoms.numbers, atoms.pbc)
832 else:
833 src = None
835 # XXX replace with ase.cell.same_cell in the future
836 # (if that functions gets to exist)
837 # def same_cell(cell1, cell2):
838 # return ((cell1 is None) == (cell2 is None) and
839 # (cell1 is None or (cell1 == cell2).all()))
841 # Cell vectors should be compared with a tolerance like positions?
842 def same_cell(cell1, cell2, tolerance=1e-8):
843 return ((cell1 is None) == (cell2 is None) and
844 (cell1 is None or (abs(cell1 - cell2).max() <= tolerance)))
846 positions, cell, numbers, pbc = broadcast(src, root=0, comm=comm)
847 ok = (len(positions) == len(atoms.positions) and
848 (abs(positions - atoms.positions).max() <= tolerance) and
849 (numbers == atoms.numbers).all() and
850 same_cell(cell, atoms.cell) and
851 (pbc == atoms.pbc).all())
853 # We need to fail equally on all ranks to avoid trouble. Thus
854 # we use an array to gather check results from everyone.
855 my_fail = np.array(not ok, dtype=bool)
856 all_fail = np.zeros(comm.size, dtype=bool)
857 comm.all_gather(my_fail, all_fail)
859 if all_fail.any():
860 if gpaw.debug:
861 with open('synchronize_atoms_r%d.pckl' % comm.rank, 'wb') as fd:
862 pickle.dump((atoms.positions, atoms.cell,
863 atoms.numbers, atoms.pbc,
864 positions, cell, numbers, pbc), fd)
865 err_ranks = np.arange(comm.size)[all_fail]
866 raise ValueError('Mismatch of Atoms objects. In debug '
867 'mode, atoms will be dumped to files.',
868 err_ranks)
870 atoms.positions = positions
871 atoms.cell = cell
874def broadcast_string(string=None, root=0, comm=world):
875 if comm.rank == root:
876 string = string.encode()
877 return broadcast_bytes(string, root, comm).decode()
880def broadcast_bytes(b=None, root=0, comm=world):
881 """Broadcast a bytes across an MPI communicator and return it."""
882 if comm.rank == root:
883 assert isinstance(b, bytes)
884 n = np.array(len(b), int)
885 else:
886 assert b is None
887 n = np.zeros(1, int)
888 comm.broadcast(n, root)
889 if comm.rank == root:
890 b = np.frombuffer(b, np.int8)
891 else:
892 b = np.zeros(n, np.int8)
893 comm.broadcast(b, root)
894 return b.tobytes()
897def broadcast_array(array: np.ndarray, *communicators) -> np.ndarray:
898 """Broadcast np.ndarray across sequence of MPI-communicators."""
899 comms = list(communicators)
900 while comms:
901 comm = comms.pop()
902 if all(comm.rank == 0 for comm in comms):
903 comm.broadcast(array, 0)
904 return array
907def send(obj, rank: int, comm: MPIComm) -> None:
908 """Send object to rank on the MPI communicator comm."""
909 b = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
910 comm.send(np.array(len(b)), rank)
911 comm.send(np.frombuffer(b, np.int8).copy(), rank)
914def receive(rank: int, comm: MPIComm) -> Any:
915 """Receive object from rank on the MPI communicator comm."""
916 n = np.array(0)
917 comm.receive(n, rank)
918 buf = np.empty(int(n), np.int8)
919 comm.receive(buf, rank)
920 return pickle.loads(buf.tobytes())
923def send_string(string, rank, comm=world):
924 b = string.encode()
925 comm.send(np.array(len(b)), rank)
926 comm.send(np.frombuffer(b, np.int8).copy(), rank)
929def receive_string(rank, comm=world):
930 n = np.array(0)
931 comm.receive(n, rank)
932 string = np.empty(n, np.int8)
933 comm.receive(string, rank)
934 return string.tobytes().decode()
937def alltoallv_string(send_dict, comm=world):
938 scounts = np.zeros(comm.size, dtype=int)
939 sdispls = np.zeros(comm.size, dtype=int)
940 stotal = 0
941 for proc in range(comm.size):
942 if proc in send_dict:
943 data = np.frombuffer(send_dict[proc].encode(), np.int8)
944 scounts[proc] = data.size
945 sdispls[proc] = stotal
946 stotal += scounts[proc]
948 rcounts = np.zeros(comm.size, dtype=int)
949 comm.alltoallv(scounts, np.ones(comm.size, dtype=int),
950 np.arange(comm.size, dtype=int),
951 rcounts, np.ones(comm.size, dtype=int),
952 np.arange(comm.size, dtype=int))
953 rdispls = np.zeros(comm.size, dtype=int)
954 rtotal = 0
955 for proc in range(comm.size):
956 rdispls[proc] = rtotal
957 rtotal += rcounts[proc]
958 # rtotal += rcounts[proc] # CHECK: is this correct?
960 sbuffer = np.zeros(stotal, dtype=np.int8)
961 for proc in range(comm.size):
962 sbuffer[sdispls[proc]:(sdispls[proc] + scounts[proc])] = (
963 np.frombuffer(send_dict[proc].encode(), np.int8))
965 rbuffer = np.zeros(rtotal, dtype=np.int8)
966 comm.alltoallv(sbuffer, scounts, sdispls, rbuffer, rcounts, rdispls)
968 rdict = {}
969 for proc in range(comm.size):
970 i = rdispls[proc]
971 rdict[proc] = rbuffer[i:i + rcounts[proc]].tobytes().decode()
973 return rdict
976def ibarrier(timeout=None, root=0, tag=123, comm=world):
977 """Non-blocking barrier returning a list of requests to wait for.
978 An optional time-out may be given, turning the call into a blocking
979 barrier with an upper time limit, beyond which an exception is raised."""
980 requests = []
981 byte = np.ones(1, dtype=np.int8)
982 if comm.rank == root:
983 # Everybody else:
984 for rank in range(comm.size):
985 if rank == root:
986 continue
987 rbuf, sbuf = np.empty_like(byte), byte.copy()
988 requests.append(comm.send(sbuf, rank, tag=2 * tag + 0,
989 block=False))
990 requests.append(comm.receive(rbuf, rank, tag=2 * tag + 1,
991 block=False))
992 else:
993 rbuf, sbuf = np.empty_like(byte), byte
994 requests.append(comm.receive(rbuf, root, tag=2 * tag + 0, block=False))
995 requests.append(comm.send(sbuf, root, tag=2 * tag + 1, block=False))
997 if comm.size == 1 or timeout is None:
998 return requests
1000 t0 = time.time()
1001 while not comm.testall(requests): # automatic clean-up upon success
1002 if time.time() - t0 > timeout:
1003 raise RuntimeError('MPI barrier timeout.')
1004 return []
1007def run(iterators):
1008 """Run through list of iterators one step at a time."""
1009 if not isinstance(iterators, list):
1010 # It's a single iterator - empty it:
1011 for i in iterators:
1012 pass
1013 return
1015 if len(iterators) == 0:
1016 return
1018 while True:
1019 try:
1020 results = [next(iter) for iter in iterators]
1021 except StopIteration:
1022 return results
1025class Parallelization:
1026 def __init__(self, comm, nkpts):
1027 self.comm = comm
1028 self.size = comm.size
1029 self.nkpts = nkpts
1031 self.kpt = None
1032 self.domain = None
1033 self.band = None
1035 self.nclaimed = 1
1036 self.navail = comm.size
1038 def set(self, kpt=None, domain=None, band=None):
1039 if kpt is not None:
1040 self.kpt = kpt
1041 if domain is not None:
1042 self.domain = domain
1043 if band is not None:
1044 self.band = band
1046 nclaimed = 1
1047 for group, name in zip([self.kpt, self.domain, self.band],
1048 ['k-point', 'domain', 'band']):
1049 if group is not None:
1050 assert group > 0, ('Bad: Only {} cores requested for '
1051 '{} parallelization'.format(group, name))
1052 if self.size % group != 0:
1053 msg = ('Cannot parallelize as the '
1054 'communicator size %d is not divisible by the '
1055 'requested number %d of ranks for %s '
1056 'parallelization' % (self.size, group, name))
1057 raise ValueError(msg)
1058 nclaimed *= group
1059 navail = self.size // nclaimed
1061 assert self.size % nclaimed == 0
1062 assert self.size % navail == 0
1064 self.navail = navail
1065 self.nclaimed = nclaimed
1067 def get_communicator_sizes(self, kpt=None, domain=None, band=None):
1068 self.set(kpt=kpt, domain=domain, band=band)
1069 self.autofinalize()
1070 return self.kpt, self.domain, self.band
1072 def build_communicators(self, kpt=None, domain=None, band=None,
1073 order='kbd'):
1074 """Construct communicators.
1076 Returns a communicator for k-points, domains, bands and
1077 k-points/bands. The last one "unites" all ranks that are
1078 responsible for the same domain.
1080 The order must be a permutation of the characters 'kbd', each
1081 corresponding to each a parallelization mode. The last
1082 character signifies the communicator that will be assigned
1083 contiguous ranks, i.e. order='kbd' will yield contiguous
1084 domain ranks, whereas order='kdb' will yield contiguous band
1085 ranks."""
1086 self.set(kpt=kpt, domain=domain, band=band)
1087 self.autofinalize()
1089 comm = self.comm
1090 rank = comm.rank
1091 communicators = {}
1092 parent_stride = self.size
1093 offset = 0
1095 groups = dict(k=self.kpt, b=self.band, d=self.domain)
1097 # Build communicators in hierarchical manner
1098 # The ranks in the first group have largest separation while
1099 # the ranks in the last group are next to each other
1100 for name in order:
1101 group = groups[name]
1102 stride = parent_stride // group
1103 # First rank in this group
1104 r0 = rank % stride + offset
1105 # Last rank in this group
1106 r1 = r0 + stride * group
1107 ranks = np.arange(r0, r1, stride)
1108 communicators[name] = comm.new_communicator(ranks)
1109 parent_stride = stride
1110 # Offset for the next communicator
1111 offset += communicators[name].rank * stride
1113 # We want a communicator for kpts/bands, i.e. the complement of the
1114 # grid comm: a communicator uniting all cores with the same domain.
1115 c1, c2, c3 = (communicators[name] for name in order)
1116 allranks = [range(c1.size), range(c2.size), range(c3.size)]
1118 def get_communicator_complement(name):
1119 relevant_ranks = list(allranks)
1120 relevant_ranks[order.find(name)] = [communicators[name].rank]
1121 ranks = np.array([r3 + c3.size * (r2 + c2.size * r1)
1122 for r1 in relevant_ranks[0]
1123 for r2 in relevant_ranks[1]
1124 for r3 in relevant_ranks[2]])
1125 return comm.new_communicator(ranks)
1127 # The communicator of all processes that share a domain, i.e.
1128 # the combination of k-point and band dommunicators.
1129 communicators['D'] = get_communicator_complement('d')
1130 # For each k-point comm rank, a communicator of all
1131 # band/domain ranks. This is typically used with ScaLAPACK
1132 # and LCAO orbital stuff.
1133 communicators['K'] = get_communicator_complement('k')
1134 return communicators
1136 def autofinalize(self):
1137 if self.kpt is None:
1138 self.set(kpt=self.get_optimal_kpt_parallelization())
1139 if self.domain is None:
1140 self.set(domain=self.navail)
1141 if self.band is None:
1142 self.set(band=self.navail)
1144 if self.navail > 1:
1145 assignments = dict(kpt=self.kpt,
1146 domain=self.domain,
1147 band=self.band)
1148 raise gpaw.BadParallelization(
1149 f'All the CPUs must be used. Have {assignments} but '
1150 f'{self.navail} times more are available.')
1152 def get_optimal_kpt_parallelization(self, kptprioritypower=1.4):
1153 if self.domain and self.band:
1154 # Try to use all the CPUs for k-point parallelization
1155 ncpus = min(self.nkpts, self.navail)
1156 return ncpus
1157 ncpuvalues, wastevalues = self.find_kpt_parallelizations()
1158 scores = ((self.navail // ncpuvalues) *
1159 ncpuvalues**kptprioritypower)**(1.0 - wastevalues)
1160 arg = np.argmax(scores)
1161 ncpus = ncpuvalues[arg]
1162 return ncpus
1164 def find_kpt_parallelizations(self):
1165 nkpts = self.nkpts
1166 ncpuvalues = []
1167 wastevalues = []
1169 ncpus = nkpts
1170 while ncpus > 0:
1171 if self.navail % ncpus == 0:
1172 nkptsmax = -(-nkpts // ncpus)
1173 effort = nkptsmax * ncpus
1174 efficiency = nkpts / float(effort)
1175 waste = 1.0 - efficiency
1176 wastevalues.append(waste)
1177 ncpuvalues.append(ncpus)
1178 ncpus -= 1
1179 return np.array(ncpuvalues), np.array(wastevalues)
1182def cleanup():
1183 error = getattr(sys, 'last_type', None)
1184 if error is not None: # else: Python script completed or raise SystemExit
1185 if parallel and not (gpaw.dry_run > 1):
1186 sys.stdout.flush()
1187 sys.stderr.write(('GPAW CLEANUP (node %d): %s occurred. '
1188 'Calling MPI_Abort!\n') % (world.rank, error))
1189 sys.stderr.flush()
1190 # Give other nodes a moment to crash by themselves (perhaps
1191 # producing helpful error messages)
1192 time.sleep(10)
1193 world.abort(42)
1196def print_mpi_stack_trace(type, value, tb):
1197 """Format exceptions nicely when running in parallel.
1199 Use this function as an except hook. Adds rank
1200 and line number to each line of the exception. Lines will
1201 still be printed from different ranks in random order, but
1202 one can grep for a rank or run 'sort' on the output to obtain
1203 readable data."""
1205 exception_text = traceback.format_exception(type, value, tb)
1206 ndigits = len(str(world.size - 1))
1207 rankstring = ('%%0%dd' % ndigits) % world.rank
1209 lines = []
1210 # The exception elements may contain newlines themselves
1211 for element in exception_text:
1212 lines.extend(element.splitlines())
1214 line_ndigits = len(str(len(lines) - 1))
1216 for lineno, line in enumerate(lines):
1217 lineno = ('%%0%dd' % line_ndigits) % lineno
1218 sys.stderr.write(f'rank={rankstring} L{lineno}: {line}\n')
1221if world.size > 1: # Triggers for dry-run communicators too, but we care not.
1222 sys.excepthook = print_mpi_stack_trace
1225def exit(error='Manual exit'):
1226 # Note that exit must be called on *all* MPI tasks
1227 atexit._exithandlers = [] # not needed because we are intentially exiting
1228 if parallel and not (gpaw.dry_run > 1):
1229 sys.stdout.flush()
1230 sys.stderr.write(('GPAW CLEANUP (node %d): %s occurred. ' +
1231 'Calling MPI_Finalize!\n') % (world.rank, error))
1232 sys.stderr.flush()
1233 else:
1234 cleanup(error)
1235 world.barrier() # sync up before exiting
1236 sys.exit() # quit for serial case, return to cgpaw.c for parallel case
1239atexit.register(cleanup)