Coverage for gpaw/wannier90.py: 84%

400 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1import numpy as np 

2from gpaw.utilities.blas import gemmdot 

3from gpaw.ibz2bz import (get_overlap, get_overlap_coefficients, 

4 get_phase_shifted_overlap_coefficients, 

5 IBZ2BZMaps) 

6from gpaw.spinorbit import soc_eigenstates 

7 

8 

9class Wannier90: 

10 def __init__(self, calc, seed=None, bands=None, orbitals_ai=None, 

11 spin=0, spinors=False): 

12 

13 if seed is None: 

14 seed = calc.atoms.get_chemical_formula() 

15 self.seed = seed 

16 

17 if bands is None: 

18 bands = range(calc.get_number_of_bands()) 

19 self.bands = bands 

20 

21 Na = len(calc.atoms) 

22 if orbitals_ai is None: 

23 orbitals_ai = [] 

24 for ia in range(Na): 

25 ni = 0 

26 setup = calc.wfs.setups[ia] 

27 for l, n in zip(setup.l_j, setup.n_j): 

28 if not n == -1: 

29 ni += 2 * l + 1 

30 orbitals_ai.append(range(ni)) 

31 

32 self.calc = calc 

33 self.ibz2bz = IBZ2BZMaps.from_calculator(calc) 

34 self.bands = bands 

35 self.Nn = len(bands) 

36 self.Na = Na 

37 self.orbitals_ai = orbitals_ai 

38 self.Nw = np.sum([len(orbitals_ai[ai]) for ai in range(Na)]) 

39 self.kpts_kc = calc.get_ibz_k_points() 

40 self.Nk = len(self.kpts_kc) 

41 self.spin = spin 

42 self.spinors = spinors 

43 

44 if spinors: 

45 # spinorbit.WaveFunctions.transform currently do not suppport 

46 # transformation of wavefunctions, only projections. 

47 # XXX: should be updated in the future 

48 assert calc.wfs.kd.nbzkpts == calc.wfs.kd.nibzkpts 

49 self.soc = soc_eigenstates(calc) 

50 else: 

51 self.soc = None 

52 

53 def write_input(self, 

54 mp=None, 

55 plot=False, 

56 num_iter=100, 

57 write_xyz=False, 

58 write_rmn=False, 

59 translate_home_cell=False, 

60 dis_num_iter=200, 

61 dis_froz_max=0.1, 

62 dis_mix_ratio=0.5, 

63 dis_win_min=None, 

64 dis_win_max=None, 

65 search_shells=None, 

66 write_u_matrices=False): 

67 calc = self.calc 

68 seed = self.seed 

69 bands = self.bands 

70 orbitals_ai = self.orbitals_ai 

71 spinors = self.spinors 

72 

73 if seed is None: 

74 seed = calc.atoms.get_chemical_formula() 

75 

76 if bands is None: 

77 bands = range(calc.get_number_of_bands()) 

78 

79 Na = len(calc.atoms) 

80 if orbitals_ai is None: 

81 orbitals_ai = [] 

82 for ia in range(Na): 

83 ni = 0 

84 setup = calc.wfs.setups[ia] 

85 for l, n in zip(setup.l_j, setup.n_j): 

86 if not n == -1: 

87 ni += 2 * l + 1 

88 orbitals_ai.append(range(ni)) 

89 assert len(orbitals_ai) == Na 

90 

91 Nw = np.sum([len(orbitals_ai[ai]) for ai in range(Na)]) 

92 if spinors: 

93 Nw *= 2 

94 new_bands = [] 

95 for n in bands: 

96 new_bands.append(2 * n) 

97 new_bands.append(2 * n + 1) 

98 bands = new_bands 

99 

100 f = open(seed + '.win', 'w') 

101 

102 pos_ac = calc.spos_ac 

103 # pos_av = calc.atoms.get_positions() 

104 # cell_cv = calc.atoms.get_cell() 

105 # pos_ac = np.dot(pos_av, np.linalg.inv(cell_cv)) 

106 

107 print('begin projections', file=f) 

108 for ia, orbitals_i in enumerate(orbitals_ai): 

109 setup = calc.wfs.setups[ia] 

110 l_i = [] 

111 n_i = [] 

112 for n, l in zip(setup.n_j, setup.l_j): 

113 if not n == -1: 

114 l_i += (2 * l + 1) * [l] 

115 n_i += (2 * l + 1) * [n] 

116 r_c = pos_ac[ia] 

117 for orb in orbitals_i: 

118 l = l_i[orb] 

