Coverage for gpaw/band_descriptor.py: 36%
155 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1# Copyright (C) 2003 CAMP
2# Please see the accompanying LICENSE file for further information.
4"""Band-descriptors for blocked/strided groups.
6This module contains classes defining two kinds of band groups:
8* Blocked groups with contiguous band indices.
9* Strided groups with evenly-spaced band indices.
10"""
12import numpy as np
14from gpaw.mpi import serial_comm
16NONBLOCKING = False
19class BandDescriptor:
20 r"""Descriptor-class for ordered lists of bands
22 A ``BandDescriptor`` object holds information on how functions, such
23 as wave functions and corresponding occupation numbers, are divided
24 into groups according to band indices. The main information here is
25 how many bands are stored on each processor and who gets what.
27 This is how a 12 band array is laid out in memory on 3 cpu's::
29 a) Blocked groups b) Strided groups
31 3 7 11 9 10 11
32 myn 2 \ 6 \ 10 myn 6 7 8
33 | 1 \ 5 \ 9 | 3 4 5
34 | 0 4 8 | 0 1 2
35 | |
36 +----- band_rank +----- band_rank
38 Example:
40 >>> a = np.zeros((3, 4))
41 >>> a.ravel()[:] = range(12)
42 >>> a
43 array([[ 0., 1., 2., 3.],
44 [ 4., 5., 6., 7.],
45 [ 8., 9., 10., 11.]])
46 >>> b = np.zeros((4, 3))
47 >>> b.ravel()[:] = range(12)
48 >>> b.T
49 array([[ 0., 3., 6., 9.],
50 [ 1., 4., 7., 10.],
51 [ 2., 5., 8., 11.]])
52 """
54 def __init__(self, nbands: int, comm=None, strided=False):
55 """Construct band-descriptor object.
57 Parameters:
59 nbands: int
60 Global number of bands.
61 comm: MPI-communicator
62 Communicator for band-groups.
63 strided: bool
64 Enable strided band distribution for better
65 load balancing with many unoccupied bands.
67 Note that if comm.size is 1, then all bands are contained on
68 a single CPU and blocked/strided grouping loses its meaning.
70 Attributes:
72 ============ ======================================================
73 ``nbands`` Number of bands in total.
74 ``mynbands`` Number of bands on this CPU.
75 ``beg`` Beginning of band indices in group (inclusive).
76 ``end`` End of band indices in group (exclusive).
77 ``step`` Stride for band indices between ``beg`` and ``end``.
78 ``comm`` MPI-communicator for band distribution.
79 ============ ======================================================
80 """
82 if comm is None:
83 comm = serial_comm
85 self.comm = comm
86 self.nbands = nbands
87 self.strided = strided
89 self.maxmynbands = (nbands + comm.size - 1) // comm.size
91 if strided:
92 assert nbands % comm.size == 0
93 self.beg = comm.rank
94 self.end = nbands
95 self.step = comm.size
96 self.mynbands = self.maxmynbands
97 else:
98 self.beg = min(nbands, comm.rank * self.maxmynbands)
99 self.end = min(nbands, self.beg + self.maxmynbands)
100 self.mynbands = self.end - self.beg
101 self.step = 1
103 def __len__(self):
104 return self.mynbands
106 def get_slice(self, band_rank=None):
107 """Return the slice of global bands which belong to a given rank."""
108 if band_rank is None:
109 band_rank = self.comm.rank
110 assert band_rank in range(self.comm.size)
112 if self.strided:
113 nstride = self.comm.size
114 nslice = slice(band_rank, None, nstride)
115 else:
116 nslice = slice(self.beg, self.end)
117 return nslice
119 def get_band_indices(self, band_rank=None):
120 """Return the global band indices which belong to a given rank."""
121 nslice = self.get_slice(band_rank)
122 return np.arange(*nslice.indices(self.nbands))
124 def get_band_ranks(self):
125 """Return array of ranks as a function of global band indices."""
126 rank_n = np.empty(self.nbands, dtype=int)
127 for band_rank in range(self.comm.size):
128 nslice = self.get_slice(band_rank)
129 rank_n[nslice] = band_rank
130 assert (rank_n >= 0).all() and (rank_n < self.comm.size).all()
131 return rank_n
133 def who_has(self, n):
134 """Convert global band index to rank information and local index."""
135 if self.strided:
136 myn, band_rank = divmod(n, self.comm.size)
137 else:
138 band_rank, myn = divmod(n, self.maxmynbands)
139 return band_rank, myn
141 def global_index(self, myn, band_rank=None):
142 """Convert rank information and local index to global index."""
143 if band_rank is None:
144 band_rank = self.comm.rank
146 if self.strided:
147 return band_rank + myn * self.comm.size
149 return band_rank * self.maxmynbands + myn
151 def get_size_of_global_array(self):
152 return (self.nbands,)
154 def zeros(self, n=(), dtype=float, global_array=False):
155 """Return new zeroed 3D array for this domain.
157 The type can be set with the ``dtype`` keyword (default:
158 ``float``). Extra dimensions can be added with ``n=dim``.
159 A global array spanning all domains can be allocated with
160 ``global_array=True``."""
161 # TODO XXX doc
162 return self._new_array(n, dtype, True, global_array)
164 def empty(self, n=(), dtype=float, global_array=False):
165 """Return new uninitialized 3D array for this domain.
167 The type can be set with the ``dtype`` keyword (default:
168 ``float``). Extra dimensions can be added with ``n=dim``.
169 A global array spanning all domains can be allocated with
170 ``global_array=True``."""
171 # TODO XXX doc
172 return self._new_array(n, dtype, False, global_array)
174 def _new_array(self, n=(), dtype=float, zero=True, global_array=False):
175 if global_array:
176 shape = self.get_size_of_global_array()
177 else:
178 shape = (self.mynbands,)
180 if isinstance(n, int):
181 n = (n,)
183 shape = tuple(shape) + n
185 if zero:
186 return np.zeros(shape, dtype)
187 else:
188 return np.empty(shape, dtype)
190 def collect(self, a_nx, broadcast=False):
191 """Collect distributed array to master-CPU or all CPU's."""
192 if self.comm.size == 1:
193 return a_nx
195 xshape = a_nx.shape[1:]
197 # Optimization for blocked groups
198 if not self.strided:
199 if self.comm.size * self.maxmynbands > self.nbands:
200 return self.nasty_non_strided_collect(a_nx, broadcast)
201 if broadcast:
202 A_nx = self.empty(xshape, a_nx.dtype, global_array=True)
203 self.comm.all_gather(a_nx, A_nx)
204 return A_nx
206 if self.comm.rank == 0:
207 A_nx = self.empty(xshape, a_nx.dtype, global_array=True)
208 else:
209 A_nx = None
210 self.comm.gather(a_nx, 0, A_nx)
211 return A_nx
213 # Collect all arrays on the master:
214 if self.comm.rank != 0:
215 # There can be several sends before the corresponding receives
216 # are posted, so use syncronous send here
217 self.comm.ssend(a_nx, 0, 3011)
218 if broadcast:
219 A_nx = self.empty(xshape, a_nx.dtype, global_array=True)
220 self.comm.broadcast(A_nx, 0)
221 return A_nx
222 else:
223 return None
225 # Put the band groups from the slaves into the big array
226 # for the whole collection of bands:
227 A_nx = self.empty(xshape, a_nx.dtype, global_array=True)
228 for band_rank in range(self.comm.size):
229 if band_rank != 0:
230 a_nx = self.empty(xshape, a_nx.dtype, global_array=False)
231 self.comm.receive(a_nx, band_rank, 3011)
232 A_nx[self.get_slice(band_rank), ...] = a_nx
234 if broadcast:
235 self.comm.broadcast(A_nx, 0)
236 return A_nx
238 def distribute(self, B_nx, b_nx):
239 """ distribute full array B_nx to band groups, result in
240 b_nx. b_nx must be allocated."""
242 S = self.comm.size
244 if S == 1:
245 b_nx[:] = B_nx
246 return
248 # Optimization for blocked groups
249 if not self.strided:
250 M2 = self.maxmynbands
251 if M2 * S == self.nbands:
252 self.comm.scatter(B_nx, b_nx, 0)
253 return
255 if self.comm.rank == 0:
256 C_nx = np.empty((S * M2,) + B_nx.shape[1:], B_nx.dtype)
257 C_nx[:self.nbands] = B_nx
258 else:
259 C_nx = None
261 if self.mynbands < M2:
262 c_nx = np.empty((M2,) + b_nx.shape[1:], b_nx.dtype)
263 else:
264 c_nx = b_nx
266 self.comm.scatter(C_nx, c_nx, 0)
268 if self.mynbands < M2:
269 b_nx[:] = c_nx[:self.mynbands]
271 return
273 if self.comm.rank != 0:
274 self.comm.receive(b_nx, 0, 421)
275 return
276 else:
277 requests = []
278 for band_rank in range(self.comm.size):
279 if band_rank != 0:
280 a_nx = B_nx[self.get_slice(band_rank), ...].copy()
281 request = self.comm.send(a_nx, band_rank, 421, NONBLOCKING)
282 # Remember to store a reference to the
283 # send buffer (a_nx) so that is isn't
284 # deallocated:
285 requests.append((request, a_nx))
286 else:
287 b_nx[:] = B_nx[self.get_slice(), ...]
289 for request, a_nx in requests:
290 self.comm.wait(request)
292 def nasty_non_strided_collect(self, a_nx, broadcast):
293 xshape = a_nx.shape[1:]
294 if broadcast:
295 A_nx = self.nasty_non_strided_collect(a_nx, False)
296 if A_nx is None:
297 A_nx = self.empty(xshape, a_nx.dtype, global_array=True)
298 self.comm.broadcast(A_nx, 0)
299 return A_nx
301 S = self.comm.size
302 M2 = self.maxmynbands
303 if self.comm.rank == 0:
304 A_nx = np.empty((S * M2,) + xshape, a_nx.dtype)
305 else:
306 A_nx = None
308 if self.mynbands < M2:
309 b_nx = np.empty((M2,) + xshape, a_nx.dtype)
310 b_nx[:self.mynbands] = a_nx
311 else:
312 b_nx = a_nx
314 self.comm.gather(b_nx, 0, A_nx)
316 if self.comm.rank > 0:
317 return
319 return A_nx[:self.nbands]