Coverage for gpaw/response/kspair.py: 45%

368 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4from functools import cached_property 

5 

6from gpaw.projections import Projections, serial_comm 

7from gpaw.response import ResponseGroundStateAdapter, ResponseContext, timer 

8from gpaw.response.pw_parallelization import Blocks1D 

9 

10 

11class IrreducibleKPoint: 

12 """Irreducible k-point data pertaining to a certain set of transitions.""" 

13 

14 def __init__(self, ik, eps_h, f_h, Ph, psit_hG, h_myt): 

15 """Construct the IrreducibleKPoint data object. 

16 

17 The data is indexed by the composite band and spin index h = (n, s), 

18 which can be unfolded to the local transition index myt. 

19 """ 

20 self.ik = ik # Irreducible k-point index 

21 self.eps_h = eps_h # Eigenvalues 

22 self.f_h = f_h # Occupation numbers 

23 self.Ph = Ph # PAW projections 

24 self.psit_hG = psit_hG # Pseudo wave function plane-wave components 

25 self.h_myt = h_myt # myt -> h index mapping 

26 

27 @cached_property 

28 def nh(self): 

29 nh = len(self.eps_h) 

30 assert len(self.f_h) == nh 

31 assert self.Ph.nbands == nh 

32 assert len(self.psit_hG) == nh 

33 

34 return nh 

35 

36 @property 

37 def eps_myt(self): 

38 return self.eps_h[self.h_myt] 

39 

40 @property 

41 def f_myt(self): 

42 return self.f_h[self.h_myt] 

43 

44 def projectors_in_transition_index(self, Ph): 

45 Pmyt = Ph.new(nbands=len(self.h_myt), bcomm=None) 

46 Pmyt.array[:] = Ph.array[self.h_myt] 

47 return Pmyt 

48 

49 

50class KohnShamKPointPair: 

51 """Data of pairs of Kohn-Sham orbital pertaining to transitions k -> k'.""" 

52 

53 def __init__(self, K1, K2, ikpt1, ikpt2, transitions, tblocks): 

54 """Construct the KohnShamKPointPair from the k-point data of k and k'. 

55 

56 K1, K2 : int, int 

57 k-point indices of k and k' 

58 ikpt1, ikpt2 : IrreducibleKPoint, IrreducibleKPoint 

59 k-point data of the two specific k-points in the irreducible part 

60 of the BZ which are related to K1 and K2 by symmetry respectively. 

61 """ 

62 

63 self.K1 = K1 

64 self.K2 = K2 

65 self.ikpt1 = ikpt1 

66 self.ikpt2 = ikpt2 

67 self.transitions = transitions 

68 self.tblocks = tblocks 

69 

70 def get_all(self, in_mytx): 

71 """Get a certain data array with all transitions""" 

72 return self.tblocks.all_gather(in_mytx) 

73 

74 @property 

75 def deps_myt(self): 

76 return self.ikpt2.eps_myt - self.ikpt1.eps_myt 

77 

78 @property 

79 def df_myt(self): 

80 return self.ikpt2.f_myt - self.ikpt1.f_myt 

81 

82 def get_local_band_indices(self): 

83 n1_t, n2_t = self.transitions.get_band_indices() 

84 n1_myt = n1_t[self.tblocks.myslice] 

85 n2_myt = n2_t[self.tblocks.myslice] 

86 return n1_myt, n2_myt 

87 

88 def get_local_spin_indices(self): 

89 s1_t, s2_t = self.transitions.get_spin_indices() 

90 s1_myt = s1_t[self.tblocks.myslice] 

91 s2_myt = s2_t[self.tblocks.myslice] 

92 return s1_myt, s2_myt 

93 

94 def get_local_intraband_mask(self): 

95 intraband_t = self.transitions.get_intraband_mask() 

96 return intraband_t[self.tblocks.myslice] 

97 

98 

99class KohnShamKPointPairExtractor: 

100 """Functionality to extract KohnShamKPointPairs from a 

101 ResponseGroundStateAdapter.""" 

102 

103 def __init__(self, gs, context, *, 

104 transitions_blockcomm, kpts_blockcomm): 

