Coverage for gpaw/projections.py: 79%
131 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1from typing import Any, Optional, Dict
2from collections.abc import Mapping
4import numpy as np
6from gpaw.matrix import Matrix
7from gpaw.mpi import serial_comm
8from gpaw.utilities.partition import AtomPartition
9from gpaw.typing import Array2D, ArrayLike1D
11MPIComm = Any
14class Projections(Mapping):
15 def __init__(self,
16 nbands: int,
17 nproj_a: ArrayLike1D,
18 atom_partition: AtomPartition,
19 bcomm: MPIComm = None,
20 collinear=True,
21 spin=0,
22 dtype=None,
23 data=None,
24 bdist=None):
25 if bdist is None:
26 self.bcomm = bcomm or serial_comm
27 bdist = (self.bcomm, self.bcomm.size, 1)
28 else:
29 assert bcomm is None
30 self.bdist = bdist # For calling "new"
31 self.bcomm = bdist[0]
33 self.nproj_a = np.asarray(nproj_a)
34 self.atom_partition = atom_partition
35 self.collinear = collinear
36 self.spin = spin
37 self.nbands = nbands
39 self.indices = []
40 self.map = {}
41 I1 = 0
43 for a in self.atom_partition.my_indices:
44 ni = self.nproj_a[a]
45 I2 = I1 + ni
46 self.indices.append((a, I1, I2))
47 self.map[a] = (I1, I2)
48 I1 = I2
50 if not collinear:
51 I1 *= 2
53 if dtype is None and data is None:
54 dtype = float if collinear else complex
56 self.matrix = Matrix(nbands, I1, dtype, data, dist=bdist)
58 if collinear:
59 self.myshape = self.matrix.array.shape
60 else:
61 self.myshape = (len(self.matrix.array), 2, I1 // 2)
63 @property
64 def array(self):
65 if self.collinear:
66 return self.matrix.array
67 else:
68 return self.matrix.array.reshape(self.myshape)
70 def new(self, bcomm='inherit', nbands=None, atom_partition=None):
71 if bcomm == 'inherit':
72 if hasattr(self, 'bdist'):
73 return Projections(
74 nbands or self.nbands, self.nproj_a,
75 self.atom_partition if atom_partition is None
76 else atom_partition,
77 collinear=self.collinear,
78 spin=self.spin,
79 dtype=self.matrix.dtype,
80 bdist=self.bdist)
81 else:
82 bcomm = self.bcomm
83 elif bcomm is None:
84 bcomm = serial_comm
86 return Projections(
87 nbands or self.nbands, self.nproj_a,
88 self.atom_partition if atom_partition is None else atom_partition,
89 bcomm, self.collinear, self.spin, self.matrix.dtype)
91 def view(self, n1: int, n2: int) -> 'Projections':
92 return Projections(n2 - n1, self.nproj_a,
93 self.atom_partition,
94 self.bcomm, self.collinear, self.spin,
95 self.matrix.dtype, self.matrix.array[n1:n2])
97 def __getitem__(self, a):
98 I1, I2 = self.map[a]
99 return self.array[..., I1:I2]
101 def __iter__(self):
102 return iter(self.map)
104 def __len__(self):
105 return len(self.map)
107 def broadcast(self) -> 'Projections':
108 ap = AtomPartition(serial_comm, np.zeros(len(self.nproj_a), int))
109 P = self.new(atom_partition=ap)
110 comm = self.atom_partition.comm
111 for a, rank in enumerate(self.atom_partition.rank_a):
112 P1_ni = P[a]
113 if comm.rank == rank:
114 P_ni = self[a].copy()
115 else:
116 P_ni = np.empty_like(P1_ni)
117 comm.broadcast(P_ni, rank)
118 P1_ni[:] = P_ni
119 return P
121 def redist(self, atom_partition) -> 'Projections':
122 """Redistribute atoms."""
123 P = self.new(atom_partition=atom_partition)
124 arraydict = self.toarraydict()
125 arraydict.redistribute(atom_partition)
126 P.fromarraydict(arraydict)
127 return P
129 def collect(self) -> Optional[Array2D]:
130 """Collect all bands and atoms to master."""
131 if self.bcomm.size == 1:
132 P = self.matrix
133 else:
134 P = self.matrix.new(dist=(self.bcomm, 1, 1))
135 self.matrix.redist(P)
137 if self.bcomm.rank > 0:
138 return None
140 if self.atom_partition.comm.size == 1:
141 return P.array
143 P_In = self.collect_atoms(P)
144 if P_In is not None:
145 return P_In.T
147 return None
149 def toarraydict(self):
150 shape = self.myshape[:-1]
151 shapes = [shape + (nproj,) for nproj in self.nproj_a]
153 d = self.atom_partition.arraydict(shapes, self.matrix.array.dtype)
154 for a, I1, I2 in self.indices:
155 d[a][:] = self.array[..., I1:I2] # Blocks will be contiguous
156 return d
158 def fromarraydict(self, d):
159 assert d.partition == self.atom_partition
160 for a, I1, I2 in self.indices:
161 self.array[..., I1:I2] = d[a]
163 def collect_atoms(self, P):
164 if self.atom_partition.comm.rank == 0:
165 nproj = sum(self.nproj_a)
166 P_In = np.empty((nproj, P.array.shape[0]), dtype=P.array.dtype)
168 I1 = 0
169 myI1 = 0
170 for nproj, rank in zip(self.nproj_a, self.atom_partition.rank_a):
171 I2 = I1 + nproj
172 if rank == 0:
173 myI2 = myI1 + nproj
174 P_In[I1:I2] = P.array[:, myI1:myI2].T
175 myI1 = myI2
176 else:
177 self.atom_partition.comm.receive(P_In[I1:I2], rank)
178 I1 = I2
179 return P_In
180 else:
181 for a, I1, I2 in self.indices:
182 self.atom_partition.comm.send(P.array[:, I1:I2].T.copy(), 0)
183 return None
185 def as_dict_on_master(self, n1: int, n2: int) -> Dict[int, Array2D]:
186 P_nI = self.collect()
187 if P_nI is None:
188 return {}
189 I1 = 0
190 P_ani = {}
191 for a, ni in enumerate(self.nproj_a):
192 I2 = I1 + ni
193 P_ani[a] = P_nI[n1:n2, I1:I2]
194 I1 = I2
195 return P_ani