Coverage for gpaw/gpu/mpi.py: 44%
97 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-12 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-12 00:18 +0000
1from gpaw.gpu import cupy as cp
2import numpy as np
5class CuPyMPI:
6 """Quick'n'dirty wrapper to make things work without a GPU-aware MPI."""
7 def __init__(self, comm):
8 self.comm = comm
9 self.rank = comm.rank
10 self.size = comm.size
12 def __repr__(self):
13 return f'CuPyMPI({self.comm})'
15 def sum(self, array, root=-1):
16 if isinstance(array, (float, int)):
17 1 / 0
18 return self.comm.sum(array, root)
19 if isinstance(array, np.ndarray):
20 self.comm.sum(array, root)
21 return
22 a = array.get()
23 self.comm.sum(a, root)
24 array.set(a)
26 def sum_scalar(self, a, root=-1):
27 return self.comm.sum_scalar(a, root)
29 def min_scalar(self, a, root=-1):
30 return self.comm.min_scalar(a, root)
32 def max_scalar(self, a, root=-1):
33 return self.comm.max_scalar(a, root)
35 def max(self, array):
36 self.comm.max(array)
38 def all_gather(self, a, b):
39 self.comm.all_gather(a, b)
41 def gather(self, a, rank, b):
42 if isinstance(a, np.ndarray):
43 self.comm.gather(a, rank, b)
44 else:
45 if rank == self.rank:
46 c = np.empty(b.shape, b.dtype)
47 else:
48 c = None
49 self.comm.gather(a.get(), rank, c)
50 if rank == self.rank:
51 b[:] = cp.asarray(c)
53 def scatter(self, fro, to, root=0):
54 if isinstance(to, np.ndarray):
55 self.comm.scatter(fro, to, root)
56 return
57 b = np.empty(to.shape, to.dtype)
58 if self.rank == root:
59 a = fro.get()
60 else:
61 a = None
62 self.comm.scatter(a, b, root)
63 to[:] = cp.asarray(b)
65 def broadcast(self, a, root):
66 if isinstance(a, np.ndarray):
67 self.comm.broadcast(a, root)
68 return
69 b = a.get()
70 self.comm.broadcast(b, root)
71 a[...] = cp.asarray(b)
73 def receive(self, a, rank, tag=0, block=True):
74 if isinstance(a, np.ndarray):
75 return self.comm.receive(a, rank, tag, block)
76 b = np.empty(a.shape, a.dtype)
77 req = self.comm.receive(b, rank, tag, block)
78 if block:
79 a[:] = cp.asarray(b)
80 return
81 return CuPyRequest(req, b, a)
83 def ssend(self, a, rank, tag):
84 if isinstance(a, np.ndarray):
85 self.comm.ssend(a, rank, tag)
86 else:
87 self.comm.ssend(a.get(), rank, tag)
89 def send(self, a, rank, tag=0, block=True):
90 if isinstance(a, np.ndarray):
91 return self.comm.send(a, rank, tag, block)
92 b = a.get()
93 request = self.comm.send(b, rank, tag, block)
94 if not block:
95 return CuPyRequest(request, b)
97 def alltoallv(self,
98 fro, ssizes, soffsets,
99 to, rsizes, roffsets):
100 a = np.empty(to.shape, to.dtype)
101 self.comm.alltoallv(fro.get(), ssizes, soffsets,
102 a, rsizes, roffsets)
103 to[:] = cp.asarray(a)
105 def wait(self, request):
106 if not isinstance(request, CuPyRequest):
107 return self.comm.wait(request)
108 self.comm.wait(request.request)
109 if request.target is not None:
110 request.target[:] = cp.asarray(request.buffer)
112 def waitall(self, requests):
113 self.comm.waitall([request.request for request in requests])
114 for request in requests:
115 if request.target is not None:
116 request.target[:] = cp.asarray(request.buffer)
118 def get_c_object(self):
119 return self.comm.get_c_object()
122class CuPyRequest:
123 def __init__(self, request, buffer, target=None):
124 self.request = request
125 self.buffer = buffer
126 self.target = target