Coverage for gpaw/lrtddft/kssingle.py: 95%

416 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1"""Kohn-Sham single particle excitations realated objects. 

2 

3""" 

4import sys 

5import json 

6import numpy as np 

7from copy import copy 

8 

9from ase.units import Bohr, Hartree, alpha 

10 

11import gpaw.mpi as mpi 

12from gpaw.utilities import packed_index 

13from gpaw.lrtddft.excitation import Excitation, ExcitationList, get_filehandle 

14from gpaw.pair_density import PairDensity 

15from gpaw.fd_operators import Gradient 

16from gpaw.utilities.tools import coordinates 

17from .kssrestrictor import KSSRestrictor 

18 

19 

20class KSSingles(ExcitationList): 

21 """Kohn-Sham single particle excitations 

22 

23 Input parameters: 

24 

25 calculator: 

26 the calculator object after a ground state calculation 

27 

28 nspins: 

29 number of spins considered in the calculation 

30 Note: Valid only for unpolarised ground state calculation 

31 

32 eps: 

33 Minimal occupation difference for a transition (default 0.001) 

34 

35 istart: 

36 First occupied state to consider 

37 jend: 

38 Last unoccupied state to consider 

39 energy_range: 

40 The energy range [emin, emax] or emax for KS transitions to use as basis 

41 """ 

42 def __init__(self, 

43 restrict={}, 

44 log=None, 

45 txt=None): 

46 ExcitationList.__init__(self, log=log, txt=txt) 

47 self.world = mpi.world 

48 

49 self.restrict = KSSRestrictor() 

50 self.restrict.update(restrict) 

51 

52 def calculate(self, atoms, nspins=None): 

53 calculator = atoms.calc 

54 self.calculator = calculator 

55 

56 # LCAO calculation requires special actions 

57 self.lcao = calculator.wfs.mode == 'lcao' 

58 

59 # deny hybrids as their empty states are wrong 

60# gsxc = calculator.hamiltonian.xc 

61# hybrid = hasattr(gsxc, 'hybrid') and gsxc.hybrid > 0.0 

62# assert(not hybrid) 

63 

64 # ensure correctly initialized wave functions 

65 calculator.converge_wave_functions() 

66 self.world = calculator.wfs.world 

67 

68 # parallelization over bands not yet supported 

69 assert calculator.wfs.bd.comm.size == 1 

70 

71 # do the evaluation 

72 self.select(nspins) 

73 

74 trkm = self.get_trk() 

75 self.log('KSS {} transitions (restrict={})'.format( 

76 len(self), self.restrict)) 

77 self.log('KSS TRK sum %g (%g,%g,%g)' % 

78 (np.sum(trkm) / 3., trkm[0], trkm[1], trkm[2])) 

79 pol = self.get_polarizabilities(lmax=3) 

80 self.log('KSS polarisabilities(l=0-3) %g, %g, %g, %g' % 

81 tuple(pol.tolist())) 

82 return self 

83 

84 @staticmethod 

85 def emin_emax(energy_range): 

86 emin = -sys.float_info.max 

87 emax = sys.float_info.max 

88 if energy_range is not None: 

89 try: 

90 emin, emax = energy_range 

91 emin /= Hartree 

92 emax /= Hartree 

93 except TypeError: 

94 emax = energy_range / Hartree 

95 return emin, emax 

96 

97 def select(self, nspins=None): 

98 """Select KSSingles according to the given criterium.""" 

99 

100 # criteria 

101 emin, emax = self.restrict.emin_emax() 

102 istart = self.restrict['istart'] 

103 jend = self.restrict['jend'] 

104 eps = self.restrict['eps'] 

105 

106 if not hasattr(self, 'calculator'): # I'm read from a file 

107 # throw away all not needed entries 

108 for i, ks in reversed(list(enumerate(self))): 

109 if not self.restrict.is_good(ks): 

110 del self[i] 

111 return None 

112 

113 paw = self.calculator 

114 wfs = paw.wfs 

115 self.dtype = wfs.dtype 

116 self.kpt_u = wfs.kpt_u 

117 

118 if not self.lcao and self.kpt_u[0].psit_nG is None: 

119 raise RuntimeError('No wave functions in calculator!') 

120 

121 # here, we need to take care of the spins also for 

122 # closed shell systems (Sz=0) 