105 """ 

106 Parameters 

107 ---------- 

108 gs : ResponseGroundStateAdapter 

109 context : ResponseContext 

110 transitions_blockcomm : gpaw.mpi.Communicator 

111 Communicator to distribute band and spin transitions 

112 kpts_blockcomm : gpaw.mpi.Communicator 

113 Communicator over which the k-point are distributed 

114 """ 

115 assert isinstance(gs, ResponseGroundStateAdapter) 

116 self.gs = gs 

117 assert isinstance(context, ResponseContext) 

118 self.context = context 

119 

120 if self.gs.is_parallelized(): 

121 assert self.context.comm is self.gs.world 

122 # We assume no grid-parallelization in `map_who_has()` 

123 assert self.gs.gd.comm.size == 1 

124 

125 self.transitions_blockcomm = transitions_blockcomm 

126 self.kpts_blockcomm = kpts_blockcomm 

127 

128 # Prepare to distribute transitions 

129 self.tblocks = None 

130 

131 # Prepare to redistribute kptdata 

132 self.rrequests = [] 

133 self.srequests = [] 

134 

135 @timer('Get Kohn-Sham pairs') 

136 def get_kpoint_pairs(self, k1_pc, k2_pc, 

137 transitions) -> KohnShamKPointPair | None: 

138 """Get all pairs of Kohn-Sham orbitals for transitions k -> k' 

139 

140 (n1_t, k1_p, s1_t) -> (n2_t, k2_p, s2_t) 

141 

142 Here, t is a composite band and spin transition index accounted for by 

143 the input PairTransitions object, whereas p indexes the k-point that 

144 each rank of the k-point block communicator needs to extract.""" 

145 assert k1_pc.shape == k2_pc.shape 

146 

147 # Distribute transitions and extract data for transitions in 

148 # this process' block 

149 self.tblocks = Blocks1D(self.transitions_blockcomm, len(transitions)) 

150 

151 K1, ikpt1 = self.get_kpoints(k1_pc, transitions.n1_t, transitions.s1_t) 

152 K2, ikpt2 = self.get_kpoints(k2_pc, transitions.n2_t, transitions.s2_t) 

153 

154 # The process might not have a Kohn-Sham k-point pair to return, due to 

155 # the distribution over kpts_blockcomm 

156 if self.kpts_blockcomm.rank not in range(len(k1_pc)): 

157 return None 

158 

159 assert K1 is not None and ikpt1 is not None 

160 assert K2 is not None and ikpt2 is not None 

161 

162 return KohnShamKPointPair(K1, K2, ikpt1, ikpt2, 

163 transitions, self.tblocks) 

164 

165 def get_kpoints(self, k_pc, n_t, s_t): 

166 """Get the process' own k-point data and help other processes 

167 extracting theirs.""" 

168 assert len(n_t) == len(s_t) 

169 assert len(k_pc) <= self.kpts_blockcomm.size 

170 

171 # Use the data extraction factory to extract the kptdata 

172 kptdata = self.extract_kptdata(k_pc, n_t, s_t) 

173 

174 if self.kpts_blockcomm.rank not in range(len(k_pc)): 

175 return None, None # The process has no data of its own 

176 

177 assert kptdata is not None 

178 K = kptdata[0] 

179 ikpt = IrreducibleKPoint(*kptdata[1:]) 

180 

181 return K, ikpt 

182 

183 @timer('Extracting data from the ground state calculator object') 

184 def extract_kptdata(self, k_pc, n_t, s_t): 

185 """Extract the input data needed to construct the IrreducibleKPoints. 

186 """ 

187 if self.gs.is_parallelized(): 

188 return self.parallel_extract_kptdata(k_pc, n_t, s_t) 

189 else: 

190 return self.serial_extract_kptdata(k_pc, n_t, s_t) 

191 # Useful for debugging: 

192 # return self.parallel_extract_kptdata(k_pc, n_t, s_t) 

193 

194 def parallel_extract_kptdata(self, k_pc, n_t, s_t): 

195 """Extract the k-point data from a parallelized calculator.""" 