119 n = n_i[orb] 

120 print(f'f={r_c[0]:1.2f}, {r_c[1]:1.2f}, {r_c[2]:1.2f} : s ', 

121 end='', file=f) 

122 print(f'# n = {n}, l = {l}', file=f) 

123 

124 print('end projections', file=f) 

125 print(file=f) 

126 

127 if spinors: 

128 print('spinors = True', file=f) 

129 else: 

130 print('spinors = False', file=f) 

131 if write_u_matrices: 

132 print('write_u_matrices = True', file=f) 

133 print('write_hr = True', file=f) 

134 if write_xyz: 

135 print('write_xyz = True', file=f) 

136 if write_rmn: 

137 print('write_tb = True', file=f) 

138 print('write_rmn = True', file=f) 

139 if translate_home_cell: 

140 print('translate_home_cell = True', file=f) 

141 print(file=f) 

142 print('num_bands = %d' % len(bands), file=f) 

143 

144 if search_shells is not None: 

145 print(f"search_shells = {search_shells}", file=f) 

146 

147 maxn = max(bands) 

148 if maxn + 1 != len(bands): 

149 diffn = maxn - len(bands) 

150 print('exclude_bands : ', end='', file=f) 

151 counter = 0 

152 for n in range(maxn): 

153 if n not in bands: 

154 counter += 1 

155 if counter != diffn + 1: 

156 print('%d,' % (n + 1), sep='', end='', file=f) 

157 else: 

158 print('%d' % (n + 1), file=f) 

159 print(file=f) 

160 

161 print('guiding_centres = True', file=f) 

162 print('num_wann = %d' % Nw, file=f) 

163 print('num_iter = %d' % num_iter, file=f) 

164 print(file=f) 

165 

166 if len(bands) > Nw: 

167 ef = calc.get_fermi_level() 

168 print('fermi_energy = %2.3f' % ef, file=f) 

169 if dis_froz_max is not None: 

170 print('dis_froz_max = %2.3f' % (ef + dis_froz_max), file=f) 

171 if dis_win_min is not None: 

172 print('dis_win_min = %2.3f' % (ef + dis_win_min), file=f) 

173 if dis_win_max is not None: 

174 print('dis_win_max = %2.3f' % (ef + dis_win_max), file=f) 

175 print('dis_num_iter = %d' % dis_num_iter, file=f) 

176 print('dis_mix_ratio = %1.1f' % dis_mix_ratio, file=f) 

177 print(file=f) 

178 

179 print('begin unit_cell_cart', file=f) 

180 for cell_c in calc.atoms.cell: 

181 print(f'{cell_c[0]:14.10f} {cell_c[1]:14.10f} {cell_c[2]:14.10f}', 

182 file=f) 

183 print('end unit_cell_cart', file=f) 

184 print(file=f) 

185 

186 print('begin atoms_frac', file=f) 

187 for atom, pos_c in zip(calc.atoms, pos_ac): 

188 print(atom.symbol, end='', file=f) 

189 print(f'{pos_c[0]:14.10f} {pos_c[1]:14.10f} {pos_c[2]:14.10f}', 

190 file=f) 

191 print('end atoms_frac', file=f) 

192 print(file=f) 

193 

194 if plot: 

195 print('wannier_plot = True', file=f) 

196 print('wvfn_formatted = True', file=f) 

197 print(file=f) 

198 

199 if mp is not None: 

200 N_c = mp 

201 else: 

202 N_c = calc.wfs.kd.N_c 

203 print('mp_grid =', N_c[0], N_c[1], N_c[2], file=f) 

204 print(file=f) 

205 print('begin kpoints', file=f) 

206 

207 for kpt in calc.get_bz_k_points(): 

208 print(f'{kpt[0]:14.10f} {kpt[1]:14.10f} {kpt[2]:14.10f}', file=f) 

209 print('end kpoints', file=f) 

210 

211 f.close() 

212 

213 def write_projections(self): 

214 calc = self.calc 

215 seed = self.seed 

216 spin = self.spin 

217 orbitals_ai = self.orbitals_ai 

218 soc = self.soc 

219 

220 if seed is None: 

221 seed = calc.atoms.get_chemical_formula() 

222 

223 bands = get_bands(seed) 

224 Nn = len(bands) 

225 

226 spinors = False 

227 

228 win_file = open(seed + '.win') 

229 for line in win_file.readlines(): 

230 l_e = line.split() 

231 if len(l_e) > 0: 

232 if l_e[0] == 'spinors': 

