Coverage for gpaw/response/pw_parallelization.py: 57%
143 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1import numpy as np
2from gpaw.blacs import BlacsDescriptor, BlacsGrid, Redistributor
5class Blocks1D:
6 def __init__(self, blockcomm, N):
7 self.blockcomm = blockcomm
8 self.N = N # Global number of points
10 self.blocksize = (N + blockcomm.size - 1) // blockcomm.size
11 self.a = min(blockcomm.rank * self.blocksize, N)
12 self.b = min(self.a + self.blocksize, N)
13 self.nlocal = self.b - self.a
15 self.myslice = slice(self.a, self.b)
17 def all_gather(self, in_myix):
18 """All-gather array where the first dimension is block distributed.
20 Here, myi is understood as the distributed index, whereas x are the
21 remaining (global) dimensions."""
22 # Set up buffers
23 buf_myix = self.local_communication_buffer(in_myix)
24 buf_ix = self.global_communication_buffer(in_myix)
26 # Excecute all-gather
27 self.blockcomm.all_gather(buf_myix, buf_ix)
28 out_ix = buf_ix[:self.N]
30 return out_ix
32 def gather(self, in_myix, root=0):
33 """Gather array to root where the first dimension is block distributed.
34 """
35 assert root in range(self.blockcomm.size)
37 # Set up buffers
38 buf_myix = self.local_communication_buffer(in_myix)
39 if self.blockcomm.rank == root:
40 buf_ix = self.global_communication_buffer(in_myix)
41 else:
42 buf_ix = None
44 # Excecute gather
45 self.blockcomm.gather(buf_myix, root, buf_ix)
46 if self.blockcomm.rank == root:
47 out_ix = buf_ix[:self.N]
48 else:
49 out_ix = None
51 return out_ix
53 def local_communication_buffer(self, in_myix):
54 """Set up local communication buffer."""
55 if in_myix.shape[0] == self.blocksize and in_myix.flags.contiguous:
56 buf_myix = in_myix # Use input array as communication buffer
57 else:
58 assert in_myix.shape[0] == self.nlocal
59 buf_myix = np.empty(
60 (self.blocksize,) + in_myix.shape[1:], in_myix.dtype)
61 buf_myix[:self.nlocal] = in_myix
63 return buf_myix
65 def global_communication_buffer(self, in_myix):
66 """Set up global communication buffer."""
67 buf_ix = np.empty(
68 (self.blockcomm.size * self.blocksize,) + in_myix.shape[1:],
69 dtype=in_myix.dtype)
70 return buf_ix
72 def find_global_index(self, i):
73 """Find rank and local index of the global index i"""
74 rank = i // self.blocksize
75 li = i % self.blocksize
77 return rank, li
80def block_partition(comm, nblocks):
81 r"""Partition the communicator into a 2D array with horizontal
82 and vertical communication.
84 Communication between blocks (blockcomm)
85 <----------------------------------------------->
86 _______________________________________________
87 | | | | | | | | | ⋀
88 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | |
89 |_____|_____|_____|_____|_____|_____|_____|_____| |
90 | | | | | | | | | | Communication inside
91 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | | blocks
92 |_____|_____|_____|_____|_____|_____|_____|_____| | (intrablockcomm)
93 | | | | | | | | | |
94 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | |
95 |_____|_____|_____|_____|_____|_____|_____|_____| ⋁
97 """
98 if nblocks == 'max':
99 # Maximize the number of blocks
100 nblocks = comm.size
101 assert isinstance(nblocks, int)
102 assert nblocks > 0 and nblocks <= comm.size, comm.size
103 assert comm.size % nblocks == 0, comm.size
105 # Communicator between different blocks
106 if nblocks == comm.size:
107 blockcomm = comm
108 else:
109 rank1 = comm.rank // nblocks * nblocks
110 rank2 = rank1 + nblocks
111 blockcomm = comm.new_communicator(range(rank1, rank2))
113 # Communicator inside each block
114 ranks = range(comm.rank % nblocks, comm.size, nblocks)
115 if nblocks == 1:
116 assert len(ranks) == comm.size
117 intrablockcomm = comm
118 else:
119 intrablockcomm = comm.new_communicator(ranks)
121 assert blockcomm.size * intrablockcomm.size == comm.size
123 return blockcomm, intrablockcomm
126class PlaneWaveBlockDistributor:
127 """Functionality to shuffle block distribution of pair functions
128 in the plane wave basis."""
130 def __init__(self, world, blockcomm, intrablockcomm):
131 self.world = world
132 self.blockcomm = blockcomm
133 self.intrablockcomm = intrablockcomm
135 @property
136 def fully_block_distributed(self):
137 return self.world.compare(self.blockcomm) == 'ident'
139 def new_distributor(self, *, nblocks):
140 """Set up a new PlaneWaveBlockDistributor."""
141 world = self.world
142 blockcomm, intrablockcomm = block_partition(comm=world,
143 nblocks=nblocks)
144 blockdist = PlaneWaveBlockDistributor(world, blockcomm, intrablockcomm)
146 return blockdist
148 def _redistribute(self, in_wGG, nw):
149 """Redistribute array.
151 Switch between two kinds of parallel distributions:
153 1) parallel over G-vectors (second dimension of in_wGG)
154 2) parallel over frequency (first dimension of in_wGG)
156 Returns new array using the memory in the 1-d array out_x.
157 """
159 comm = self.blockcomm
161 if comm.size == 1:
162 return in_wGG
164 mynw = (nw + comm.size - 1) // comm.size
165 nG = in_wGG.shape[2]
166 mynG = (nG + comm.size - 1) // comm.size
168 bg1 = BlacsGrid(comm, comm.size, 1)
169 bg2 = BlacsGrid(comm, 1, comm.size)
170 md1 = BlacsDescriptor(bg1, nw, nG**2, mynw, nG**2)
171 md2 = BlacsDescriptor(bg2, nw, nG**2, nw, mynG * nG)
173 if len(in_wGG) == nw:
174 mdin = md2
175 mdout = md1
176 else:
177 mdin = md1
178 mdout = md2
180 r = Redistributor(comm, mdin, mdout)
182 # mdout.shape[1] is always divisible by nG because
183 # every block starts at a multiple of nG, and the last block
184 # ends at nG² which of course also is divisible. Nevertheless:
185 assert mdout.shape[1] % nG == 0
186 # (If it were not divisible, we would "lose" some numbers and the
187 # redistribution would be corrupted.)
189 inbuf = in_wGG.reshape(mdin.shape)
190 # numpy.reshape does not *guarantee* that the reshaped view will
191 # be contiguous. To support redistribution of input arrays with an
192 # arbitrary allocation layout, we make sure that the corresponding
193 # input BLACS buffer in contiguous
194 inbuf = np.ascontiguousarray(inbuf)
196 outbuf = np.empty(mdout.shape, complex)
198 r.redistribute(inbuf, outbuf)
200 outshape = (mdout.shape[0], mdout.shape[1] // nG, nG)
201 out_wGG = outbuf.reshape(outshape)
202 assert out_wGG.flags.contiguous # Since mdout.shape[1] % nG == 0
204 return out_wGG
206 def has_distribution(self, array, nw, distribution):
207 """Check if array 'array' has distribution 'distribution'."""
208 if distribution not in ['wGG', 'WgG', 'zGG', 'ZgG']:
209 raise ValueError(f'Invalid dist_type: {distribution}')
211 comm = self.blockcomm
212 if comm.size == 1:
213 # In serial, all distributions are equivalent
214 return True
216 # At the moment, only wGG and WgG distributions are supported. zGG and
217 # ZgG are complex frequency aliases for wGG and WgG respectively
218 assert len(array.shape) == 3
219 nG = array.shape[2]
220 if array.shape[1] < array.shape[2]:
221 # Looks like array is WgG/ZgG distributed
222 gblocks = Blocks1D(comm, nG)
223 assert array.shape == (nw, gblocks.nlocal, nG)
224 return distribution in ['WgG', 'ZgG']
225 else:
226 # Looks like array is wGG/zGG distributed
227 wblocks = Blocks1D(comm, nw)
228 assert array.shape == (wblocks.nlocal, nG, nG)
229 return distribution in ['wGG', 'zGG']
231 def distribute_as(self, array, nw, distribution):
232 """Redistribute array.
234 Switch between two kinds of parallel distributions:
236 1) parallel over G-vectors (distribution in ['WgG', 'ZgG'])
237 2) parallel over frequency (distribution in ['wGG', 'zGG'])
238 """
239 if self.has_distribution(array, nw, distribution):
240 # If the array already has the requested distribution, do nothing
241 return array
242 else:
243 return self._redistribute(array, nw)
245 def distribute_frequencies(self, in_wGG, nw):
246 """Distribute frequencies to all cores."""
248 world = self.world
249 comm = self.blockcomm
251 if world.size == 1:
252 return in_wGG
254 mynw = (nw + world.size - 1) // world.size
255 nG = in_wGG.shape[2]
256 mynG = (nG + comm.size - 1) // comm.size
258 wa = min(world.rank * mynw, nw)
259 wb = min(wa + mynw, nw)
261 if self.blockcomm.size == 1:
262 return in_wGG[wa:wb].copy()
264 if self.intrablockcomm.rank == 0:
265 bg1 = BlacsGrid(comm, 1, comm.size)
266 in_wGG = in_wGG.reshape((nw, -1))
267 else:
268 bg1 = BlacsGrid(None, 1, 1)
269 # bg1 = DryRunBlacsGrid(mpi.serial_comm, 1, 1)
270 in_wGG = np.zeros((0, 0), complex)
271 md1 = BlacsDescriptor(bg1, nw, nG**2, nw, mynG * nG)
273 bg2 = BlacsGrid(world, world.size, 1)
274 md2 = BlacsDescriptor(bg2, nw, nG**2, mynw, nG**2)
276 r = Redistributor(world, md1, md2)
277 shape = (wb - wa, nG, nG)
278 out_wGG = np.empty(shape, complex)
279 r.redistribute(in_wGG, out_wGG.reshape((wb - wa, nG**2)))
281 return out_wGG