123 # vspin is the virtual spin of the wave functions, 

124 # i.e. the spin used in the ground state calculation 

125 # pspin is the physical spin of the wave functions 

126 # i.e. the spin of the excited states 

127 self.nvspins = wfs.nspins 

128 self.npspins = wfs.nspins 

129 fijscale = 1 

130 ispins = [0] 

131 nks = wfs.kd.nibzkpts * wfs.kd.nspins 

132 if self.nvspins < 2: 

133 if (nspins or 0) > self.nvspins: 

134 self.npspins = nspins 

135 fijscale = 0.5 

136 ispins = [0, 1] 

137 nks *= 2 

138 

139 kpt_comm = self.calculator.wfs.kd.comm 

140 nbands = len(self.kpt_u[0].f_n) 

141 

142 # select 

143 take = np.zeros((nks, nbands, nbands), dtype=int) 

144 u = 0 

145 for ispin in ispins: 

146 for k in range(wfs.kd.nibzkpts): 

147 q = k - wfs.kd.k0 

148 for s in range(wfs.nspins): 

149 if q >= 0 and q < wfs.kd.mynk: 

150 kpt = wfs.kpt_qs[q][s] 

151 for i in range(nbands): 

152 for j in range(i + 1, nbands): 

153 fij = (kpt.f_n[i] - kpt.f_n[j]) / kpt.weight 

154 epsij = kpt.eps_n[j] - kpt.eps_n[i] 

155 if (fij > eps and 

156 epsij >= emin and epsij < emax and 

157 i >= istart and j <= jend): 

158 take[u, i, j] = 1 

159 u += 1 

160 kpt_comm.sum(take) 

161 

162 self.log() 

163 self.log('Kohn-Sham single transitions') 

164 self.log() 

165 

166 # calculate in parallel 

167 u = 0 

168 for ispin in ispins: 

169 for k in range(wfs.kd.nibzkpts): 

170 q = k - wfs.kd.k0 

171 for s in range(wfs.kd.nspins): 

172 for i in range(nbands): 

173 for j in range(i + 1, nbands): 

174 if take[u, i, j]: 

175 if q >= 0 and q < wfs.kd.mynk: 

176 kpt = wfs.kpt_qs[q][s] 

177 pspin = max(kpt.s, ispin) 

178 self.append( 

179 KSSingle(i, j, pspin, kpt, paw, 

180 fijscale=fijscale, 

181 dtype=self.dtype)) 

182 else: 

183 self.append(KSSingle(i, j, pspin=0, 

184 kpt=None, paw=paw, 

185 dtype=self.dtype)) 

186 u += 1 

187 

188 # distribute 

189 for kss in self: 

190 kss.distribute() 

191 

192 @classmethod 

193 def read(cls, filename=None, fh=None, restrict={}, log=None): 

194 """Read myself from a file""" 

195 assert (filename is not None) or (fh is not None) 

196 

197 def fail(f): 

198 raise RuntimeError(f.name + ' does not contain ' + 

199 cls.__class__.__name__ + ' data') 

200 if fh is None: 

201 f = get_filehandle(cls, filename) 

202 

203 # there can be other information, i.e. the LrTDDFT header 

204 try: 

205 content = f.read() 

206 f.seek(content.index('# KSSingles')) 

207 del content 

208 f.readline() 

209 except ValueError: 

210 fail(f) 

211 else: 

212 f = fh 

213 # we assume to be at the right place and read the header 

214 if not f.readline().strip() == '# KSSingles': 

215 fail(f) 

216 

217 words = f.readline().split() 

218 n = int(words[0]) 

219 kssl = cls(log=log) 

220 if len(words) == 1: 

221 # very old output style for real wave functions (finite systems) 

222 kssl.dtype = float 

223 restrict_from_file = {} 

224 else: 

225 if words[1].startswith('complex'): 

226 kssl.dtype = complex 

227 else: 

228 kssl.dtype = float 

229 restrict_from_file = json.loads(f.readline()) 

230 if not isinstance(restrict_from_file, dict): # old output style 

231 restrict_from_file = {'eps': restrict_from_file} 

232 kssl.npspins = 1 

233 for i in range(n): 

234 kss = KSSingle(string=f.readline(), dtype=kssl.dtype) 

235 kssl.append(kss) 

236 kssl.npspins = max(kssl.npspins, kss.pspin + 1) 