196 (myK, myik, myu_eu, 

197 myn_eueh, ik_r2, 

198 nrh_r2, eh_eur2reh, 

199 rh_eur2reh, h_r1rh, 

200 h_myt) = self.get_parallel_extraction_protocol(k_pc, n_t, s_t) 

201 

202 (eps_r1rh, f_r1rh, 

203 P_r1rhI, psit_r1rhG, 

204 eps_r2rh, f_r2rh, 

205 P_r2rhI, psit_r2rhG) = self.allocate_transfer_arrays(myik, nrh_r2, 

206 ik_r2, h_r1rh) 

207 

208 # Do actual extraction 

209 for myu, myn_eh, eh_r2reh, rh_r2reh in zip(myu_eu, myn_eueh, 

210 eh_eur2reh, rh_eur2reh): 

211 

212 eps_eh, f_eh, P_ehI = self.extract_wfs_data(myu, myn_eh) 

213 

214 for r2, (eh_reh, rh_reh) in enumerate(zip(eh_r2reh, rh_r2reh)): 

215 if eh_reh: 

216 eps_r2rh[r2][rh_reh] = eps_eh[eh_reh] 

217 f_r2rh[r2][rh_reh] = f_eh[eh_reh] 

218 P_r2rhI[r2][rh_reh] = P_ehI[eh_reh] 

219 

220 # Wavefunctions are heavy objects which can only be extracted 

221 # for one band index at a time, handle them seperately 

222 self.add_wave_function(myu, myn_eh, eh_r2reh, 

223 rh_r2reh, psit_r2rhG) 

224 

225 self.distribute_extracted_data(eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG, 

226 eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG) 

227 

228 # Some processes may not have to return a k-point 

229 if myik is None: 

230 data = None 

231 else: 

232 eps_h, f_h, Ph, psit_hG = self.collect_kptdata( 

233 myik, h_r1rh, eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG) 

234 data = myK, myik, eps_h, f_h, Ph, psit_hG, h_myt 

235 

236 # Wait for communication to finish 

237 with self.context.timer('Waiting to complete mpi.send'): 

238 while self.srequests: 

239 self.context.comm.wait(self.srequests.pop(0)) 

240 

241 return data 

242 

243 @timer('Create data extraction protocol') 

244 def get_parallel_extraction_protocol(self, k_pc, n_t, s_t): 

245 """Figure out how to extract data efficiently in parallel.""" 

246 comm = self.context.comm 

247 get_extraction_info = self.create_get_extraction_info() 

248 

249 # (K, ik) for each process 

250 mykpt = (None, None) 

251 

252 # Extraction protocol 

253 myu_eu = [] 

254 myn_eueh = [] 

255 

256 # Data distribution protocol 

257 nrh_r2 = np.zeros(comm.size, dtype=int) 

258 ik_r2 = [None for _ in range(comm.size)] 

259 eh_eur2reh = [] 

260 rh_eur2reh = [] 

261 h_r1rh = [list([]) for _ in range(comm.size)] 

262 

263 # h to t index mapping 

264 t_myt = self.tblocks.myslice 

265 n_myt, s_myt = n_t[t_myt], s_t[t_myt] 

266 h_myt = np.empty(self.tblocks.nlocal, dtype=int) 

267 

268 nt = len(n_t) 

269 assert nt == len(s_t) 

270 t_t = np.arange(nt) 

271 nh = 0 

272 for p, k_c in enumerate(k_pc): # p indicates the receiving process 

273 K = self.gs.kpoints.kptfinder.find(k_c) 

274 ik = self.gs.kd.bz2ibz_k[K] 

275 for r2 in range(p * self.tblocks.blockcomm.size, 

276 min((p + 1) * self.tblocks.blockcomm.size, 

277 comm.size)): 

278 ik_r2[r2] = ik 

279 

280 if p == self.kpts_blockcomm.rank: 

281 mykpt = (K, ik) 

282 

283 # Find out who should store the data in KSKPpoint 

284 r2_t, myt_t = self.map_who_has(p, t_t) 

285 

286 # Find out how to extract data 

287 # In the ground state, kpts are indexed by u=(s, k) 

