Coverage for gpaw/mpi4pywrapper.py: 42%

66 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1try: 

2 from mpi4py.MPI import Request, SUM, MAX, IN_PLACE 

3except ImportError: 

4 pass 

5 

6 

7class MPI4PYWrapper: 

8 def __init__(self, comm, parent=None): 

9 self.comm = comm 

10 self.size = comm.size 

11 self.rank = comm.rank 

12 self.parent = parent # XXX check C-object against comm.parent? 

13 

14 def new_communicator(self, ranks): 

15 comm = self.comm.Create(self.comm.group.Incl(ranks)) 

16 if self.comm.rank in ranks: 

17 return MPI4PYWrapper(comm, parent=self) 

18 else: 

19 # This cpu is not in the new communicator: 

20 return None 

21 

22 def max_scalar(self, a, root=-1): 

23 return self.sum_scalar(a, root=-1, _op=MAX) 

24 

25 def sum_scalar(self, a, root=-1, _op=None): 

26 if _op is None: 

27 _op = SUM 

28 assert isinstance(a, (int, float, complex)) 

29 if root == -1: 

30 return self.comm.allreduce(a, op=_op) 

31 else: 

32 return self.comm.reduce(a, root=root, op=_op) 

33 

34 def sum(self, a, root=-1): 

35 if root == -1: 

36 self.comm.Allreduce(IN_PLACE, a, op=SUM) 

37 else: 

38 if root == self.rank: 

39 self.comm.Reduce(IN_PLACE, a, root=root, op=SUM) 

40 else: 

41 self.comm.Reduce(a, None, root=root, op=SUM) 

42 

43 def scatter(self, a, b, root): 

44 self.comm.Scatter(a, b, root) 

45 

46 def alltoallv(self, sbuffer, scounts, sdispls, rbuffer, rcounts, rdispls): 

47 self.comm.Alltoallv((sbuffer, (scounts, sdispls), sbuffer.dtype.char), 

48 (rbuffer, (rcounts, rdispls), rbuffer.dtype.char)) 

49 

50 def all_gather(self, a, b): 

51 self.comm.Allgather(a, b) 

52 

53 def gather(self, a, root, b=None): 

54 self.comm.Gather(a, b, root) 

55 

56 def broadcast(self, a, root): 

57 self.comm.Bcast(a, root) 

58 

59 def sendreceive(self, a, dest, b, src, sendtag=123, recvtag=123): 

60 return self.comm.Sendrecv(a, dest, sendtag, b, src, recvtag) 

61 

62 def send(self, a, dest, tag=123, block=True): 

63 if block: 

64 self.comm.Send(a, dest, tag) 

65 else: 

66 return self.comm.Isend(a, dest, tag) 

67 

68 def ssend(self, a, dest, tag=123): 

69 return self.comm.Ssend(a, dest, tag) 

70 

71 def receive(self, a, src, tag=123, block=True): 

72 if block: 

73 self.comm.Recv(a, src, tag) 

74 else: 

75 return self.comm.Irecv(a, src, tag) 

76 

77 def test(self, request): 

78 return request.test() 

79 

80 def testall(self, requests): 

81 return Request.testall(requests) 

82 

83 def wait(self, request): 

84 request.wait() 

85 

86 def waitall(self, requests): 

87 Request.waitall(requests) 

88 

89 def name(self): 

90 return self.comm.Get_name() 

91 

92 def barrier(self): 

93 self.comm.barrier() 

94 

95 def get_c_object(self): 

96 return self.comm