237 

238 if fh is None: 

239 f.close() 

240 

241 kssl.update() 

242 kssl.restrict.update(restrict_from_file) 

243 if len(restrict): 

244 kssl.restrict.update(restrict) 

245 kssl.select() 

246 

247 return kssl 

248 

249 def update(self): 

250 istart = self[0].i 

251 jend = 0 

252 npspins = 1 

253 nvspins = 1 

254 for kss in self: 

255 istart = min(kss.i, istart) 

256 jend = max(kss.j, jend) 

257 if kss.pspin == 1: 

258 npspins = 2 

259 if kss.spin == 1: 

260 nvspins = 2 

261 self.restrict.update({'istart': istart, 'jend': jend}) 

262 self.npspins = npspins 

263 self.nvspins = nvspins 

264 

265 if hasattr(self, 'energies'): 

266 del self.energies 

267 

268 def set_arrays(self): 

269 if hasattr(self, 'energies'): 

270 return 

271 energies = [] 

272 fij = [] 

273 me = [] 

274 mur = [] 

275 muv = [] 

276 magn = [] 

277 for k in self: 

278 energies.append(k.energy) 

279 fij.append(k.fij) 

280 me.append(k.me) 

281 mur.append(k.mur) 

282 if k.muv is not None: 

283 muv.append(k.muv) 

284 if k.magn is not None: 

285 magn.append(k.magn) 

286 self.energies = np.array(energies) 

287 self.fij = np.array(fij) 

288 self.me = np.array(me) 

289 self.mur = np.array(mur) 

290 if len(muv): 

291 self.muv = np.array(muv) 

292 else: 

293 self.muv = None 

294 if len(magn): 

295 self.magn = np.array(magn) 

296 else: 

297 self.magn = None 

298 

299 def write(self, filename=None, fh=None): 

300 """Write current state to a file. 

301 

302 'filename' is the filename. If the filename ends in .gz, 

303 the file is automatically saved in compressed gzip format. 

304 

305 'fh' is a filehandle. This can be used to write into already 

306 opened files. 

307 """ 

308 if self.world.rank != 0: 

309 return 

310 

311 if fh is None: 

312 f = get_filehandle(self, filename, mode='w') 

313 else: 

314 f = fh 

315 

316 f.write('# KSSingles\n') 

317 f.write(f'{len(self)} {np.dtype(self.dtype)}\n') 

318 f.write(json.dumps(self.restrict.values) + '\n') 

319 for kss in self: 

320 f.write(kss.outstring()) 

321 if fh is None: 

322 f.close() 

323 

324 def overlap(self, ov_nn, other): 

325 """Matrix element overlaps determined from wave function overlaps. 

326 

327 Parameters 

328 ---------- 

329 ov_nn: array 

330 Wave function overlap factors from a displaced calculator. 

331 Index 0 corresponds to our own wavefunctions conjugated and 

332 index 1 to the others' wavefunctions 

333 

334 Returns 

335 ------- 

336 ov_pp: array 

337 Overlap corresponding to matrix elements. 

338 Index 0 corresponds to our own matrix elements conjugated and 

339 index 1 to the others' matrix elements 

340 """ 

341 n0 = len(self) 

342 n1 = len(other) 

343 ov_pp = np.zeros((n0, n1), dtype=ov_nn.dtype) 

344 i1_p = [ex.i for ex in other] 

345 j1_p = [ex.j for ex in other] 

346 for p0, ex0 in enumerate(self): 

347 ov_pp[p0, :] = ov_nn[ex0.i, i1_p].conj() * ov_nn[ex0.j, j1_p] 

348 return ov_pp 

349 

350 

351class KSSingle(Excitation, PairDensity): 

352 """Single Kohn-Sham transition containing all its indices 

353 

354 pspin=physical spin 

355 spin=virtual spin, i.e. spin in the ground state calc. 

356 kpt=the Kpoint object 

357 fijscale=weight for the occupation difference:: 

358 me = sqrt(fij*epsij) * <i|r|j> 

359 mur = - <i|r|a> 

360 muv = - <i|nabla|a>/omega_ia with omega_ia>0 

361 magn = <i|[r x nabla]|a> / (2 m_e c) 

362 """ 

363 

364 def __init__(self, iidx=None, jidx=None, pspin=None, kpt=None, 

365 paw=None, string=None, fijscale=1, dtype=float): 