288 for s in set(s_t): 

289 thiss_myt = s_myt == s 

290 thiss_t = s_t == s 

291 t_ct = t_t[thiss_t] 

292 n_ct = n_t[thiss_t] 

293 r2_ct = r2_t[t_ct] 

294 

295 # Find out where data is in GS 

296 u = ik * self.gs.nspins + s 

297 myu, r1_ct, myn_ct = get_extraction_info(u, n_ct, r2_ct) 

298 

299 # If the process is extracting or receiving data, 

300 # figure out how to do so 

301 if comm.rank in np.append(r1_ct, r2_ct): 

302 # Does this process have anything to send? 

303 thisr1_ct = r1_ct == comm.rank 

304 if np.any(thisr1_ct): 

305 eh_r2reh = [list([]) for _ in range(comm.size)] 

306 rh_r2reh = [list([]) for _ in range(comm.size)] 

307 # Find composite indeces h = (n, s) 

308 n_et = n_ct[thisr1_ct] 

309 n_eh = np.unique(n_et) 

310 # Find composite local band indeces 

311 myn_eh = np.unique(myn_ct[thisr1_ct]) 

312 

313 # Where to send the data 

314 r2_et = r2_ct[thisr1_ct] 

315 for r2 in np.unique(r2_et): 

316 thisr2_et = r2_et == r2 

317 # What ns are the process sending? 

318 n_reh = np.unique(n_et[thisr2_et]) 

319 eh_reh = [] 

320 for n in n_reh: 

321 eh_reh.append(np.where(n_eh == n)[0][0]) 

322 # How to send it 

323 eh_r2reh[r2] = eh_reh 

324 nreh = len(eh_reh) 

325 rh_r2reh[r2] = np.arange(nreh) + nrh_r2[r2] 

326 nrh_r2[r2] += nreh 

327 

328 myu_eu.append(myu) 

329 myn_eueh.append(myn_eh) 

330 eh_eur2reh.append(eh_r2reh) 

331 rh_eur2reh.append(rh_r2reh) 

332 

333 # Does this process have anything to receive? 

334 thisr2_ct = r2_ct == comm.rank 

335 if np.any(thisr2_ct): 

336 # Find unique composite indeces h = (n, s) 

337 n_rt = n_ct[thisr2_ct] 

338 n_rn = np.unique(n_rt) 

339 nrn = len(n_rn) 

340 h_rn = np.arange(nrn) + nh 

341 nh += nrn 

342 

343 # Where to get the data from 

344 r1_rt = r1_ct[thisr2_ct] 

345 for r1 in np.unique(r1_rt): 

346 thisr1_rt = r1_rt == r1 

347 # What ns are the process getting? 

348 n_reh = np.unique(n_rt[thisr1_rt]) 

349 # Where to put them 

350 for n in n_reh: 

351 h = h_rn[np.where(n_rn == n)[0][0]] 

352 h_r1rh[r1].append(h) 

353 

354 # h to t mapping 

355 thisn_myt = n_myt == n 

356 thish_myt = np.logical_and(thisn_myt, 

357 thiss_myt) 

358 h_myt[thish_myt] = h 

359 

360 return (*mykpt, myu_eu, myn_eueh, ik_r2, nrh_r2, 

361 eh_eur2reh, rh_eur2reh, h_r1rh, h_myt) 

362 

363 def create_get_extraction_info(self): 

364 """Creator component of the extraction information factory.""" 

365 if self.gs.is_parallelized(): 

366 return self.get_parallel_extraction_info 

367 else: 

368 return self.get_serial_extraction_info 

369 

370 @staticmethod 

371 def get_serial_extraction_info(u, n_ct, r2_ct): 

372 """Figure out where to extract the data from in the gs calc""" 

373 # Let the process extract its own data 

374 myu = u # The process has access to all data 

375 r1_ct = r2_ct 

376 myn_ct = n_ct 

377 

378 return myu, r1_ct, myn_ct 

379 

380 def get_parallel_extraction_info(self, u, n_ct, *unused): 

381 """Figure out where to extract the data from in the gs calc""" 

