Coverage for gpaw/gpu/mpi.py: 44%

97 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-12 00:18 +0000

1from gpaw.gpu import cupy as cp 

2import numpy as np 

3 

4 

5class CuPyMPI: 

6 """Quick'n'dirty wrapper to make things work without a GPU-aware MPI.""" 

7 def __init__(self, comm): 

8 self.comm = comm 

9 self.rank = comm.rank 

10 self.size = comm.size 

11 

12 def __repr__(self): 

13 return f'CuPyMPI({self.comm})' 

14 

15 def sum(self, array, root=-1): 

16 if isinstance(array, (float, int)): 

17 1 / 0 

18 return self.comm.sum(array, root) 

19 if isinstance(array, np.ndarray): 

20 self.comm.sum(array, root) 

21 return 

22 a = array.get() 

23 self.comm.sum(a, root) 

24 array.set(a) 

25 

26 def sum_scalar(self, a, root=-1): 

27 return self.comm.sum_scalar(a, root) 

28 

29 def min_scalar(self, a, root=-1): 

30 return self.comm.min_scalar(a, root) 

31 

32 def max_scalar(self, a, root=-1): 

33 return self.comm.max_scalar(a, root) 

34 

35 def max(self, array): 

36 self.comm.max(array) 

37 

38 def all_gather(self, a, b): 

39 self.comm.all_gather(a, b) 

40 

41 def gather(self, a, rank, b): 

42 if isinstance(a, np.ndarray): 

43 self.comm.gather(a, rank, b) 

44 else: 

45 if rank == self.rank: 

46 c = np.empty(b.shape, b.dtype) 

47 else: 

48 c = None 

49 self.comm.gather(a.get(), rank, c) 

50 if rank == self.rank: 

51 b[:] = cp.asarray(c) 

52 

53 def scatter(self, fro, to, root=0): 

54 if isinstance(to, np.ndarray): 

55 self.comm.scatter(fro, to, root) 

56 return 

57 b = np.empty(to.shape, to.dtype) 

58 if self.rank == root: 

59 a = fro.get() 

60 else: 

61 a = None 

62 self.comm.scatter(a, b, root) 

63 to[:] = cp.asarray(b) 

64 

65 def broadcast(self, a, root): 

66 if isinstance(a, np.ndarray): 

67 self.comm.broadcast(a, root) 

68 return 

69 b = a.get() 

70 self.comm.broadcast(b, root) 

71 a[...] = cp.asarray(b) 

72 

73 def receive(self, a, rank, tag=0, block=True): 

74 if isinstance(a, np.ndarray): 

75 return self.comm.receive(a, rank, tag, block) 

76 b = np.empty(a.shape, a.dtype) 

77 req = self.comm.receive(b, rank, tag, block) 

78 if block: 

79 a[:] = cp.asarray(b) 

80 return 

81 return CuPyRequest(req, b, a) 

82 

83 def ssend(self, a, rank, tag): 

84 if isinstance(a, np.ndarray): 

85 self.comm.ssend(a, rank, tag) 

86 else: 

87 self.comm.ssend(a.get(), rank, tag) 

88 

89 def send(self, a, rank, tag=0, block=True): 

90 if isinstance(a, np.ndarray): 

91 return self.comm.send(a, rank, tag, block) 

92 b = a.get() 

93 request = self.comm.send(b, rank, tag, block) 

94 if not block: 

95 return CuPyRequest(request, b) 

96 

97 def alltoallv(self, 

98 fro, ssizes, soffsets, 

99 to, rsizes, roffsets): 

100 a = np.empty(to.shape, to.dtype) 

101 self.comm.alltoallv(fro.get(), ssizes, soffsets, 

102 a, rsizes, roffsets) 

103 to[:] = cp.asarray(a) 

104 

105 def wait(self, request): 

106 if not isinstance(request, CuPyRequest): 

107 return self.comm.wait(request) 

108 self.comm.wait(request.request) 

109 if request.target is not None: 

110 request.target[:] = cp.asarray(request.buffer) 

111 

112 def waitall(self, requests): 

113 self.comm.waitall([request.request for request in requests]) 

114 for request in requests: 

115 if request.target is not None: 

116 request.target[:] = cp.asarray(request.buffer) 

117 

118 def get_c_object(self): 

119 return self.comm.get_c_object() 

120 

121 

122class CuPyRequest: 

123 def __init__(self, request, buffer, target=None): 

124 self.request = request 

125 self.buffer = buffer 

126 self.target = target