366 """ 

367 iidx: index of occupied state 

368 jidx: index of empty state 

369 pspin: physical spin 

370 kpt: kpoint object, 

371 paw: calculator, 

372 string: string to be initialized from 

373 fijscale: 

374 dtype: dtype of matrix elements 

375 """ 

376 if string is not None: 

377 self.fromstring(string, dtype) 

378 return None 

379 

380 # normal entry 

381 

382 PairDensity.__init__(self, paw) 

383 PairDensity.initialize(self, kpt, iidx, jidx) 

384 

385 self.pspin = pspin 

386 

387 self.energy = 0.0 

388 self.fij = 0.0 

389 

390 self.me = np.zeros((3), dtype=dtype) 

391 self.mur = np.zeros((3), dtype=dtype) 

392 self.muv = np.zeros((3), dtype=dtype) 

393 self.magn = np.zeros((3), dtype=dtype) 

394 

395 self.kpt_comm = paw.wfs.kd.comm 

396 

397 # leave empty if not my kpt 

398 if kpt is None: 

399 return 

400 

401 wfs = paw.wfs 

402 gd = wfs.gd 

403 

404 self.energy = kpt.eps_n[jidx] - kpt.eps_n[iidx] 

405 self.fij = (kpt.f_n[iidx] - kpt.f_n[jidx]) * fijscale 

406 

407 # calculate matrix elements ----------- 

408 

409 # length form .......................... 

410 

411 # course grid contribution 

412 # <i|r|j> is the negative of the dipole moment (because of negative 

413 # e- charge) 

414 me = - gd.calculate_dipole_moment(self.get()) 

415 

416 # augmentation contributions 

417 ma = np.zeros(me.shape, dtype=dtype) 

418 pos_av = paw.atoms.get_positions() / Bohr 

419 for a, P_ni in kpt.P_ani.items(): 

420 Ra = pos_av[a] 

421 Pi_i = P_ni[self.i].conj() 

422 Pj_i = P_ni[self.j] 

423 Delta_pL = wfs.setups[a].Delta_pL 

424 ni = len(Pi_i) 

425 ma0 = 0 

426 ma1 = np.zeros(me.shape, dtype=me.dtype) 

427 for i in range(ni): 

428 for j in range(ni): 

429 pij = Pi_i[i] * Pj_i[j] 

430 ij = packed_index(i, j, ni) 

431 # L=0 term 

432 ma0 += Delta_pL[ij, 0] * pij 

433 # L=1 terms 

434 if wfs.setups[a].lmax >= 1: 

435 # see spherical_harmonics.py for 

436 # L=1:y L=2:z; L=3:x 

437 ma1 += np.array([Delta_pL[ij, 3], Delta_pL[ij, 1], 

438 Delta_pL[ij, 2]]) * pij 

439 ma += np.sqrt(4 * np.pi / 3) * ma1 + Ra * np.sqrt(4 * np.pi) * ma0 

440 gd.comm.sum(ma) 

441 

442 self.me = np.sqrt(self.energy * self.fij) * (me + ma) 

443 self.mur = - (me + ma) 

444 

445 # velocity form ............................. 

446 

447 if self.lcao: 

448 self.wfi = _get_and_distribute_wf(wfs, iidx, kpt.k, pspin) 

449 self.wfj = _get_and_distribute_wf(wfs, jidx, kpt.k, pspin) 

450 

451 me = np.zeros(self.mur.shape, dtype=dtype) 

452 

453 # get derivatives 

454 dtype = self.wfj.dtype 

455 dwfj_cg = gd.empty((3), dtype=dtype) 

456 if not hasattr(gd, 'ddr'): 

457 gd.ddr = [Gradient(gd, c, dtype=dtype, n=2).apply 

458 for c in range(3)] 

459 for c in range(3): 

460 gd.ddr[c](self.wfj, dwfj_cg[c], kpt.phase_cd) 

461 me[c] = gd.integrate(self.wfi.conj() * dwfj_cg[c]) 

462 

463 # XXX is this the best choice, maybe center of mass? 

464 origin = 0.5 * np.diag(paw.wfs.gd.cell_cv) 

465 

466 # augmentation contributions 

467 

468 # <psi_i|grad|psi_j> 

469 ma = np.zeros(me.shape, dtype=me.dtype) 