382 gs = self.gs 

383 # Find out where data is in GS 

384 k, s = divmod(u, gs.nspins) 

385 kptrank, q = gs.kd.who_has(k) 

386 myu = q * gs.nspins + s 

387 r1_ct, myn_ct = [], [] 

388 for n in n_ct: 

389 bandrank, myn = gs.bd.who_has(n) 

390 # XXX this will fail when using non-standard nesting 

391 # of communicators. 

392 r1 = (kptrank * gs.gd.comm.size * gs.bd.comm.size 

393 + bandrank * gs.gd.comm.size) 

394 r1_ct.append(r1) 

395 myn_ct.append(myn) 

396 

397 return myu, np.array(r1_ct), np.array(myn_ct) 

398 

399 @timer('Allocate transfer arrays') 

400 def allocate_transfer_arrays(self, myik, nrh_r2, ik_r2, h_r1rh): 

401 """Allocate arrays for intermediate storage of data.""" 

402 kptex = self.gs.kpt_u[0] 

403 Pshape = kptex.projections.array.shape 

404 Pdtype = kptex.projections.matrix.dtype 

405 psitdtype = kptex.psit.array.dtype 

406 

407 # Number of h-indeces to receive 

408 nrh_r1 = [len(h_rh) for h_rh in h_r1rh] 

409 

410 # if self.kpts_blockcomm.rank in range(len(ik_p)): 

411 if myik is not None: 

412 ng = self.gs.global_pd.ng_q[myik] 

413 eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG = [], [], [], [] 

414 for nrh in nrh_r1: 

415 if nrh >= 1: 

416 eps_r1rh.append(np.empty(nrh)) 

417 f_r1rh.append(np.empty(nrh)) 

418 P_r1rhI.append(np.empty((nrh,) + Pshape[1:], dtype=Pdtype)) 

419 psit_r1rhG.append(np.empty((nrh, ng), dtype=psitdtype)) 

420 else: 

421 eps_r1rh.append(None) 

422 f_r1rh.append(None) 

423 P_r1rhI.append(None) 

424 psit_r1rhG.append(None) 

425 else: 

426 eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG = None, None, None, None 

427 

428 eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG = [], [], [], [] 

429 for nrh, ik in zip(nrh_r2, ik_r2): 

430 if nrh: 

431 eps_r2rh.append(np.empty(nrh)) 

432 f_r2rh.append(np.empty(nrh)) 

433 P_r2rhI.append(np.empty((nrh,) + Pshape[1:], dtype=Pdtype)) 

434 ng = self.gs.global_pd.ng_q[ik] 

435 psit_r2rhG.append(np.empty((nrh, ng), dtype=psitdtype)) 

436 else: 

437 eps_r2rh.append(None) 

438 f_r2rh.append(None) 

439 P_r2rhI.append(None) 

440 psit_r2rhG.append(None) 

441 

442 return (eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG, 

443 eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG) 

444 

445 def map_who_has(self, p, t_t): 

446 """Convert k-point and transition index to global world rank 

447 and local transition index""" 

448 trank_t, myt_t = np.divmod(t_t, self.tblocks.blocksize) 

449 return p * self.tblocks.blockcomm.size + trank_t, myt_t 

450 

451 @timer('Extracting eps, f and P_I from wfs') 

452 def extract_wfs_data(self, myu, myn_eh): 

453 kpt = self.gs.kpt_u[myu] 

454 # Get eig and occ 

455 eps_eh, f_eh = kpt.eps_n[myn_eh], kpt.f_n[myn_eh] / kpt.weight 

456 

457 # Get projections 

458 assert kpt.projections.atom_partition.comm.size == 1 

459 P_ehI = kpt.projections.array[myn_eh] 

460 

461 return eps_eh, f_eh, P_ehI 

462 

463 @timer('Extracting wave function from wfs') 

464 def add_wave_function(self, myu, myn_eh, 

465 eh_r2reh, rh_r2reh, psit_r2rhG): 

466 """Add the plane wave coefficients of the smooth part of 

467 the wave function to the psit_r2rtG arrays.""" 