233 spinors = l_e[2] 

234 if spinors in ['T', 'true', '1', 'True']: 

235 spinors = True 

236 else: 

237 spinors = False 

238 if l_e[0] == 'num_wann': 

239 Nw = int(l_e[2]) 

240 if l_e[0] == 'mp_grid': 

241 Nk = int(l_e[2]) * int(l_e[3]) * int(l_e[4]) 

242 assert Nk == len(calc.get_bz_k_points()) 

243 

244 Na = len(calc.atoms) 

245 if orbitals_ai is None: 

246 orbitals_ai = [] 

247 for ia in range(Na): 

248 ni = 0 

249 setup = calc.wfs.setups[ia] 

250 for l, n in zip(setup.l_j, setup.n_j): 

251 if not n == -1: 

252 ni += 2 * l + 1 

253 orbitals_ai.append(range(ni)) 

254 assert len(orbitals_ai) == Na 

255 

256 if spinors: 

257 new_orbitals_ai = [] 

258 for orbitals_i in orbitals_ai: 

259 new_orbitals_i = [] 

260 for i in orbitals_i: 

261 new_orbitals_i.append(2 * i) 

262 new_orbitals_i.append(2 * i + 1) 

263 new_orbitals_ai.append(new_orbitals_i) 

264 orbitals_ai = new_orbitals_ai 

265 

266 Ni = 0 

267 for orbitals_i in orbitals_ai: 

268 Ni += len(orbitals_i) 

269 assert Nw == Ni 

270 

271 f = open(seed + '.amn', 'w') 

272 

273 print('Kohn-Sham input generated from GPAW calculation', file=f) 

274 print('%10d %6d %6d' % (Nn, Nk, Nw), file=f) 

275 

276 P_kni = np.zeros((Nk, Nn, Nw), complex) 

277 for ik in range(Nk): 

278 if spinors: 

279 P_ani = soc[ik].P_amj 

280 else: 

281 P_ani = get_projections_in_bz(calc.wfs, 

282 ik, 

283 spin, 

284 self.ibz2bz, 

285 bcomm=None) 

286 for i in range(Nw): 

287 icount = 0 

288 for ai in range(Na): 

289 ni = len(orbitals_ai[ai]) 

290 P_ni = P_ani[ai][bands] 

291 P_ni = P_ni[:, orbitals_ai[ai]] 

292 P_kni[ik, :, icount:ni + icount] = P_ni.conj() 

293 icount += ni 

294 

295 for ik in range(Nk): 

296 for i in range(Nw): 

297 for n in range(Nn): 

298 P = P_kni[ik, n, i] 

299 data = (n + 1, i + 1, ik + 1, P.real, P.imag) 

300 print('%4d %4d %4d %18.12f %20.12f' % data, file=f) 

301 

302 f.close() 

303 

304 def write_eigenvalues(self): 

305 calc = self.calc 

306 seed = self.seed 

307 spin = self.spin 

308 soc = self.soc 

309 

310 bands = get_bands(seed) 

311 

312 f = open(seed + '.eig', 'w') 

313 

314 for ik in range(len(calc.get_bz_k_points())): 

315 if soc is None: 

316 ibzk = calc.wfs.kd.bz2ibz_k[ik] # IBZ k-point 

317 e_n = calc.get_eigenvalues(kpt=ibzk, spin=spin) 

318 else: 

319 e_n = soc[ik].eig_m 

320 for i, n in enumerate(bands): 

321 data = (i + 1, ik + 1, e_n[n]) 

322 print('%5d %5d %14.6f' % data, file=f) 

323 

324 f.close() 

325 

326 def write_overlaps(self, less_memory=False): 

327 calc = self.calc 

328 seed = self.seed 

329 spin = self.spin 

330 soc = self.soc 

331 ibz2bz = self.ibz2bz 

332 

333 if seed is None: 

334 seed = calc.atoms.get_chemical_formula() 

335 

336 if soc is None: 

337 spinors = False 

338 else: 

339 spinors = True 

340 

341 bands = get_bands(seed) 

342 Nn = len(bands) 

343 kpts_kc = calc.get_bz_k_points() 

344 Nk = len(kpts_kc) 

345 

346 nnkp = open(seed + '.nnkp') 

347 lines = nnkp.readlines() 

348 for il, line in enumerate(lines): 

349 if len(line.split()) > 1: 

350 if line.split()[0] == 'begin' and line.split()[1] == 'nnkpts': 

351 Nb = eval(lines[il + 1].split()[0]) 

352 i0 = il + 2 