470 # Ra x <psi_i|grad|psi_j> for magnetic transition dipole 

471 mRa = np.zeros(me.shape, dtype=me.dtype) 

472 for a, P_ni in kpt.P_ani.items(): 

473 Pi_i = P_ni[self.i].conj() 

474 Pj_i = P_ni[self.j] 

475 nabla_iiv = paw.wfs.setups[a].nabla_iiv 

476 ma_c = np.zeros(me.shape, dtype=me.dtype) 

477 for c in range(3): 

478 for i1, Pi in enumerate(Pi_i): 

479 for i2, Pj in enumerate(Pj_i): 

480 ma_c[c] += Pi * Pj * nabla_iiv[i1, i2, c] 

481 mRa += np.cross(paw.atoms[a].position / Bohr - origin, ma_c) 

482 ma += ma_c 

483 gd.comm.sum(ma) 

484 gd.comm.sum(mRa) 

485 

486 self.muv = - (me + ma) / self.energy 

487 

488 # magnetic transition dipole ................ 

489 

490 # m_ij = -(1/2c) <i|L|j> = i/2c <i|r x p|j> 

491 # see Autschbach et al., J. Chem. Phys., 116, 6930 (2002) 

492 

493 r_cg, r2_g = coordinates(gd, origin=origin) 

494 magn = np.zeros(me.shape, dtype=dtype) 

495 

496 # <psi_i|r x grad|psi_j> 

497 wfi_g = self.wfi.conj() 

498 for ci in range(3): 

499 cj = (ci + 1) % 3 

500 ck = (ci + 2) % 3 

501 magn[ci] = gd.integrate(wfi_g * r_cg[cj] * dwfj_cg[ck] - 

502 wfi_g * r_cg[ck] * dwfj_cg[cj]) 

503 

504 # augmentation contributions 

505 # <psi_i| r x nabla |psi_j> 

506 # = <psi_i| (r - Ra + Ra) x nabla |psi_j> 

507 # = <psi_i| (r - Ra) x nabla |psi_j> + Ra x <psi_i| nabla |psi_j> 

508 

509 ma = np.zeros(magn.shape, dtype=magn.dtype) 

510 for a, P_ni in kpt.P_ani.items(): 

511 Pi_i = P_ni[self.i].conj() 

512 Pj_i = P_ni[self.j] 

513 rxnabla_iiv = paw.wfs.setups[a].rxnabla_iiv 

514 for c in range(3): 

515 for i1, Pi in enumerate(Pi_i): 

516 for i2, Pj in enumerate(Pj_i): 

517 ma[c] += Pi * Pj * rxnabla_iiv[i1, i2, c] 

518 gd.comm.sum(ma) 

519 

520 self.magn = alpha / 2. * (magn + ma + mRa) 

521 

522 def distribute(self): 

523 """Distribute results to all cores.""" 

524 self.spin = self.kpt_comm.sum_scalar(self.spin) 

525 self.pspin = self.kpt_comm.sum_scalar(self.pspin) 

526 self.k = self.kpt_comm.sum_scalar(self.k) 

527 self.weight = self.kpt_comm.sum_scalar(self.weight) 

528 self.energy = self.kpt_comm.sum_scalar(self.energy) 

529 self.fij = self.kpt_comm.sum_scalar(self.fij) 

530 

531 self.kpt_comm.sum(self.me) 

532 self.kpt_comm.sum(self.mur) 

533 self.kpt_comm.sum(self.muv) 

534 self.kpt_comm.sum(self.magn) 

535 

536 def __add__(self, other): 

537 """Add two KSSingles""" 

538 result = copy(self) 

539 result.me = self.me + other.me 

540 result.mur = self.mur + other.mur 

541 result.muv = self.muv + other.muv 

542 result.magn = self.magn + other.magn 

543 return result 

544 

545 def __sub__(self, other): 

546 """Subtract two KSSingles""" 

547 result = copy(self) 

548 result.me = self.me - other.me 

549 result.mur = self.mur - other.mur 

550 result.muv = self.muv - other.muv 

551 result.magn = self.magn - other.magn 

552 return result 

553 

554 def __rmul__(self, x): 

555 return self.__mul__(x) 

556 

557 def __mul__(self, x): 

558 """Multiply a KSSingle with a number""" 

559 assert isinstance(x, (float, int)) 

560 result = copy(self) 

