Coverage for gpaw/projections.py: 79%

131 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1from typing import Any, Optional, Dict 

2from collections.abc import Mapping 

3 

4import numpy as np 

5 

6from gpaw.matrix import Matrix 

7from gpaw.mpi import serial_comm 

8from gpaw.utilities.partition import AtomPartition 

9from gpaw.typing import Array2D, ArrayLike1D 

10 

11MPIComm = Any 

12 

13 

14class Projections(Mapping): 

15 def __init__(self, 

16 nbands: int, 

17 nproj_a: ArrayLike1D, 

18 atom_partition: AtomPartition, 

19 bcomm: MPIComm = None, 

20 collinear=True, 

21 spin=0, 

22 dtype=None, 

23 data=None, 

24 bdist=None): 

25 if bdist is None: 

26 self.bcomm = bcomm or serial_comm 

27 bdist = (self.bcomm, self.bcomm.size, 1) 

28 else: 

29 assert bcomm is None 

30 self.bdist = bdist # For calling "new" 

31 self.bcomm = bdist[0] 

32 

33 self.nproj_a = np.asarray(nproj_a) 

34 self.atom_partition = atom_partition 

35 self.collinear = collinear 

36 self.spin = spin 

37 self.nbands = nbands 

38 

39 self.indices = [] 

40 self.map = {} 

41 I1 = 0 

42 

43 for a in self.atom_partition.my_indices: 

44 ni = self.nproj_a[a] 

45 I2 = I1 + ni 

46 self.indices.append((a, I1, I2)) 

47 self.map[a] = (I1, I2) 

48 I1 = I2 

49 

50 if not collinear: 

51 I1 *= 2 

52 

53 if dtype is None and data is None: 

54 dtype = float if collinear else complex 

55 

56 self.matrix = Matrix(nbands, I1, dtype, data, dist=bdist) 

57 

58 if collinear: 

59 self.myshape = self.matrix.array.shape 

60 else: 

61 self.myshape = (len(self.matrix.array), 2, I1 // 2) 

62 

63 @property 

64 def array(self): 

65 if self.collinear: 

66 return self.matrix.array 

67 else: 

68 return self.matrix.array.reshape(self.myshape) 

69 

70 def new(self, bcomm='inherit', nbands=None, atom_partition=None): 

71 if bcomm == 'inherit': 

72 if hasattr(self, 'bdist'): 

73 return Projections( 

74 nbands or self.nbands, self.nproj_a, 

75 self.atom_partition if atom_partition is None 

76 else atom_partition, 

77 collinear=self.collinear, 

78 spin=self.spin, 

79 dtype=self.matrix.dtype, 

80 bdist=self.bdist) 

81 else: 

82 bcomm = self.bcomm 

83 elif bcomm is None: 

84 bcomm = serial_comm 

85 

86 return Projections( 

87 nbands or self.nbands, self.nproj_a, 

88 self.atom_partition if atom_partition is None else atom_partition, 

89 bcomm, self.collinear, self.spin, self.matrix.dtype) 

90 

91 def view(self, n1: int, n2: int) -> 'Projections': 

92 return Projections(n2 - n1, self.nproj_a, 

93 self.atom_partition, 

94 self.bcomm, self.collinear, self.spin, 

95 self.matrix.dtype, self.matrix.array[n1:n2]) 

96 

97 def __getitem__(self, a): 

98 I1, I2 = self.map[a] 

99 return self.array[..., I1:I2] 

100 

101 def __iter__(self): 

102 return iter(self.map) 

103 

104 def __len__(self): 

105 return len(self.map) 

106 

107 def broadcast(self) -> 'Projections': 

108 ap = AtomPartition(serial_comm, np.zeros(len(self.nproj_a), int)) 

109 P = self.new(atom_partition=ap) 

110 comm = self.atom_partition.comm 

111 for a, rank in enumerate(self.atom_partition.rank_a): 

112 P1_ni = P[a] 

113 if comm.rank == rank: 

114 P_ni = self[a].copy() 

115 else: 

116 P_ni = np.empty_like(P1_ni) 

117 comm.broadcast(P_ni, rank) 

118 P1_ni[:] = P_ni 

119 return P 

120 

121 def redist(self, atom_partition) -> 'Projections': 

122 """Redistribute atoms.""" 

123 P = self.new(atom_partition=atom_partition) 

124 arraydict = self.toarraydict() 

125 arraydict.redistribute(atom_partition) 

126 P.fromarraydict(arraydict) 

127 return P 

128 

129 def collect(self) -> Optional[Array2D]: 

130 """Collect all bands and atoms to master.""" 

131 if self.bcomm.size == 1: 

132 P = self.matrix 

133 else: 

134 P = self.matrix.new(dist=(self.bcomm, 1, 1)) 

135 self.matrix.redist(P) 

136 

137 if self.bcomm.rank > 0: 

138 return None 

139 

140 if self.atom_partition.comm.size == 1: 

141 return P.array 

142 

143 P_In = self.collect_atoms(P) 

144 if P_In is not None: 

145 return P_In.T 

146 

147 return None 

148 

149 def toarraydict(self): 

150 shape = self.myshape[:-1] 

151 shapes = [shape + (nproj,) for nproj in self.nproj_a] 

152 

153 d = self.atom_partition.arraydict(shapes, self.matrix.array.dtype) 

154 for a, I1, I2 in self.indices: 

155 d[a][:] = self.array[..., I1:I2] # Blocks will be contiguous 

156 return d 

157 

158 def fromarraydict(self, d): 

159 assert d.partition == self.atom_partition 

160 for a, I1, I2 in self.indices: 

161 self.array[..., I1:I2] = d[a] 

162 

163 def collect_atoms(self, P): 

164 if self.atom_partition.comm.rank == 0: 

165 nproj = sum(self.nproj_a) 

166 P_In = np.empty((nproj, P.array.shape[0]), dtype=P.array.dtype) 

167 

168 I1 = 0 

169 myI1 = 0 

170 for nproj, rank in zip(self.nproj_a, self.atom_partition.rank_a): 

171 I2 = I1 + nproj 

172 if rank == 0: 

173 myI2 = myI1 + nproj 

174 P_In[I1:I2] = P.array[:, myI1:myI2].T 

175 myI1 = myI2 

176 else: 

177 self.atom_partition.comm.receive(P_In[I1:I2], rank) 

178 I1 = I2 

179 return P_In 

180 else: 

181 for a, I1, I2 in self.indices: 

182 self.atom_partition.comm.send(P.array[:, I1:I2].T.copy(), 0) 

183 return None 

184 

185 def as_dict_on_master(self, n1: int, n2: int) -> Dict[int, Array2D]: 

186 P_nI = self.collect() 

187 if P_nI is None: 

188 return {} 

189 I1 = 0 

190 P_ani = {} 

191 for a, ni in enumerate(self.nproj_a): 

192 I2 = I1 + ni 

193 P_ani[a] = P_nI[n1:n2, I1:I2] 

194 I1 = I2 

195 return P_ani