Coverage for gpaw/mpi4pywrapper.py: 42%
66 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1try:
2 from mpi4py.MPI import Request, SUM, MAX, IN_PLACE
3except ImportError:
4 pass
7class MPI4PYWrapper:
8 def __init__(self, comm, parent=None):
9 self.comm = comm
10 self.size = comm.size
11 self.rank = comm.rank
12 self.parent = parent # XXX check C-object against comm.parent?
14 def new_communicator(self, ranks):
15 comm = self.comm.Create(self.comm.group.Incl(ranks))
16 if self.comm.rank in ranks:
17 return MPI4PYWrapper(comm, parent=self)
18 else:
19 # This cpu is not in the new communicator:
20 return None
22 def max_scalar(self, a, root=-1):
23 return self.sum_scalar(a, root=-1, _op=MAX)
25 def sum_scalar(self, a, root=-1, _op=None):
26 if _op is None:
27 _op = SUM
28 assert isinstance(a, (int, float, complex))
29 if root == -1:
30 return self.comm.allreduce(a, op=_op)
31 else:
32 return self.comm.reduce(a, root=root, op=_op)
34 def sum(self, a, root=-1):
35 if root == -1:
36 self.comm.Allreduce(IN_PLACE, a, op=SUM)
37 else:
38 if root == self.rank:
39 self.comm.Reduce(IN_PLACE, a, root=root, op=SUM)
40 else:
41 self.comm.Reduce(a, None, root=root, op=SUM)
43 def scatter(self, a, b, root):
44 self.comm.Scatter(a, b, root)
46 def alltoallv(self, sbuffer, scounts, sdispls, rbuffer, rcounts, rdispls):
47 self.comm.Alltoallv((sbuffer, (scounts, sdispls), sbuffer.dtype.char),
48 (rbuffer, (rcounts, rdispls), rbuffer.dtype.char))
50 def all_gather(self, a, b):
51 self.comm.Allgather(a, b)
53 def gather(self, a, root, b=None):
54 self.comm.Gather(a, b, root)
56 def broadcast(self, a, root):
57 self.comm.Bcast(a, root)
59 def sendreceive(self, a, dest, b, src, sendtag=123, recvtag=123):
60 return self.comm.Sendrecv(a, dest, sendtag, b, src, recvtag)
62 def send(self, a, dest, tag=123, block=True):
63 if block:
64 self.comm.Send(a, dest, tag)
65 else:
66 return self.comm.Isend(a, dest, tag)
68 def ssend(self, a, dest, tag=123):
69 return self.comm.Ssend(a, dest, tag)
71 def receive(self, a, src, tag=123, block=True):
72 if block:
73 self.comm.Recv(a, src, tag)
74 else:
75 return self.comm.Irecv(a, src, tag)
77 def test(self, request):
78 return request.test()
80 def testall(self, requests):
81 return Request.testall(requests)
83 def wait(self, request):
84 request.wait()
86 def waitall(self, requests):
87 Request.waitall(requests)
89 def name(self):
90 return self.comm.Get_name()
92 def barrier(self):
93 self.comm.barrier()
95 def get_c_object(self):
96 return self.comm