Coverage for gpaw/arraydict.py: 78%

100 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1from collections.abc import MutableMapping 

2import numpy as np 

3 

4 

5class DefaultKeyMapping: 

6 def a2key(self, a): 

7 return a 

8 

9 def key2a(self, key): 

10 return key 

11 

12 

13class TableKeyMapping: 

14 def __init__(self, a2key_a, key2a_k): 

15 self.a2key_a = a2key_a 

16 self.key2a_k = key2a_k 

17 

18 def key2a(self, key): 

19 return self.key2a_k[key] 

20 

21 def a2key(self, a): 

22 return self.a2key_a[a] 

23 

24 

25class ArrayDict(MutableMapping): 

26 """Distributed dictionary of fixed-size, fixed-dtype arrays. 

27 

28 Implements a map [0, ..., N] -> [A0, ..., AN] 

29 

30 Elements are initialized as empty numpy arrays. 

31 

32 Unlike a normal dictionary, this class implements a strict loop ordering 

33 which is consistent with that of the underlying atom partition.""" 

34 def __init__(self, partition, shapes_a, dtype=float, d=None, 

35 keymap=None): 

36 self.partition = partition 

37 if callable(shapes_a): 

38 shapes_a = [shapes_a(a) for a in range(self.partition.natoms)] 

39 self.shapes_a = shapes_a # global 

40 assert len(shapes_a) == partition.natoms 

41 self.dtype = dtype 

42 self.data = {} 

43 

44 # This will be a terrible hack to make it easier to customize 

45 # the "keys". Normally keys are 0...N and correspond to 

46 # rank_a of the AtomPartition. But the keymap allows the user 

47 # to pass around a one-to-one mapping between the "physical" 

48 # keys 0...N and some other objects that the user may fancy. 

49 if keymap is None: 

50 keymap = DefaultKeyMapping() 

51 self.keymap = keymap 

52 

53 if d is None: 

54 for a in partition.my_indices: 

55 self[a] = np.empty(self.shapes_a[a], dtype=dtype) 

56 else: 

57 self.update(d) 

58 self.check_consistency() 

59 

60 # copy() is dangerous since redistributions write back 

61 # into arrays, and redistribution of a copy could lead to bugs 

62 # if the other copy changes. 

63 # 

64 # Let's just use deepcopy to stay out of trouble. Then the 

65 # dictionary copy() is not invoked confusingly. 

66 # 

67 # -askhl 

68 def copy(self): 

69 copy = ArrayDict(self.partition, self.shapes_a, dtype=self.dtype, 

70 keymap=self.keymap) 

71 for a in self: 

72 copy[a][:] = self[a] 

73 return copy 

74 

75 def __getitem__(self, a): 

76 value = self.data[a] 

77 assert value.shape == self.shapes_a[a] 

78 assert value.dtype == self.dtype 

79 return value 

80 

81 def __setitem__(self, a, value): 

82 assert value.shape == self.shapes_a[a], \ 

83 f'defined shape {self.shapes_a[a]} vs new {value.shape}' 

84 assert value.dtype == self.dtype 

85 self.data[a] = value 

86 

87 def redistribute(self, partition): 

88 """Redistribute according to specified partition.""" 

89 def get_empty(a): 

90 return np.empty(self.shapes_a[a], self.dtype) 

91 

92 self.partition.redistribute(partition, self, get_empty) 

93 self.partition = partition # Better with immutable partition? 

94 self.check_consistency() 

95 

96 def check_consistency(self): 

97 k1 = set(self.partition.my_indices) 

98 k2 = set(self.data.keys()) 

99 assert k1 == k2, f'Required keys {k1} different from actual {k2}' 

100 for a, array in self.items(): 

101 assert array.dtype == self.dtype 

102 assert array.shape == self.shapes_a[a], \ 

103 (f'array shape {array.shape} ' 

104 f'vs specified shape {self.shapes_a[a]}') 

105 

106 def toarray(self, axis=None): 

107 # We could also implement it as a contiguous buffer. 

108 if len(self) == 0: 

109 # XXXXXX how should we deal with globally or locally empty arrays? 

110 # This will probably lead to bugs unless we get all the 

111 # dimensions right. 

112 return np.empty(0, self.dtype) 

113 if axis is None: 

114 return np.concatenate([self[a].ravel() 

115 for a in self.partition.my_indices]) 

116 else: 

117 # XXX self[a].shape must all be consistent except along axis 

118 return np.concatenate([self[a] for a in self.partition.my_indices], 

119 axis=axis) 

120 

121 def fromarray(self, data): 

122 assert data.dtype == self.dtype 

123 M1 = 0 

124 for a in self.partition.my_indices: 

125 M2 = M1 + np.prod(self.shapes_a[a]) 

126 self[a].ravel()[:] = data[M1:M2] 

127 M1 = M2 

128 

129 def redistribute_and_broadcast(self, dist_comm, dup_comm): 

130 # Data exists on self which is a "nice" distribution but now 

131 # we want it on sub_partition which has a smaller communicator 

132 # whose parent is self.comm. 

133 # 

134 # We want our own data replicated on each 

135 

136 # XXX direct comparison of communicators are unsafe as we do not use 

137 # MPI_Comm_compare 

138 

139 # assert subpartition.comm.parent == self.partition.comm 

140 from gpaw.utilities.partition import AtomPartition 

141 

142 newrank_a = self.partition.rank_a % dist_comm.size 

143 masters_only_partition = AtomPartition(self.partition.comm, newrank_a) 

144 dst_partition = AtomPartition(dist_comm, newrank_a) 

145 copy = self.copy() 

146 copy.redistribute(masters_only_partition) 

147 

148 dst = ArrayDict(dst_partition, self.shapes_a, dtype=self.dtype, 

149 keymap=self.keymap) 

150 data = dst.toarray() 

151 if dup_comm.rank == 0: 

152 data0 = copy.toarray() 

153 data[:] = data0 

154 dup_comm.broadcast(data, 0) 

155 dst.fromarray(data) 

156 return dst 

157 

158 def __iter__(self): 

159 # These functions enforce the same ordering as self.partition 

160 # when looping. 

161 return iter(self.partition.my_indices) 

162 

163 def __repr__(self): 

164 tokens = [] 

165 for key in sorted(self.keys()): 

166 shapestr = 'x'.join(map(str, self.shapes_a[key])) 

167 tokens.append(f'{self.keymap.a2key(key)}:{shapestr}') 

168 text = ', '.join(tokens) 

169 return '%s@rank%d/%d {%s}' % (self.__class__.__name__, 

170 self.partition.comm.rank, 

171 self.partition.comm.size, 

172 text) 

173 

174 def __delitem__(self, a): 

175 # Actually this is not quite right; we effectively delete items 

176 # when we redistribute the arraydict. But this is another 

177 # code path so let's resolve this later if necessary. 

178 raise TypeError('Deleting arraydict elements not supported since ' 

179 'doing so violates the input list-of-shapes') 

180 

181 def __len__(self): 

182 return len(self.data)