468 kpt = self.gs.kpt_u[myu] 

469 

470 for eh_reh, rh_reh, psit_rhG in zip(eh_r2reh, rh_r2reh, psit_r2rhG): 

471 if eh_reh: 

472 for eh, rh in zip(eh_reh, rh_reh): 

473 psit_rhG[rh] = kpt.psit_nG[myn_eh[eh]] 

474 

475 @timer('Distributing kptdata') 

476 def distribute_extracted_data(self, eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG, 

477 eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG): 

478 """Send the extracted data to appropriate destinations""" 

479 comm = self.context.comm 

480 # Store the data extracted by the process itself 

481 rank = comm.rank 

482 # Check if there is actually some data to store 

483 if eps_r2rh[rank] is not None: 

484 eps_r1rh[rank] = eps_r2rh[rank] 

485 f_r1rh[rank] = f_r2rh[rank] 

486 P_r1rhI[rank] = P_r2rhI[rank] 

487 psit_r1rhG[rank] = psit_r2rhG[rank] 

488 

489 # Receive data 

490 if eps_r1rh is not None: # The process may not be receiving anything 

491 for r1, (eps_rh, f_rh, 

492 P_rhI, psit_rhG) in enumerate(zip(eps_r1rh, f_r1rh, 

493 P_r1rhI, psit_r1rhG)): 

494 # Check if there is any data to receive 

495 if r1 != rank and eps_rh is not None: 

496 rreq1 = comm.receive(eps_rh, r1, tag=201, block=False) 

497 rreq2 = comm.receive(f_rh, r1, tag=202, block=False) 

498 rreq3 = comm.receive(P_rhI, r1, tag=203, block=False) 

499 rreq4 = comm.receive(psit_rhG, r1, tag=204, block=False) 

500 self.rrequests += [rreq1, rreq2, rreq3, rreq4] 

501 

502 # Send data 

503 for r2, (eps_rh, f_rh, 

504 P_rhI, psit_rhG) in enumerate(zip(eps_r2rh, f_r2rh, 

505 P_r2rhI, psit_r2rhG)): 

506 # Check if there is any data to send 

507 if r2 != rank and eps_rh is not None: 

508 sreq1 = comm.send(eps_rh, r2, tag=201, block=False) 

509 sreq2 = comm.send(f_rh, r2, tag=202, block=False) 

510 sreq3 = comm.send(P_rhI, r2, tag=203, block=False) 

511 sreq4 = comm.send(psit_rhG, r2, tag=204, block=False) 

512 self.srequests += [sreq1, sreq2, sreq3, sreq4] 

513 

514 with self.context.timer('Waiting to complete mpi.receive'): 

515 while self.rrequests: 

516 comm.wait(self.rrequests.pop(0)) 

517 

518 @timer('Collecting kptdata') 

519 def collect_kptdata(self, myik, h_r1rh, 

520 eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG): 

521 """From the extracted data, collect the IrreducibleKPoint data arrays 

522 """ 

523 # Allocate data arrays 

524 maxh_r1 = [max(h_rh) for h_rh in h_r1rh if h_rh] 

525 if maxh_r1: 

526 nh = max(maxh_r1) + 1 

527 else: # Carry around empty arrays 

528 assert self.tblocks.a == self.tblocks.b 

529 nh = 0 

530 eps_h = np.empty(nh) 

531 f_h = np.empty(nh) 

532 Ph = self.new_projections(nh) 

533 psit_hG = self.new_wfs(nh, self.gs.global_pd.ng_q[myik]) 

534 

535 # Store extracted data in the arrays 

536 for (h_rh, eps_rh, 

537 f_rh, P_rhI, psit_rhG) in zip(h_r1rh, eps_r1rh, 

538 f_r1rh, P_r1rhI, psit_r1rhG): 

539 if h_rh: 

540 eps_h[h_rh] = eps_rh 

541 f_h[h_rh] = f_rh 

542 Ph.array[h_rh] = P_rhI 

543 psit_hG[h_rh] = psit_rhG 

544 

545 return eps_h, f_h, Ph, psit_hG 

546 

547 def new_projections(self, nh): 

