Coverage for gpaw/utilities/grid.py: 67%

138 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from functools import partial 

2 

3import numpy as np 

4from gpaw.utilities.grid_redistribute import general_redistribute 

5from gpaw.utilities.partition import AtomPartition, AtomicMatrixDistributor 

6 

7 

8class GridRedistributor: 

9 def __init__(self, comm, broadcast_comm, gd, aux_gd): 

10 self.comm = comm 

11 self.broadcast_comm = broadcast_comm 

12 self.gd = gd 

13 self.aux_gd = aux_gd 

14 self.enabled = np.any(gd.parsize_c != aux_gd.parsize_c) 

15 

16 assert gd.comm.size * broadcast_comm.size == comm.size 

17 if self.enabled: 

18 assert comm.compare(aux_gd.comm) in ['ident', 'congruent'] 

19 else: 

20 assert gd.comm.compare(aux_gd.comm) in ['ident', 'congruent'] 

21 

22 if aux_gd.comm.rank == 0: 

23 aux_ranks = gd.comm.translate_ranks(aux_gd.comm, 

24 np.arange(gd.comm.size)) 

25 else: 

26 aux_ranks = np.empty(gd.comm.size, dtype=int) 

27 aux_gd.comm.broadcast(aux_ranks, 0) 

28 

29 auxrank2rank = dict(zip(aux_ranks, np.arange(gd.comm.size))) 

30 

31 def rank2parpos1(rank): 

32 if rank in auxrank2rank: 

33 return gd.get_processor_position_from_rank(auxrank2rank[rank]) 

34 else: 

35 return None 

36 

37 rank2parpos2 = aux_gd.get_processor_position_from_rank 

38 

39 try: 

40 gd.n_cp 

41 except AttributeError: # AtomPAW 

42 self._distribute = self._collect = lambda x: None 

43 return # XXX 

44 

45 self._distribute = partial(general_redistribute, aux_gd.comm, 

46 gd.n_cp, aux_gd.n_cp, 

47 rank2parpos1, rank2parpos2) 

48 self._collect = partial(general_redistribute, aux_gd.comm, 

49 aux_gd.n_cp, gd.n_cp, 

50 rank2parpos2, rank2parpos1) 

51 

52 def distribute(self, src_xg, dst_xg=None): 

53 if not self.enabled: 

54 assert src_xg is dst_xg or dst_xg is None 

55 return src_xg 

56 if dst_xg is None: 

57 dst_xg = self.aux_gd.empty(src_xg.shape[:-3], dtype=src_xg.dtype) 

58 self._distribute(src_xg, dst_xg) 

59 return dst_xg 

60 

61 def collect(self, src_xg, dst_xg=None): 

62 if not self.enabled: 

63 assert src_xg is dst_xg or dst_xg is None 

64 return src_xg 

65 if dst_xg is None: 

66 dst_xg = self.gd.empty(src_xg.shape[:-3], src_xg.dtype) 

67 self._collect(src_xg, dst_xg) 

68 self.broadcast_comm.broadcast(dst_xg, 0) 

69 return dst_xg 

70 

71 def get_atom_distributions(self, spos_ac): 

72 return AtomDistributions(self.comm, self.broadcast_comm, 

73 self.gd, self.aux_gd, spos_ac) 

74 

75 

76class AtomDistributions: 

77 def __init__(self, comm, broadcast_comm, gd, aux_gd, spos_ac): 

78 self.comm = comm 

79 self.broadcast_comm = broadcast_comm 

80 self.gd = gd 

81 self.aux_gd = aux_gd 

82 

83 rank_a = gd.get_ranks_from_positions(spos_ac) 

84 aux_rank_a = aux_gd.get_ranks_from_positions(spos_ac) 

85 self.partition = AtomPartition(gd.comm, rank_a, name='gd') 

86 

87 if gd is aux_gd: 

88 name = 'aux-unextended' 

89 else: 

90 name = 'aux-extended' 

91 self.aux_partition = AtomPartition(aux_gd.comm, aux_rank_a, name=name) 

92 

93 self.work_partition = AtomPartition(comm, np.zeros(len(spos_ac)), 

94 name='work').as_even_partition() 

95 

96 if gd is aux_gd: 

97 aux_broadcast_comm = gd.comm.new_communicator([gd.comm.rank]) 

98 else: 

99 aux_broadcast_comm = broadcast_comm 

100 