353 break 

354 

355 f = open(seed + '.mmn', 'w') 

356 

357 print('Kohn-Sham input generated from GPAW calculation', file=f) 

358 print('%10d %6d %6d' % (Nn, Nk, Nb), file=f) 

359 

360 icell_cv = (2 * np.pi) * np.linalg.inv(calc.wfs.gd.cell_cv).T 

361 r_g = calc.wfs.gd.get_grid_point_coordinates() 

362 

363 spos_ac = calc.spos_ac 

364 wfs = calc.wfs 

365 dO_aii = get_overlap_coefficients(wfs) 

366 

367 if not less_memory: 

368 u_knG = [] 

369 for ik in range(Nk): 

370 u_nG = self.wavefunctions(ik, bands) 

371 u_knG.append(u_nG) 

372 

373 proj_k = [] 

374 for ik in range(Nk): 

375 if spinors: 

376 proj_k.append(soc[ik].projections) 

377 else: 

378 proj_k.append(get_projections_in_bz(calc.wfs, 

379 ik, spin, 

380 ibz2bz, 

381 bcomm=None)) 

382 

383 for ik1 in range(Nk): 

384 if less_memory: 

385 u1_nG = self.wavefunctions(ik1, bands) 

386 else: 

387 u1_nG = u_knG[ik1] 

388 for ib in range(Nb): 

389 # b denotes nearest neighbor k-points 

390 line = lines[i0 + ik1 * Nb + ib].split() 

391 ik2 = int(line[1]) - 1 

392 if less_memory: 

393 u2_nG = self.wavefunctions(ik2, bands) 

394 else: 

395 u2_nG = u_knG[ik2] 

396 

397 G_c = np.array([int(line[i]) for i in range(2, 5)]) 

398 bG_v = np.dot(G_c, icell_cv) 

399 u2_nG = u2_nG * np.exp(-1.0j * gemmdot(bG_v, r_g, beta=0.0)) 

400 bG_c = kpts_kc[ik2] - kpts_kc[ik1] + G_c 

401 phase_shifted_dO_aii = get_phase_shifted_overlap_coefficients( 

402 dO_aii, spos_ac, -bG_c) 

403 M_mm = get_overlap(bands, 

404 wfs.gd, 

405 u1_nG, 

406 u2_nG, 

407 proj_k[ik1], 

408 proj_k[ik2], 

409 phase_shifted_dO_aii) 

410 indices = (ik1 + 1, ik2 + 1, G_c[0], G_c[1], G_c[2]) 

411 print('%3d %3d %4d %3d %3d' % indices, file=f) 

412 for m1 in range(len(M_mm)): 

413 for m2 in range(len(M_mm)): 

414 M = M_mm[m2, m1] 

415 print(f'{M.real:20.12f} {M.imag:20.12f}', file=f) 

416 

417 f.close() 

418 

419 def write_wavefunctions(self): 

420 

421 calc = self.calc 

422 soc = self.soc 

423 spin = self.spin 

424 seed = self.seed 

425 

426 if soc is None: 

427 spinors = False 

428 else: 

429 spinors = True 

430 

431 if seed is None: 

432 seed = calc.atoms.get_chemical_formula() 

433 

434 bands = get_bands(seed) 

435 Nn = len(bands) 

436 Nk = len(calc.get_bz_k_points()) 

437 

438 for ik in range(Nk): 

439 if spinors: 

440 # For spinors, G denotes spin and grid: G = (s, gx, gy, gz) 

441 u_nG = soc[ik].wavefunctions(calc, periodic=True) 

442 else: 

443 # For non-spinors, G denotes grid: G = (gx, gy, gz) 

444 u_nG = self.wavefunctions(ik, bands) 

445 

446 f = open('UNK%s.%d' % (str(ik + 1).zfill(5), spin + 1), 'w') 

447 grid_v = np.shape(u_nG)[1:] 

448 print(grid_v[0], grid_v[1], grid_v[2], ik + 1, Nn, file=f) 

449 for n in bands: 

450 for iz in range(grid_v[2]): 

451 for iy in range(grid_v[1]): 

452 for ix in range(grid_v[0]): 

453 u = u_nG[n, ix, iy, iz] 

454 print(u.real, u.imag, file=f) 

455 f.close() 

456 

457 def wavefunctions(self, bz_index, bands): 

458 maxband = bands[-1] + 1 

459 if self.spinors: 

460 # For spinors, G denotes spin and grid: G = (s, gx, gy, gz) 

