Coverage for gpaw/utilities/grid.py: 67%
138 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1from functools import partial
3import numpy as np
4from gpaw.utilities.grid_redistribute import general_redistribute
5from gpaw.utilities.partition import AtomPartition, AtomicMatrixDistributor
8class GridRedistributor:
9 def __init__(self, comm, broadcast_comm, gd, aux_gd):
10 self.comm = comm
11 self.broadcast_comm = broadcast_comm
12 self.gd = gd
13 self.aux_gd = aux_gd
14 self.enabled = np.any(gd.parsize_c != aux_gd.parsize_c)
16 assert gd.comm.size * broadcast_comm.size == comm.size
17 if self.enabled:
18 assert comm.compare(aux_gd.comm) in ['ident', 'congruent']
19 else:
20 assert gd.comm.compare(aux_gd.comm) in ['ident', 'congruent']
22 if aux_gd.comm.rank == 0:
23 aux_ranks = gd.comm.translate_ranks(aux_gd.comm,
24 np.arange(gd.comm.size))
25 else:
26 aux_ranks = np.empty(gd.comm.size, dtype=int)
27 aux_gd.comm.broadcast(aux_ranks, 0)
29 auxrank2rank = dict(zip(aux_ranks, np.arange(gd.comm.size)))
31 def rank2parpos1(rank):
32 if rank in auxrank2rank:
33 return gd.get_processor_position_from_rank(auxrank2rank[rank])
34 else:
35 return None
37 rank2parpos2 = aux_gd.get_processor_position_from_rank
39 try:
40 gd.n_cp
41 except AttributeError: # AtomPAW
42 self._distribute = self._collect = lambda x: None
43 return # XXX
45 self._distribute = partial(general_redistribute, aux_gd.comm,
46 gd.n_cp, aux_gd.n_cp,
47 rank2parpos1, rank2parpos2)
48 self._collect = partial(general_redistribute, aux_gd.comm,
49 aux_gd.n_cp, gd.n_cp,
50 rank2parpos2, rank2parpos1)
52 def distribute(self, src_xg, dst_xg=None):
53 if not self.enabled:
54 assert src_xg is dst_xg or dst_xg is None
55 return src_xg
56 if dst_xg is None:
57 dst_xg = self.aux_gd.empty(src_xg.shape[:-3], dtype=src_xg.dtype)
58 self._distribute(src_xg, dst_xg)
59 return dst_xg
61 def collect(self, src_xg, dst_xg=None):
62 if not self.enabled:
63 assert src_xg is dst_xg or dst_xg is None
64 return src_xg
65 if dst_xg is None:
66 dst_xg = self.gd.empty(src_xg.shape[:-3], src_xg.dtype)
67 self._collect(src_xg, dst_xg)
68 self.broadcast_comm.broadcast(dst_xg, 0)
69 return dst_xg
71 def get_atom_distributions(self, spos_ac):
72 return AtomDistributions(self.comm, self.broadcast_comm,
73 self.gd, self.aux_gd, spos_ac)
76class AtomDistributions:
77 def __init__(self, comm, broadcast_comm, gd, aux_gd, spos_ac):
78 self.comm = comm
79 self.broadcast_comm = broadcast_comm
80 self.gd = gd
81 self.aux_gd = aux_gd
83 rank_a = gd.get_ranks_from_positions(spos_ac)
84 aux_rank_a = aux_gd.get_ranks_from_positions(spos_ac)
85 self.partition = AtomPartition(gd.comm, rank_a, name='gd')
87 if gd is aux_gd:
88 name = 'aux-unextended'
89 else:
90 name = 'aux-extended'
91 self.aux_partition = AtomPartition(aux_gd.comm, aux_rank_a, name=name)
93 self.work_partition = AtomPartition(comm, np.zeros(len(spos_ac)),
94 name='work').as_even_partition()
96 if gd is aux_gd:
97 aux_broadcast_comm = gd.comm.new_communicator([gd.comm.rank])
98 else:
99 aux_broadcast_comm = broadcast_comm
101 self.aux_dist = AtomicMatrixDistributor(self.partition,
102 aux_broadcast_comm,
103 self.aux_partition)
104 self.work_dist = AtomicMatrixDistributor(self.partition,
105 broadcast_comm,
106 self.work_partition)
108 def to_aux(self, arraydict):
109 if self.gd is self.aux_gd:
110 return arraydict.copy()
111 return self.aux_dist.distribute(arraydict)
113 def from_aux(self, arraydict):
114 if self.gd is self.aux_gd:
115 return arraydict.copy()
116 return self.aux_dist.collect(arraydict)
118 def to_work(self, arraydict):
119 return self.work_dist.distribute(arraydict)
121 def from_work(self, arraydict):
122 return self.work_dist.collect(arraydict)
125def get_domains_from_gd(comm, gd, offset_c=None):
126 ranks = gd.comm.translate_ranks(comm, np.arange(gd.comm.size))
127 assert (ranks >= 0).all(), 'comm not parent of gd.comm'
129 def rank2parpos(rank):
130 gdrank = comm.translate_ranks(gd.comm, np.array([rank]))[0]
131 # XXXXXXXXXXXXX segfault when not passing array!!
132 if gdrank == -1:
133 return None
134 return gd.get_processor_position_from_rank(gdrank)
136 def add_offset(n_cp, offset_c):
137 n_cp = [n_p.copy() for n_p in n_cp]
138 for c in range(3):
139 n_cp[c] += offset_c[c]
140 return n_cp
142 n_cp = gd.n_cp
143 if offset_c is not None:
144 n_cp = add_offset(n_cp, offset_c)
146 return n_cp, rank2parpos
149def grid2grid(comm, gd1, gd2, src_g, dst_g, offset1_c=None, offset2_c=None,
150 xp=np):
151 assert np.all(src_g.shape == gd1.n_c)
152 assert np.all(dst_g.shape == gd2.n_c)
154 n1_cp, rank2parpos1 = get_domains_from_gd(comm, gd1, offset_c=offset1_c)
155 n2_cp, rank2parpos2 = get_domains_from_gd(comm, gd2, offset_c=offset2_c)
157 general_redistribute(comm,
158 n1_cp, n2_cp,
159 rank2parpos1, rank2parpos2,
160 src_g, dst_g, xp=xp)
163def main():
164 from gpaw.grid_descriptor import GridDescriptor
165 from gpaw.mpi import world
167 serial = world.new_communicator([world.rank])
169 # Generator which must run on all ranks
170 gen = np.random.RandomState(0)
172 # This one is just used by master
173 gen_serial = np.random.RandomState(17)
175 maxsize = 5
176 for i in range(1):
177 N1_c = gen.randint(1, maxsize, 3)
178 N2_c = gen.randint(1, maxsize, 3)
180 gd1 = GridDescriptor(N1_c, N1_c)
181 gd2 = GridDescriptor(N2_c, N2_c)
182 serial_gd1 = gd1.new_descriptor(comm=serial)
183 # serial_gd2 = gd2.new_descriptor(comm=serial)
185 a1_serial = serial_gd1.empty()
186 a1_serial.flat[:] = gen_serial.rand(a1_serial.size)
188 if world.rank == 0:
189 print('r0: a1 serial', a1_serial.ravel())
191 a1 = gd1.empty()
192 a1[:] = -1
194 grid2grid(world, serial_gd1, gd1, a1_serial, a1)
196 print(world.rank, 'a1 distributed', a1.ravel())
197 world.barrier()
199 a2 = gd2.zeros()
200 a2[:] = -2
201 grid2grid(world, gd1, gd2, a1, a2)
202 print(world.rank, 'a2 distributed', a2.ravel())
203 world.barrier()
205 gd1 = GridDescriptor(N1_c, N1_c * 0.2)
207 a1 = gd1.empty()
208 a1.flat[:] = gen.rand(a1.size)
210 grid2grid(world, gd1, gd2, a1, a2)
213if __name__ == '__main__':
214 main()