548 proj = self.gs.kpt_u[0].projections 

549 # We have to initialize the projections by hand, because 

550 # Projections.new() interprets nbands == 0 to imply that it should 

551 # inherit the preexisting number of bands... 

552 return Projections(nh, proj.nproj_a, proj.atom_partition, serial_comm, 

553 proj.collinear, proj.spin, proj.matrix.dtype) 

554 

555 def new_wfs(self, nh, nG): 

556 assert self.gs.dtype == self.gs.kpt_u[0].psit.array.dtype 

557 return np.empty((nh, nG), self.gs.dtype) 

558 

559 def serial_extract_kptdata(self, k_pc, n_t, s_t): 

560 """Extract the k-point data from a serial calculator. 

561 

562 Since all the processes can access all of the data, each process 

563 extracts the data of its own k-point without any need for 

564 communication.""" 

565 if self.kpts_blockcomm.rank not in range(len(k_pc)): 

566 # No data to extract 

567 return None 

568 

569 # Find k-point indeces 

570 k_c = k_pc[self.kpts_blockcomm.rank] 

571 K = self.gs.kpoints.kptfinder.find(k_c) 

572 ik = self.gs.kd.bz2ibz_k[K] 

573 

574 (myu_eu, myn_eurn, nh, 

575 h_eurn, h_myt) = self.get_serial_extraction_protocol(ik, n_t, s_t) 

576 

577 # Allocate transfer arrays 

578 eps_h = np.empty(nh) 

579 f_h = np.empty(nh) 

580 Ph = self.new_projections(nh) 

581 psit_hG = self.new_wfs(nh, self.gs.pd.ng_q[ik]) 

582 

583 # Extract data from the ground state 

584 for myu, myn_rn, h_rn in zip(myu_eu, myn_eurn, h_eurn): 

585 kpt = self.gs.kpt_u[myu] 

586 with self.context.timer('Extracting eps, f and P_I from wfs'): 

587 eps_h[h_rn] = kpt.eps_n[myn_rn] 

588 f_h[h_rn] = kpt.f_n[myn_rn] / kpt.weight 

589 Ph.array[h_rn] = kpt.projections.array[myn_rn] 

590 

591 with self.context.timer('Extracting wave function from wfs'): 

592 for myn, h in zip(myn_rn, h_rn): 

593 psit_hG[h] = kpt.psit_nG[myn] 

594 

595 return K, ik, eps_h, f_h, Ph, psit_hG, h_myt 

596 

597 @timer('Create data extraction protocol') 

598 def get_serial_extraction_protocol(self, ik, n_t, s_t): 

599 """Figure out how to extract data efficiently in serial.""" 

600 

601 # Only extract the transitions handled by the process itself 

602 t_myt = self.tblocks.myslice 

603 n_myt = n_t[t_myt] 

604 s_myt = s_t[t_myt] 

605 

606 # In the ground state, kpts are indexed by u=(s, k) 

607 myu_eu = [] 

608 myn_eurn = [] 

609 nh = 0 

610 h_eurn = [] 

611 h_myt = np.empty(self.tblocks.nlocal, dtype=int) 

612 for s in set(s_myt): 

613 thiss_myt = s_myt == s 

614 n_ct = n_myt[thiss_myt] 

615 

616 # Find unique composite h = (n, u) indeces 

617 n_rn = np.unique(n_ct) 

618 nrn = len(n_rn) 

619 h_eurn.append(np.arange(nrn) + nh) 

620 nh += nrn 

621 

622 # Find mapping between h and the transition index 

623 for n, h in zip(n_rn, h_eurn[-1]): 

624 thisn_myt = n_myt == n 

625 thish_myt = np.logical_and(thisn_myt, thiss_myt) 

626 h_myt[thish_myt] = h 

627 

628 # Find out where data is 

629 u = ik * self.gs.nspins + s 

630 # The process has access to all data 

631 myu = u 

632 myn_rn = n_rn 

633 

634 myu_eu.append(myu) 

635 myn_eurn.append(myn_rn) 

636 

637 return myu_eu, myn_eurn, nh, h_eurn, h_myt