461 return self.soc[bz_index].wavefunctions( 

462 self.calc, periodic=True)[bands] 

463 # For non-spinors, G denotes grid: G = (gx, gy, gz) 

464 ibz_index = self.calc.wfs.kd.bz2ibz_k[bz_index] 

465 ut_nR = np.array([self.calc.wfs.get_wave_function_array( 

466 n, ibz_index, self.spin, 

467 periodic=True) for n in range(maxband)]) 

468 ut_nR_sym = np.array([self.ibz2bz[bz_index].map_pseudo_wave_to_BZ( 

469 ut_nR[n]) for n in range(maxband)]) 

470 

471 return ut_nR_sym 

472 

473 

474def get_bands(seed): 

475 win_file = open(seed + '.win') 

476 exclude_bands = None 

477 for line in win_file.readlines(): 

478 l_e = line.split() 

479 if len(l_e) > 0: 

480 if l_e[0] == 'num_bands': 

481 Nn = int(l_e[2]) 

482 if l_e[0] == 'exclude_bands': 

483 exclude_bands = line.split()[2] 

484 exclude_bands = [int(n) - 1 for n in exclude_bands.split(',')] 

485 if exclude_bands is None: 

486 bands = range(Nn) 

487 else: 

488 bands = range(Nn + len(exclude_bands)) 

489 bands = [n for n in bands if n not in exclude_bands] 

490 win_file.close() 

491 

492 return bands 

493 

494 

495def get_projections_in_bz(wfs, K, s, ibz2bz, bcomm=None): 

496 """ Returns projections object in full BZ 

497 wfs: calc.wfs object 

498 K: BZ k-point index 

499 s: spin index 

500 ibz2bz: IBZ2BZMaps 

501 bcomm: band communicator 

502 """ 

503 ik = wfs.kd.bz2ibz_k[K] # IBZ k-point 

504 kpt = wfs.kpt_qs[ik][s] 

505 nbands = wfs.bd.nbands 

506 # Get projections in ibz 

507 proj = kpt.projections.new(nbands=nbands, bcomm=bcomm) 

508 proj.array[:] = kpt.projections.array[:nbands] 

509 

510 # map projections to bz 

511 proj_sym = ibz2bz[K].map_projections(proj) 

512 return proj_sym 

513 

514 

515def read_umat(seed, kd, dis=False): 

516 """ 

517 Reads wannier transformation matrix 

518 """ 

519 if ".mat" not in seed: 

520 if dis: 

521 seed += "_u_dis.mat" 

522 else: 

523 seed += "_u.mat" 

524 f = open(seed, "r") 

525 f.readline() # first line is a comment 

526 nk, nw1, nw2 = [int(i) for i in f.readline().split()] 

527 assert nk == kd.nbzkpts 

528 uwan = np.empty([nw1, nw2, nk], dtype=complex) 

529 iklist = [] # list to store found iks 

530 for ik1 in range(nk): 

531 f.readline() # empty line 

532 K_c = [float(rdum) for rdum in f.readline().split()] 

533 ik = kd.where_is_q(K_c, kd.bzk_kc) 

534 assert np.allclose(np.array(K_c), kd.bzk_kc[ik]) 

535 iklist.append(ik) 

536 for ib1 in range(nw1): 

537 for ib2 in range(nw2): 

538 rdum1, rdum2 = [float(rdum) for rdum in 

539 f.readline().split()] 

540 uwan[ib1, ib2, ik] = complex(rdum1, rdum2) 

541 assert set(iklist) == set(range(nk)) # check that all k:s were found 

542 return uwan, nk, nw1, nw2 

543 

544 

545def read_uwan(seed, kd, dis=False): 

546 """ 

547 Reads wannier transformation matrix 

548 Input parameters: 

549 ----------------- 

550 seed: str 

551 seed in wannier calculation 

552 kd: kpt descriptor 

553 dis: logical 

554 should be set to true if nband > nwan 

555 """ 

556 assert '.mat' not in seed 

557 # reads in wannier transformation matrix 

558 umat, nk, nw1, nw2 = read_umat(seed, kd, dis=False) 

559 

560 if dis: 

561 # Reads in transformation to optimal subspace 

562 umat_dis, nk, nw1, nw2 = read_umat(seed, kd, dis=True) 

563 uwan = np.zeros_like(umat_dis) 

564 for ik in range(nk): 

565 uwan[:, :, ik] = umat[:, :, ik] @ umat_dis[:, :, ik] 

566 else: 

567 uwan = umat 

568 return uwan, nk, nw1, nw2