561 result.me = self.me * x 

562 result.mur = self.mur * x 

563 result.muv = self.muv * x 

564 result.magn = self.magn * x 

565 return result 

566 

567 def __truediv__(self, x): 

568 return self.__mul__(1. / x) 

569 

570 __div__ = __truediv__ 

571 

572 def fromstring(self, string, dtype=float): 

573 l = string.split() 

574 self.i = int(l.pop(0)) 

575 self.j = int(l.pop(0)) 

576 self.pspin = int(l.pop(0)) 

577 self.spin = int(l.pop(0)) 

578 if dtype == float: 

579 self.k = 0 

580 self.weight = 1 

581 else: 

582 self.k = int(l.pop(0)) 

583 self.weight = float(l.pop(0)) 

584 self.energy = float(l.pop(0)) 

585 self.fij = float(l.pop(0)) 

586 self.mur = np.array([dtype(l.pop(0)) for i in range(3)]) 

587 self.me = - self.mur * np.sqrt(self.energy * self.fij) 

588 self.muv = self.magn = None 

589 if len(l): 

590 self.muv = np.array([dtype(l.pop(0)) for i in range(3)]) 

591 if len(l): 

592 self.magn = np.array([dtype(l.pop(0)) for i in range(3)]) 

593 return None 

594 

595 def outstring(self): 

596 if self.mur.dtype == float: 

597 string = '{:d} {:d} {:d} {:d} {:.10g} {:f}'.format( 

598 self.i, self.j, self.pspin, self.spin, self.energy, self.fij) 

599 else: 

600 string = ( 

601 '{:d} {:d} {:d} {:d} {:d} {:.10g} {:g} {:g}'.format( 

602 self.i, self.j, self.pspin, self.spin, self.k, 

603 self.weight, self.energy, self.fij)) 

604 string += ' ' 

605 

606 def format_me(me): 

607 string = '' 

608 if me.dtype == float: 

609 for m in me: 

610 string += f' {m:.5e}' 

611 else: 

612 for m in me: 

613 string += ' {0.real:.5e}{0.imag:+.5e}j'.format(m) 

614 return string 

615 

616 string += ' ' + format_me(self.mur) 

617 if self.muv is not None: 

618 string += ' ' + format_me(self.muv) 

619 if self.magn is not None: 

620 string += ' ' + format_me(self.magn) 

621 string += '\n' 

622 

623 return string 

624 

625 def __str__(self): 

626 string = '# <KSSingle> %d->%d %d(%d) eji=%g[eV]' % \ 

627 (self.i, self.j, self.pspin, self.spin, 

628 self.energy * Hartree) 

629 if self.me.dtype == float: 

630 string += f' ({self.me[0]:g},{self.me[1]:g},{self.me[2]:g})' 

631 else: 

632 string += f' kpt={self.k:d} w={self.weight:g}' 

633 string += ' (' 

634 # use velocity form 

635 s = - np.sqrt(self.energy * self.fij) 

636 for c, m in enumerate(s * self.me): 

637 string += '{0.real:.5e}{0.imag:+.5e}j'.format(m) 

638 if c < 2: 

639 string += ',' 

640 string += ')' 

641 return string 

642 

643 def __eq__(self, other): 

644 """KSSingles are considred equal when their indices are equal.""" 

645 return (self.pspin == other.pspin and self.k == other.k and 

646 self.i == other.i and self.j == other.j) 

647 

648 def __hash__(self): 

649 """Hash similar to __eq__""" 

650 if not hasattr(self, 'hash'): 

651 self.hash = hash((self.spin, self.k, self.i, self.j)) 

652 return self.hash 

653 

654 # 

655 # User interface: ## 

656 # 

657 

658 def get_weight(self): 

659 return self.fij 

660 

661 

662def _get_and_distribute_wf(wfs, n, k, s): 

663 gd = wfs.gd 

664 wf = wfs.get_wave_function_array(n=n, k=k, s=s, realspace=True, 

665 periodic=False) 

666 if wfs.world.rank != 0: 

667 wf = gd.empty(dtype=wfs.dtype, global_array=True) 

668 wf = np.ascontiguousarray(wf) 

669 wfs.world.broadcast(wf, 0) 

670 wfd = gd.empty(dtype=wfs.dtype, global_array=False) 

671 wfd = gd.distribute(wf) 

672 return wfd