Coverage for gpaw/response/kspair.py: 45%
368 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1from __future__ import annotations
3import numpy as np
4from functools import cached_property
6from gpaw.projections import Projections, serial_comm
7from gpaw.response import ResponseGroundStateAdapter, ResponseContext, timer
8from gpaw.response.pw_parallelization import Blocks1D
11class IrreducibleKPoint:
12 """Irreducible k-point data pertaining to a certain set of transitions."""
14 def __init__(self, ik, eps_h, f_h, Ph, psit_hG, h_myt):
15 """Construct the IrreducibleKPoint data object.
17 The data is indexed by the composite band and spin index h = (n, s),
18 which can be unfolded to the local transition index myt.
19 """
20 self.ik = ik # Irreducible k-point index
21 self.eps_h = eps_h # Eigenvalues
22 self.f_h = f_h # Occupation numbers
23 self.Ph = Ph # PAW projections
24 self.psit_hG = psit_hG # Pseudo wave function plane-wave components
25 self.h_myt = h_myt # myt -> h index mapping
27 @cached_property
28 def nh(self):
29 nh = len(self.eps_h)
30 assert len(self.f_h) == nh
31 assert self.Ph.nbands == nh
32 assert len(self.psit_hG) == nh
34 return nh
36 @property
37 def eps_myt(self):
38 return self.eps_h[self.h_myt]
40 @property
41 def f_myt(self):
42 return self.f_h[self.h_myt]
44 def projectors_in_transition_index(self, Ph):
45 Pmyt = Ph.new(nbands=len(self.h_myt), bcomm=None)
46 Pmyt.array[:] = Ph.array[self.h_myt]
47 return Pmyt
50class KohnShamKPointPair:
51 """Data of pairs of Kohn-Sham orbital pertaining to transitions k -> k'."""
53 def __init__(self, K1, K2, ikpt1, ikpt2, transitions, tblocks):
54 """Construct the KohnShamKPointPair from the k-point data of k and k'.
56 K1, K2 : int, int
57 k-point indices of k and k'
58 ikpt1, ikpt2 : IrreducibleKPoint, IrreducibleKPoint
59 k-point data of the two specific k-points in the irreducible part
60 of the BZ which are related to K1 and K2 by symmetry respectively.
61 """
63 self.K1 = K1
64 self.K2 = K2
65 self.ikpt1 = ikpt1
66 self.ikpt2 = ikpt2
67 self.transitions = transitions
68 self.tblocks = tblocks
70 def get_all(self, in_mytx):
71 """Get a certain data array with all transitions"""
72 return self.tblocks.all_gather(in_mytx)
74 @property
75 def deps_myt(self):
76 return self.ikpt2.eps_myt - self.ikpt1.eps_myt
78 @property
79 def df_myt(self):
80 return self.ikpt2.f_myt - self.ikpt1.f_myt
82 def get_local_band_indices(self):
83 n1_t, n2_t = self.transitions.get_band_indices()
84 n1_myt = n1_t[self.tblocks.myslice]
85 n2_myt = n2_t[self.tblocks.myslice]
86 return n1_myt, n2_myt
88 def get_local_spin_indices(self):
89 s1_t, s2_t = self.transitions.get_spin_indices()
90 s1_myt = s1_t[self.tblocks.myslice]
91 s2_myt = s2_t[self.tblocks.myslice]
92 return s1_myt, s2_myt
94 def get_local_intraband_mask(self):
95 intraband_t = self.transitions.get_intraband_mask()
96 return intraband_t[self.tblocks.myslice]
99class KohnShamKPointPairExtractor:
100 """Functionality to extract KohnShamKPointPairs from a
101 ResponseGroundStateAdapter."""
103 def __init__(self, gs, context, *,
104 transitions_blockcomm, kpts_blockcomm):
105 """
106 Parameters
107 ----------
108 gs : ResponseGroundStateAdapter
109 context : ResponseContext
110 transitions_blockcomm : gpaw.mpi.Communicator
111 Communicator to distribute band and spin transitions
112 kpts_blockcomm : gpaw.mpi.Communicator
113 Communicator over which the k-point are distributed
114 """
115 assert isinstance(gs, ResponseGroundStateAdapter)
116 self.gs = gs
117 assert isinstance(context, ResponseContext)
118 self.context = context
120 if self.gs.is_parallelized():
121 assert self.context.comm is self.gs.world
122 # We assume no grid-parallelization in `map_who_has()`
123 assert self.gs.gd.comm.size == 1
125 self.transitions_blockcomm = transitions_blockcomm
126 self.kpts_blockcomm = kpts_blockcomm
128 # Prepare to distribute transitions
129 self.tblocks = None
131 # Prepare to redistribute kptdata
132 self.rrequests = []
133 self.srequests = []
135 @timer('Get Kohn-Sham pairs')
136 def get_kpoint_pairs(self, k1_pc, k2_pc,
137 transitions) -> KohnShamKPointPair | None:
138 """Get all pairs of Kohn-Sham orbitals for transitions k -> k'
140 (n1_t, k1_p, s1_t) -> (n2_t, k2_p, s2_t)
142 Here, t is a composite band and spin transition index accounted for by
143 the input PairTransitions object, whereas p indexes the k-point that
144 each rank of the k-point block communicator needs to extract."""
145 assert k1_pc.shape == k2_pc.shape
147 # Distribute transitions and extract data for transitions in
148 # this process' block
149 self.tblocks = Blocks1D(self.transitions_blockcomm, len(transitions))
151 K1, ikpt1 = self.get_kpoints(k1_pc, transitions.n1_t, transitions.s1_t)
152 K2, ikpt2 = self.get_kpoints(k2_pc, transitions.n2_t, transitions.s2_t)
154 # The process might not have a Kohn-Sham k-point pair to return, due to
155 # the distribution over kpts_blockcomm
156 if self.kpts_blockcomm.rank not in range(len(k1_pc)):
157 return None
159 assert K1 is not None and ikpt1 is not None
160 assert K2 is not None and ikpt2 is not None
162 return KohnShamKPointPair(K1, K2, ikpt1, ikpt2,
163 transitions, self.tblocks)
165 def get_kpoints(self, k_pc, n_t, s_t):
166 """Get the process' own k-point data and help other processes
167 extracting theirs."""
168 assert len(n_t) == len(s_t)
169 assert len(k_pc) <= self.kpts_blockcomm.size
171 # Use the data extraction factory to extract the kptdata
172 kptdata = self.extract_kptdata(k_pc, n_t, s_t)
174 if self.kpts_blockcomm.rank not in range(len(k_pc)):
175 return None, None # The process has no data of its own
177 assert kptdata is not None
178 K = kptdata[0]
179 ikpt = IrreducibleKPoint(*kptdata[1:])
181 return K, ikpt
183 @timer('Extracting data from the ground state calculator object')
184 def extract_kptdata(self, k_pc, n_t, s_t):
185 """Extract the input data needed to construct the IrreducibleKPoints.
186 """
187 if self.gs.is_parallelized():
188 return self.parallel_extract_kptdata(k_pc, n_t, s_t)
189 else:
190 return self.serial_extract_kptdata(k_pc, n_t, s_t)
191 # Useful for debugging:
192 # return self.parallel_extract_kptdata(k_pc, n_t, s_t)
194 def parallel_extract_kptdata(self, k_pc, n_t, s_t):
195 """Extract the k-point data from a parallelized calculator."""
196 (myK, myik, myu_eu,
197 myn_eueh, ik_r2,
198 nrh_r2, eh_eur2reh,
199 rh_eur2reh, h_r1rh,
200 h_myt) = self.get_parallel_extraction_protocol(k_pc, n_t, s_t)
202 (eps_r1rh, f_r1rh,
203 P_r1rhI, psit_r1rhG,
204 eps_r2rh, f_r2rh,
205 P_r2rhI, psit_r2rhG) = self.allocate_transfer_arrays(myik, nrh_r2,
206 ik_r2, h_r1rh)
208 # Do actual extraction
209 for myu, myn_eh, eh_r2reh, rh_r2reh in zip(myu_eu, myn_eueh,
210 eh_eur2reh, rh_eur2reh):
212 eps_eh, f_eh, P_ehI = self.extract_wfs_data(myu, myn_eh)
214 for r2, (eh_reh, rh_reh) in enumerate(zip(eh_r2reh, rh_r2reh)):
215 if eh_reh:
216 eps_r2rh[r2][rh_reh] = eps_eh[eh_reh]
217 f_r2rh[r2][rh_reh] = f_eh[eh_reh]
218 P_r2rhI[r2][rh_reh] = P_ehI[eh_reh]
220 # Wavefunctions are heavy objects which can only be extracted
221 # for one band index at a time, handle them seperately
222 self.add_wave_function(myu, myn_eh, eh_r2reh,
223 rh_r2reh, psit_r2rhG)
225 self.distribute_extracted_data(eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG,
226 eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG)
228 # Some processes may not have to return a k-point
229 if myik is None:
230 data = None
231 else:
232 eps_h, f_h, Ph, psit_hG = self.collect_kptdata(
233 myik, h_r1rh, eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG)
234 data = myK, myik, eps_h, f_h, Ph, psit_hG, h_myt
236 # Wait for communication to finish
237 with self.context.timer('Waiting to complete mpi.send'):
238 while self.srequests:
239 self.context.comm.wait(self.srequests.pop(0))
241 return data
243 @timer('Create data extraction protocol')
244 def get_parallel_extraction_protocol(self, k_pc, n_t, s_t):
245 """Figure out how to extract data efficiently in parallel."""
246 comm = self.context.comm
247 get_extraction_info = self.create_get_extraction_info()
249 # (K, ik) for each process
250 mykpt = (None, None)
252 # Extraction protocol
253 myu_eu = []
254 myn_eueh = []
256 # Data distribution protocol
257 nrh_r2 = np.zeros(comm.size, dtype=int)
258 ik_r2 = [None for _ in range(comm.size)]
259 eh_eur2reh = []
260 rh_eur2reh = []
261 h_r1rh = [list([]) for _ in range(comm.size)]
263 # h to t index mapping
264 t_myt = self.tblocks.myslice
265 n_myt, s_myt = n_t[t_myt], s_t[t_myt]
266 h_myt = np.empty(self.tblocks.nlocal, dtype=int)
268 nt = len(n_t)
269 assert nt == len(s_t)
270 t_t = np.arange(nt)
271 nh = 0
272 for p, k_c in enumerate(k_pc): # p indicates the receiving process
273 K = self.gs.kpoints.kptfinder.find(k_c)
274 ik = self.gs.kd.bz2ibz_k[K]
275 for r2 in range(p * self.tblocks.blockcomm.size,
276 min((p + 1) * self.tblocks.blockcomm.size,
277 comm.size)):
278 ik_r2[r2] = ik
280 if p == self.kpts_blockcomm.rank:
281 mykpt = (K, ik)
283 # Find out who should store the data in KSKPpoint
284 r2_t, myt_t = self.map_who_has(p, t_t)
286 # Find out how to extract data
287 # In the ground state, kpts are indexed by u=(s, k)
288 for s in set(s_t):
289 thiss_myt = s_myt == s
290 thiss_t = s_t == s
291 t_ct = t_t[thiss_t]
292 n_ct = n_t[thiss_t]
293 r2_ct = r2_t[t_ct]
295 # Find out where data is in GS
296 u = ik * self.gs.nspins + s
297 myu, r1_ct, myn_ct = get_extraction_info(u, n_ct, r2_ct)
299 # If the process is extracting or receiving data,
300 # figure out how to do so
301 if comm.rank in np.append(r1_ct, r2_ct):
302 # Does this process have anything to send?
303 thisr1_ct = r1_ct == comm.rank
304 if np.any(thisr1_ct):
305 eh_r2reh = [list([]) for _ in range(comm.size)]
306 rh_r2reh = [list([]) for _ in range(comm.size)]
307 # Find composite indeces h = (n, s)
308 n_et = n_ct[thisr1_ct]
309 n_eh = np.unique(n_et)
310 # Find composite local band indeces
311 myn_eh = np.unique(myn_ct[thisr1_ct])
313 # Where to send the data
314 r2_et = r2_ct[thisr1_ct]
315 for r2 in np.unique(r2_et):
316 thisr2_et = r2_et == r2
317 # What ns are the process sending?
318 n_reh = np.unique(n_et[thisr2_et])
319 eh_reh = []
320 for n in n_reh:
321 eh_reh.append(np.where(n_eh == n)[0][0])
322 # How to send it
323 eh_r2reh[r2] = eh_reh
324 nreh = len(eh_reh)
325 rh_r2reh[r2] = np.arange(nreh) + nrh_r2[r2]
326 nrh_r2[r2] += nreh
328 myu_eu.append(myu)
329 myn_eueh.append(myn_eh)
330 eh_eur2reh.append(eh_r2reh)
331 rh_eur2reh.append(rh_r2reh)
333 # Does this process have anything to receive?
334 thisr2_ct = r2_ct == comm.rank
335 if np.any(thisr2_ct):
336 # Find unique composite indeces h = (n, s)
337 n_rt = n_ct[thisr2_ct]
338 n_rn = np.unique(n_rt)
339 nrn = len(n_rn)
340 h_rn = np.arange(nrn) + nh
341 nh += nrn
343 # Where to get the data from
344 r1_rt = r1_ct[thisr2_ct]
345 for r1 in np.unique(r1_rt):
346 thisr1_rt = r1_rt == r1
347 # What ns are the process getting?
348 n_reh = np.unique(n_rt[thisr1_rt])
349 # Where to put them
350 for n in n_reh:
351 h = h_rn[np.where(n_rn == n)[0][0]]
352 h_r1rh[r1].append(h)
354 # h to t mapping
355 thisn_myt = n_myt == n
356 thish_myt = np.logical_and(thisn_myt,
357 thiss_myt)
358 h_myt[thish_myt] = h
360 return (*mykpt, myu_eu, myn_eueh, ik_r2, nrh_r2,
361 eh_eur2reh, rh_eur2reh, h_r1rh, h_myt)
363 def create_get_extraction_info(self):
364 """Creator component of the extraction information factory."""
365 if self.gs.is_parallelized():
366 return self.get_parallel_extraction_info
367 else:
368 return self.get_serial_extraction_info
370 @staticmethod
371 def get_serial_extraction_info(u, n_ct, r2_ct):
372 """Figure out where to extract the data from in the gs calc"""
373 # Let the process extract its own data
374 myu = u # The process has access to all data
375 r1_ct = r2_ct
376 myn_ct = n_ct
378 return myu, r1_ct, myn_ct
380 def get_parallel_extraction_info(self, u, n_ct, *unused):
381 """Figure out where to extract the data from in the gs calc"""
382 gs = self.gs
383 # Find out where data is in GS
384 k, s = divmod(u, gs.nspins)
385 kptrank, q = gs.kd.who_has(k)
386 myu = q * gs.nspins + s
387 r1_ct, myn_ct = [], []
388 for n in n_ct:
389 bandrank, myn = gs.bd.who_has(n)
390 # XXX this will fail when using non-standard nesting
391 # of communicators.
392 r1 = (kptrank * gs.gd.comm.size * gs.bd.comm.size
393 + bandrank * gs.gd.comm.size)
394 r1_ct.append(r1)
395 myn_ct.append(myn)
397 return myu, np.array(r1_ct), np.array(myn_ct)
399 @timer('Allocate transfer arrays')
400 def allocate_transfer_arrays(self, myik, nrh_r2, ik_r2, h_r1rh):
401 """Allocate arrays for intermediate storage of data."""
402 kptex = self.gs.kpt_u[0]
403 Pshape = kptex.projections.array.shape
404 Pdtype = kptex.projections.matrix.dtype
405 psitdtype = kptex.psit.array.dtype
407 # Number of h-indeces to receive
408 nrh_r1 = [len(h_rh) for h_rh in h_r1rh]
410 # if self.kpts_blockcomm.rank in range(len(ik_p)):
411 if myik is not None:
412 ng = self.gs.global_pd.ng_q[myik]
413 eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG = [], [], [], []
414 for nrh in nrh_r1:
415 if nrh >= 1:
416 eps_r1rh.append(np.empty(nrh))
417 f_r1rh.append(np.empty(nrh))
418 P_r1rhI.append(np.empty((nrh,) + Pshape[1:], dtype=Pdtype))
419 psit_r1rhG.append(np.empty((nrh, ng), dtype=psitdtype))
420 else:
421 eps_r1rh.append(None)
422 f_r1rh.append(None)
423 P_r1rhI.append(None)
424 psit_r1rhG.append(None)
425 else:
426 eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG = None, None, None, None
428 eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG = [], [], [], []
429 for nrh, ik in zip(nrh_r2, ik_r2):
430 if nrh:
431 eps_r2rh.append(np.empty(nrh))
432 f_r2rh.append(np.empty(nrh))
433 P_r2rhI.append(np.empty((nrh,) + Pshape[1:], dtype=Pdtype))
434 ng = self.gs.global_pd.ng_q[ik]
435 psit_r2rhG.append(np.empty((nrh, ng), dtype=psitdtype))
436 else:
437 eps_r2rh.append(None)
438 f_r2rh.append(None)
439 P_r2rhI.append(None)
440 psit_r2rhG.append(None)
442 return (eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG,
443 eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG)
445 def map_who_has(self, p, t_t):
446 """Convert k-point and transition index to global world rank
447 and local transition index"""
448 trank_t, myt_t = np.divmod(t_t, self.tblocks.blocksize)
449 return p * self.tblocks.blockcomm.size + trank_t, myt_t
451 @timer('Extracting eps, f and P_I from wfs')
452 def extract_wfs_data(self, myu, myn_eh):
453 kpt = self.gs.kpt_u[myu]
454 # Get eig and occ
455 eps_eh, f_eh = kpt.eps_n[myn_eh], kpt.f_n[myn_eh] / kpt.weight
457 # Get projections
458 assert kpt.projections.atom_partition.comm.size == 1
459 P_ehI = kpt.projections.array[myn_eh]
461 return eps_eh, f_eh, P_ehI
463 @timer('Extracting wave function from wfs')
464 def add_wave_function(self, myu, myn_eh,
465 eh_r2reh, rh_r2reh, psit_r2rhG):
466 """Add the plane wave coefficients of the smooth part of
467 the wave function to the psit_r2rtG arrays."""
468 kpt = self.gs.kpt_u[myu]
470 for eh_reh, rh_reh, psit_rhG in zip(eh_r2reh, rh_r2reh, psit_r2rhG):
471 if eh_reh:
472 for eh, rh in zip(eh_reh, rh_reh):
473 psit_rhG[rh] = kpt.psit_nG[myn_eh[eh]]
475 @timer('Distributing kptdata')
476 def distribute_extracted_data(self, eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG,
477 eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG):
478 """Send the extracted data to appropriate destinations"""
479 comm = self.context.comm
480 # Store the data extracted by the process itself
481 rank = comm.rank
482 # Check if there is actually some data to store
483 if eps_r2rh[rank] is not None:
484 eps_r1rh[rank] = eps_r2rh[rank]
485 f_r1rh[rank] = f_r2rh[rank]
486 P_r1rhI[rank] = P_r2rhI[rank]
487 psit_r1rhG[rank] = psit_r2rhG[rank]
489 # Receive data
490 if eps_r1rh is not None: # The process may not be receiving anything
491 for r1, (eps_rh, f_rh,
492 P_rhI, psit_rhG) in enumerate(zip(eps_r1rh, f_r1rh,
493 P_r1rhI, psit_r1rhG)):
494 # Check if there is any data to receive
495 if r1 != rank and eps_rh is not None:
496 rreq1 = comm.receive(eps_rh, r1, tag=201, block=False)
497 rreq2 = comm.receive(f_rh, r1, tag=202, block=False)
498 rreq3 = comm.receive(P_rhI, r1, tag=203, block=False)
499 rreq4 = comm.receive(psit_rhG, r1, tag=204, block=False)
500 self.rrequests += [rreq1, rreq2, rreq3, rreq4]
502 # Send data
503 for r2, (eps_rh, f_rh,
504 P_rhI, psit_rhG) in enumerate(zip(eps_r2rh, f_r2rh,
505 P_r2rhI, psit_r2rhG)):
506 # Check if there is any data to send
507 if r2 != rank and eps_rh is not None:
508 sreq1 = comm.send(eps_rh, r2, tag=201, block=False)
509 sreq2 = comm.send(f_rh, r2, tag=202, block=False)
510 sreq3 = comm.send(P_rhI, r2, tag=203, block=False)
511 sreq4 = comm.send(psit_rhG, r2, tag=204, block=False)
512 self.srequests += [sreq1, sreq2, sreq3, sreq4]
514 with self.context.timer('Waiting to complete mpi.receive'):
515 while self.rrequests:
516 comm.wait(self.rrequests.pop(0))
518 @timer('Collecting kptdata')
519 def collect_kptdata(self, myik, h_r1rh,
520 eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG):
521 """From the extracted data, collect the IrreducibleKPoint data arrays
522 """
523 # Allocate data arrays
524 maxh_r1 = [max(h_rh) for h_rh in h_r1rh if h_rh]
525 if maxh_r1:
526 nh = max(maxh_r1) + 1
527 else: # Carry around empty arrays
528 assert self.tblocks.a == self.tblocks.b
529 nh = 0
530 eps_h = np.empty(nh)
531 f_h = np.empty(nh)
532 Ph = self.new_projections(nh)
533 psit_hG = self.new_wfs(nh, self.gs.global_pd.ng_q[myik])
535 # Store extracted data in the arrays
536 for (h_rh, eps_rh,
537 f_rh, P_rhI, psit_rhG) in zip(h_r1rh, eps_r1rh,
538 f_r1rh, P_r1rhI, psit_r1rhG):
539 if h_rh:
540 eps_h[h_rh] = eps_rh
541 f_h[h_rh] = f_rh
542 Ph.array[h_rh] = P_rhI
543 psit_hG[h_rh] = psit_rhG
545 return eps_h, f_h, Ph, psit_hG
547 def new_projections(self, nh):
548 proj = self.gs.kpt_u[0].projections
549 # We have to initialize the projections by hand, because
550 # Projections.new() interprets nbands == 0 to imply that it should
551 # inherit the preexisting number of bands...
552 return Projections(nh, proj.nproj_a, proj.atom_partition, serial_comm,
553 proj.collinear, proj.spin, proj.matrix.dtype)
555 def new_wfs(self, nh, nG):
556 assert self.gs.dtype == self.gs.kpt_u[0].psit.array.dtype
557 return np.empty((nh, nG), self.gs.dtype)
559 def serial_extract_kptdata(self, k_pc, n_t, s_t):
560 """Extract the k-point data from a serial calculator.
562 Since all the processes can access all of the data, each process
563 extracts the data of its own k-point without any need for
564 communication."""
565 if self.kpts_blockcomm.rank not in range(len(k_pc)):
566 # No data to extract
567 return None
569 # Find k-point indeces
570 k_c = k_pc[self.kpts_blockcomm.rank]
571 K = self.gs.kpoints.kptfinder.find(k_c)
572 ik = self.gs.kd.bz2ibz_k[K]
574 (myu_eu, myn_eurn, nh,
575 h_eurn, h_myt) = self.get_serial_extraction_protocol(ik, n_t, s_t)
577 # Allocate transfer arrays
578 eps_h = np.empty(nh)
579 f_h = np.empty(nh)
580 Ph = self.new_projections(nh)
581 psit_hG = self.new_wfs(nh, self.gs.pd.ng_q[ik])
583 # Extract data from the ground state
584 for myu, myn_rn, h_rn in zip(myu_eu, myn_eurn, h_eurn):
585 kpt = self.gs.kpt_u[myu]
586 with self.context.timer('Extracting eps, f and P_I from wfs'):
587 eps_h[h_rn] = kpt.eps_n[myn_rn]
588 f_h[h_rn] = kpt.f_n[myn_rn] / kpt.weight
589 Ph.array[h_rn] = kpt.projections.array[myn_rn]
591 with self.context.timer('Extracting wave function from wfs'):
592 for myn, h in zip(myn_rn, h_rn):
593 psit_hG[h] = kpt.psit_nG[myn]
595 return K, ik, eps_h, f_h, Ph, psit_hG, h_myt
597 @timer('Create data extraction protocol')
598 def get_serial_extraction_protocol(self, ik, n_t, s_t):
599 """Figure out how to extract data efficiently in serial."""
601 # Only extract the transitions handled by the process itself
602 t_myt = self.tblocks.myslice
603 n_myt = n_t[t_myt]
604 s_myt = s_t[t_myt]
606 # In the ground state, kpts are indexed by u=(s, k)
607 myu_eu = []
608 myn_eurn = []
609 nh = 0
610 h_eurn = []
611 h_myt = np.empty(self.tblocks.nlocal, dtype=int)
612 for s in set(s_myt):
613 thiss_myt = s_myt == s
614 n_ct = n_myt[thiss_myt]
616 # Find unique composite h = (n, u) indeces
617 n_rn = np.unique(n_ct)
618 nrn = len(n_rn)
619 h_eurn.append(np.arange(nrn) + nh)
620 nh += nrn
622 # Find mapping between h and the transition index
623 for n, h in zip(n_rn, h_eurn[-1]):
624 thisn_myt = n_myt == n
625 thish_myt = np.logical_and(thisn_myt, thiss_myt)
626 h_myt[thish_myt] = h
628 # Find out where data is
629 u = ik * self.gs.nspins + s
630 # The process has access to all data
631 myu = u
632 myn_rn = n_rn
634 myu_eu.append(myu)
635 myn_eurn.append(myn_rn)
637 return myu_eu, myn_eurn, nh, h_eurn, h_myt