101 self.aux_dist = AtomicMatrixDistributor(self.partition, 

102 aux_broadcast_comm, 

103 self.aux_partition) 

104 self.work_dist = AtomicMatrixDistributor(self.partition, 

105 broadcast_comm, 

106 self.work_partition) 

107 

108 def to_aux(self, arraydict): 

109 if self.gd is self.aux_gd: 

110 return arraydict.copy() 

111 return self.aux_dist.distribute(arraydict) 

112 

113 def from_aux(self, arraydict): 

114 if self.gd is self.aux_gd: 

115 return arraydict.copy() 

116 return self.aux_dist.collect(arraydict) 

117 

118 def to_work(self, arraydict): 

119 return self.work_dist.distribute(arraydict) 

120 

121 def from_work(self, arraydict): 

122 return self.work_dist.collect(arraydict) 

123 

124 

125def get_domains_from_gd(comm, gd, offset_c=None): 

126 ranks = gd.comm.translate_ranks(comm, np.arange(gd.comm.size)) 

127 assert (ranks >= 0).all(), 'comm not parent of gd.comm' 

128 

129 def rank2parpos(rank): 

130 gdrank = comm.translate_ranks(gd.comm, np.array([rank]))[0] 

131 # XXXXXXXXXXXXX segfault when not passing array!! 

132 if gdrank == -1: 

133 return None 

134 return gd.get_processor_position_from_rank(gdrank) 

135 

136 def add_offset(n_cp, offset_c): 

137 n_cp = [n_p.copy() for n_p in n_cp] 

138 for c in range(3): 

139 n_cp[c] += offset_c[c] 

140 return n_cp 

141 

142 n_cp = gd.n_cp 

143 if offset_c is not None: 

144 n_cp = add_offset(n_cp, offset_c) 

145 

146 return n_cp, rank2parpos 

147 

148 

149def grid2grid(comm, gd1, gd2, src_g, dst_g, offset1_c=None, offset2_c=None, 

150 xp=np): 

151 assert np.all(src_g.shape == gd1.n_c) 

152 assert np.all(dst_g.shape == gd2.n_c) 

153 

154 n1_cp, rank2parpos1 = get_domains_from_gd(comm, gd1, offset_c=offset1_c) 

155 n2_cp, rank2parpos2 = get_domains_from_gd(comm, gd2, offset_c=offset2_c) 

156 

157 general_redistribute(comm, 

158 n1_cp, n2_cp, 

159 rank2parpos1, rank2parpos2, 

160 src_g, dst_g, xp=xp) 

161 

162 

163def main(): 

164 from gpaw.grid_descriptor import GridDescriptor 

165 from gpaw.mpi import world 

166 

167 serial = world.new_communicator([world.rank]) 

168 

169 # Generator which must run on all ranks 

170 gen = np.random.RandomState(0) 

171 

172 # This one is just used by master 

173 gen_serial = np.random.RandomState(17) 

174 

175 maxsize = 5 

176 for i in range(1): 

177 N1_c = gen.randint(1, maxsize, 3) 

178 N2_c = gen.randint(1, maxsize, 3) 

179 

180 gd1 = GridDescriptor(N1_c, N1_c) 

181 gd2 = GridDescriptor(N2_c, N2_c) 

182 serial_gd1 = gd1.new_descriptor(comm=serial) 

183 # serial_gd2 = gd2.new_descriptor(comm=serial) 

184 

185 a1_serial = serial_gd1.empty() 

186 a1_serial.flat[:] = gen_serial.rand(a1_serial.size) 

187 

188 if world.rank == 0: 

189 print('r0: a1 serial', a1_serial.ravel()) 

190 

191 a1 = gd1.empty() 

192 a1[:] = -1 

193 

194 grid2grid(world, serial_gd1, gd1, a1_serial, a1) 

195 

196 print(world.rank, 'a1 distributed', a1.ravel()) 

197 world.barrier() 

198 

199 a2 = gd2.zeros() 

200 a2[:] = -2 

201 grid2grid(world, gd1, gd2, a1, a2) 

202 print(world.rank, 'a2 distributed', a2.ravel()) 

203 world.barrier() 

204 

205 gd1 = GridDescriptor(N1_c, N1_c * 0.2) 

206 

207 a1 = gd1.empty() 

208 a1.flat[:] = gen.rand(a1.size) 

209 

210 grid2grid(world, gd1, gd2, a1, a2) 

211 

212 

213if __name__ == '__main__': 

214 main()