Coverage for gpaw/unfold.py: 66%

259 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1import numpy as np 

2import pickle 

3 

4from ase.units import Hartree, Bohr 

5 

6from gpaw.kpt_descriptor import to1bz 

7from gpaw.new.ase_interface import GPAW 

8from gpaw.spinorbit import soc_eigenstates 

9from gpaw.pw.descriptor import PWDescriptor 

10import gpaw.mpi as mpi 

11 

12 

13class Unfold: 

14 """This Class is used to Unfold the Bands of a supercell (SC) calculations 

15 into a the primitive cell (PC). As a convention (when possible) capital 

16 letters variables are related to the SC while lowercase ones to the 

17 PC """ 

18 

19 def __init__(self, 

20 name=None, 

21 calc=None, 

22 M=None, 

23 spin=0, 

24 spinorbit=None, 

25 theta=90, 

26 scale=1.0, 

27 phi=90): 

28 

29 self.name = name 

30 self.calc = GPAW(calc, txt=None, communicator=mpi.serial_comm) 

31 self.M = np.array(M, dtype=float) 

32 self.spinorbit = spinorbit 

33 self.spin = spin 

34 

35 self.gd = self.calc.wfs.gd.new_descriptor() 

36 

37 self.kd = self.calc.wfs.kd 

38 if self.calc.wfs.mode == 'pw': 

39 self.pd = self.calc.wfs.pd 

40 else: 

41 self.pd = PWDescriptor(ecut=None, gd=self.gd, kd=self.kd, 

42 dtype=complex) 

43 

44 self.acell_cv = self.gd.cell_cv 

45 self.bcell_cv = 2 * np.pi * self.gd.icell_cv 

46 

47 self.nb = self.calc.get_number_of_bands() 

48 

49 self.v_Kmn = None 

50 if spinorbit: 

51 assert self.calc.density.collinear 

52 self.nb *= 2 

53 if mpi.world.rank == 0: 

54 print('Calculating spinorbit Corrections') 

55 soc = soc_eigenstates(self.calc, 

56 scale=scale, theta=theta, phi=phi) 

57 self.e_mK = soc.eigenvalues().T 

58 self.v_Kmn = soc.eigenvectors() 

59 if mpi.world.rank == 0: 

60 print('Done with the spinorbit Corrections') 

61 

62 def get_K_index(self, K): 

63 """Find the index of a given K.""" 

64 

65 K = np.array([K]) 

66 bzKG = to1bz(K, self.acell_cv)[0] 

67 iK = self.kd.where_is_q(bzKG, self.kd.ibzk_kc) 

68 return iK 

69 

70 def get_g(self, iK): 

71 """Not all the G vectors are relevant for the bands unfolding, 

72 but only the ones that match the PC reciprocal vectors. 

73 This function finds the relevant ones.""" 

74 

75 G_Gv_temp = self.pd.get_reciprocal_vectors(q=iK, add_q=False) 

76 G_Gc_temp = np.dot(G_Gv_temp, np.linalg.inv(self.bcell_cv)) 

77 

78 iG_list = [] 

79 g_list = [] 

80 for iG, G in enumerate(G_Gc_temp): 

81 a = np.dot(G, np.linalg.inv(self.M).T) 

82 check = np.abs(a) % 1 < 1e-5 

83 check2 = np.abs((np.abs(a[np.where(~check)]) % 1) - 1) < 1e-5 

84 if all(check) or all(check2): 

85 iG_list.append(iG) 

86 g_list.append(G) 

87 

88 return np.array(iG_list), np.array(g_list) 

89 

90 def get_G_index(self, iK, G, G_list): 

91 """Find the index of a given G.""" 

92 

93 G_list -= G 

94 sumG = np.sum(abs(G_list), axis=1) 

95 iG = np.where(sumG < 1e-5)[0] 

96 return iG 

97 

98 def get_eigenvalues(self, iK): 

99 """Get the list of eigenvalues for a given iK.""" 

100 

101 if not self.spinorbit: 

102 e_m = self.calc.get_eigenvalues(kpt=iK, spin=self.spin) 

103 else: 

104 e_m = self.e_mK[:, iK] 

105 return np.array(e_m) / Hartree 

106 

107 def get_pw_wavefunctions_k(self, iK): 

108 """Get the list of Fourier coefficients of the WaveFunction for a 

109 given iK. For spinors the number of bands is doubled and a spin 

110 dimension is added.""" 

111 

112 psi_mgrid = get_rs_wavefunctions_k(self.calc, iK, self.spinorbit, 

113 self.v_Kmn, spin=self.spin) 

114 if not self.spinorbit and self.calc.density.collinear: 

115 psi_list_mG = [] 

116 for i in range(len(psi_mgrid)): 

117 psi_list_mG.append(self.pd.fft(psi_mgrid[i], iK)) 

118 

119 psi_mG = np.array(psi_list_mG) 

120 return psi_mG 

121 else: 

122 u0_list_mG = [] 

123 u1_list_mG = [] 

124 for i in range(psi_mgrid.shape[0]): 

125 u0_list_mG.append(self.pd.fft(psi_mgrid[i, 0], iK)) 

126 u1_list_mG.append(self.pd.fft(psi_mgrid[i, 1], iK)) 

127 

128 u0_mG = np.array(u0_list_mG) 

129 u1_mG = np.array(u1_list_mG) 

130 

131 u_mG = np.zeros((len(u0_mG), 

132 2, 

133 u0_mG.shape[1]), complex) 

134 u_mG[:, 0] = u0_mG 

135 u_mG[:, 1] = u1_mG 

136 

137 return u_mG 

138 

139 def get_spectral_weights_k(self, k_t): 

140 r"""Returns the spectral weights for a given k in the PC: 

141 

142 P_mK(k_t) = \sum_n |<Km|k_t n>|**2 

143 

144 which can be shown to be equivalent to: 

145 

146 P_mK(k_t) = \sum_g |C_Km(g+k_t-K)|**2 

147 """ 

148 

149 K_c, G_t = find_K_from_k(k_t, self.M) 

150 iK = self.get_K_index(K_c) 

151 iG_list, g_list = self.get_g(iK) 

152 gG_t_list = g_list + G_t 

153 

154 G_Gv = self.pd.get_reciprocal_vectors(q=iK, add_q=False) 

155 G_Gc = np.dot(G_Gv, np.linalg.inv(self.bcell_cv)) 

156 

157 igG_t_list = [] 

158 for g in gG_t_list: 

159 igG_t_list.append(self.get_G_index(iK, g.copy(), G_Gc.copy())) 

160 

161 C_mG = self.get_pw_wavefunctions_k(iK) 

162 P_m = [] 

163 if not self.spinorbit and self.calc.density.collinear: 

164 for m in range(self.nb): 

165 P = 0. 

166 norm = np.sum(np.linalg.norm(C_mG[m, :])**2) 

167 for iG in igG_t_list: 

168 P += np.linalg.norm(C_mG[m, iG])**2 

169 P_m.append(P / norm) 

170 else: 

171 for m in range(self.nb): 

172 P = 0. 

173 norm = np.sum(np.linalg.norm(C_mG[m, 0, :])**2 + 

174 np.linalg.norm(C_mG[m, 1, :])**2) 

175 for iG in igG_t_list: 

176 P += (np.linalg.norm(C_mG[m, 0, iG])**2 + 

177 np.linalg.norm(C_mG[m, 1, iG])**2) 

178 P_m.append(P / norm) 

179 

180 return np.array(P_m) 

181 

182 def get_spectral_weights(self, kpoints, filename=None): 

183 """Collect the spectral weights for the k points in the kpoints list. 

184 

185 This function is parallelized over k's.""" 

186 

187 Nk = len(kpoints) 

188 Nb = self.nb 

189 

190 world = mpi.world 

191 if filename is None: 

192 try: 

193 e_mK, P_mK = pickle.load(open('weights_' + self.name + 

194 '.pckl', 'rb')) 

195 except OSError: 

196 e_Km = [] 

197 P_Km = [] 

198 if world.rank == 0: 

199 print('Getting EigenValues and Weights') 

200 

201 e_Km = np.zeros((Nk, Nb)) 

202 P_Km = np.zeros((Nk, Nb)) 

203 myk = range(0, Nk)[world.rank::world.size] 

204 for ik in myk: 

205 k = kpoints[ik] 

206 print('kpoint: %s' % k) 

207 K_c, G_c = find_K_from_k(k, self.M) 

208 iK = self.get_K_index(K_c) 

209 e_Km[ik] = self.get_eigenvalues(iK) 

210 P_Km[ik] = self.get_spectral_weights_k(k) 

211 

212 world.barrier() 

213 world.sum(e_Km) 

214 world.sum(P_Km) 

215 

216 e_mK = np.array(e_Km).T 

217 P_mK = np.array(P_Km).T 

218 if world.rank == 0: 

219 pickle.dump((e_mK, P_mK), 

220 open('weights_' + self.name + '.pckl', 'wb')) 

221 else: 

222 e_mK, P_mK = pickle.load(open(filename, 'rb')) 

223 

224 return e_mK, P_mK 

225 

226 def spectral_function(self, kpts, x, X, points_name, width=0.002, 

227 npts=10000, filename=None): 

228 r"""Returns the spectral function for all the ks in kpoints: 

229 

230 eta / pi 

231 

232 A_k(e) = \sum_m P_mK(k) x --------------------- 

233 

234 (e - e_mk)**2 + eta**2 

235 

236 

237 at each k-points defined on npts energy points in the range 

238 [emin, emax]. The width keyword is FWHM = 2 * eta.""" 

239 

240 Nk = len(kpts) 

241 A_ke = np.zeros((Nk, npts), float) 

242 

243 world = mpi.world 

244 e_mK, P_mK = self.get_spectral_weights(kpts, filename) 

245 if world.rank == 0: 

246 print('Calculating the Spectral Function') 

247 emin = np.min(e_mK) - 5 * width 

248 emax = np.max(e_mK) + 5 * width 

249 e = np.linspace(emin, emax, npts) 

250 

251 for ik in range(Nk): 

252 for ie in range(len(e_mK[:, ik])): 

253 e0 = e_mK[ie, ik] 

254 D = (width / 2 / np.pi) / ((e - e0)**2 + (width / 2)**2) 

255 A_ke[ik] += P_mK[ie, ik] * D 

256 if world.rank == 0: 

257 pickle.dump((e * Hartree, A_ke, x, X, points_name), 

258 open('sf_' + self.name + '.pckl', 'wb')) 

259 print('Spectral Function calculation completed!') 

260 return 

261 

262 

263def find_K_from_k(k, M): 

264 """Gets a k vector in scaled coordinates and returns a K vector and the 

265 unfolding G in scaled Coordinates.""" 

266 

267 KG = np.dot(M, k) 

268 G = np.zeros(3, dtype=int) 

269 

270 for i in range(3): 

271 if KG[i] > 0.5000001: 

272 G[i] = int(np.round(KG[i])) 

273 KG[i] -= np.round(KG[i]) 

274 elif KG[i] < -0.4999999: 

275 G[i] = int(np.round(KG[i])) 

276 KG[i] += abs(np.round(KG[i])) 

277 

278 return KG, G 

279 

280 

281def get_rs_wavefunctions_k(calc, iK, spinorbit=False, v_Kmn=None, spin=0): 

282 """Get the list of WaveFunction for a given iK. For spinors the number of 

283 bands is doubled and a spin dimension is added.""" 

284 

285 N_c = calc.wfs.gd.N_c 

286 k_c = calc.wfs.kd.ibzk_kc[iK] 

287 Nb = calc.wfs.bd.nbands 

288 Ns = calc.wfs.nspins 

289 eikr_R = np.exp(-2j * np.pi * np.dot(np.indices(N_c).T, 

290 k_c / N_c).T) 

291 

292 if calc.wfs.mode == 'lcao' and not calc.wfs.positions_set: 

293 calc.initialize_positions() 

294 

295 if not spinorbit: 

296 psit_mgrid = np.array( 

297 [calc.get_pseudo_wave_function(m, iK, spin, periodic=True) 

298 for m in range(Nb)]) * Bohr**1.5 

299 return psit_mgrid 

300 else: 

301 v_mn = v_Kmn[iK] 

302 v0_mn = v_mn[:, ::2] 

303 v1_mn = v_mn[:, 1::2] 

304 

305 u0_ngrid = np.array( 

306 [calc.wfs.get_wave_function_array(n, iK, 0) * eikr_R 

307 for n in range(Nb)]) 

308 u1_ngrid = np.array( 

309 [calc.wfs.get_wave_function_array(n, iK, (Ns - 1)) * eikr_R 

310 for n in range(Nb)]) 

311 u0_mG = np.swapaxes(np.dot(v0_mn, np.swapaxes(u0_ngrid, 0, 2)), 1, 2) 

312 u1_mG = np.swapaxes(np.dot(v1_mn, np.swapaxes(u1_ngrid, 0, 2)), 1, 2) 

313 ut_mgrid = np.zeros((len(u0_mG), 

314 2, 

315 len(u0_mG[0]), 

316 len(u0_mG[0, 0]), 

317 len(u0_mG[0, 0, 0])), complex) 

318 ut_mgrid[:, 0] = u0_mG 

319 ut_mgrid[:, 1] = u1_mG 

320 return ut_mgrid 

321 

322 

323def plot_spectral_function(filename, color='blue', eref=None, 

324 emin=None, emax=None, scale=1): 

325 """Function to plot spectral function corresponding to the bandstructure 

326 along the kpoints path.""" 

327 

328 try: 

329 e, A_ke, x, X, points_name = pickle.load(open(filename + '.pckl', 

330 'rb')) 

331 except OSError: 

332 print('You Need to Calculate the SF first!') 

333 raise SystemExit() 

334 

335 import matplotlib.pyplot as plt 

336 print('Plotting Spectral Function') 

337 

338 if eref is not None: 

339 e -= eref 

340 if emin is None: 

341 emin = e.min() 

342 if emax is None: 

343 emax = e.max() 

344 

345 A_ke /= np.max(A_ke) 

346 A_ek = A_ke.T * scale 

347 A_ekc = np.reshape(A_ek, (A_ek.shape[0], A_ek.shape[1])) 

348 

349 mycmap = make_colormap(color) 

350 

351 plt.figure() 

352 

353 plt.plot([0, x[-1]], 2 * [0.0], '--', c='0.5') 

354 plt.imshow(A_ekc + 0.23, 

355 cmap=mycmap, 

356 aspect='auto', 

357 origin='lower', 

358 vmin=0., 

359 vmax=1, 

360 extent=[0, x[-1], e.min(), e.max()]) 

361 

362 for k in X[1:-1]: 

363 plt.plot([k, k], [emin, emax], lw=0.5, c='0.5') 

364 plt.xticks(X, points_name, size=20) 

365 plt.yticks(size=20) 

366 plt.ylabel('E(eV)', size=20) 

367 plt.axis([0, x[-1], emin, emax]) 

368 plt.savefig(filename + '_spec.png') 

369 plt.show() 

370 

371 

372def plot_band_structure(e_mK, P_mK, x, X, points_name, 

373 weights_mK=None, color='red', fit=True, nfit=200): 

374 """Function to plot the bandstructure using the P_mK weights directly. 

375 each point is represented with a filled circle, whose size and color 

376 vary with as a function of P_mK.""" 

377 

378 import matplotlib.pyplot as plt 

379 print('Plotting Bands Structure') 

380 emin = e_mK.min() 

381 emax = e_mK.max() 

382 

383 new_cmap = make_colormap(color) 

384 

385 if weights_mK is None: 

386 weights_mK = P_mK.copy() 

387 else: 

388 weights_mK *= P_mK.copy() 

389 

390 plt.figure() 

391 plt.plot([0, x[-1]], 2 * [0.0], '--', c='0.5') 

392 plt.scatter(np.tile(x, len(e_mK)), e_mK.reshape(-1), 

393 c=P_mK.reshape(-1), 

394 cmap=new_cmap, 

395 vmin=0., 

396 vmax=1., 

397 s=20. * weights_mK.reshape(-1), 

398 marker='o', 

399 edgecolor='none') 

400 

401 for k in X[1:-1]: 

402 plt.plot([k, k], [emin, emax], lw=0.5, c='0.5') 

403 plt.xticks(X, points_name, size=20) 

404 plt.yticks(size=20) 

405 plt.ylabel('E(eV)', size=20) 

406 plt.axis([0, x[-1], emin, emax]) 

407 plt.show() 

408 

409 

410def make_colormap(main_color): 

411 """Custom colormaps used in plot_spectral function and 

412 plot_band_structure.""" 

413 

414 from matplotlib.colors import LinearSegmentedColormap 

415 if main_color == 'blue': 

416 cdict = {'red': ((0.0, 1.0, 1.0), 

417 (0.25, 1.0, 1.0), 

418 (0.5, 0.0, 0.0), 

419 (1.0, 0.0, 0.0)), 

420 

421 'green': ((0.0, 1.0, 1.0), 

422 (0.25, 1.0, 1.0), 

423 (0.5, 0.0, 0.0), 

424 (1.0, 0.0, 0.0)), 

425 

426 'blue': ((0.0, 1.0, 1.0), 

427 (0.25, 1.0, 1.0), 

428 (0.5, 1.0, 1.0), 

429 (1.0, 0.75, 0.75))} 

430 

431 elif main_color == 'red': 

432 cdict = {'red': ((0.0, 1.0, 1.0), 

433 (0.25, 1.0, 1.0), 

434 (0.5, 1.0, 1.0), 

435 (1.0, 0.75, 0.75)), 

436 

437 'green': ((0.0, 1.0, 1.0), 

438 (0.25, 1.0, 1.0), 

439 (0.5, 0.0, 0.0), 

440 (1.0, 0.0, 0.0)), 

441 

442 'blue': ((0.0, 1.0, 1.0), 

443 (0.25, 1.0, 1.0), 

444 (0.5, 0.0, 0.0), 

445 (1.0, 0.0, 0.0))} 

446 

447 elif main_color == 'green': 

448 cdict = {'red': ((0.0, 1.0, 1.0), 

449 (0.25, 1.0, 1.0), 

450 (0.5, 0.0, 0.0), 

451 (1.0, 0.0, 0.0)), 

452 

453 'green': ((0.0, 1.0, 1.0), 

454 (0.25, 1.0, 1.0), 

455 (0.5, 1.0, 1.0), 

456 (1.0, 0.75, 0.75)), 

457 

458 'blue': ((0.0, 1.0, 1.0), 

459 (0.25, 1.0, 1.0), 

460 (0.5, 0.0, 0.0), 

461 (1.0, 0.0, 0.0)), 

462 

463 'alpha': ((0.0, 0.0, 0.0), 

464 (0.25, 0.1, 0.1), 

465 (0.5, 1.0, 1.0), 

466 (1.0, 1.0, 1.0))} 

467 

468 cmap = LinearSegmentedColormap('mymap', cdict) 

469 return cmap 

470 

471 

472def get_vacuum_level(calc, plot_pot=False): 

473 """Get the vacuum energy level from a given calculator.""" 

474 

475 calc.restore_state() 

476 if calc.wfs.mode == 'pw': 

477 vHt_g = calc.hamiltonian.pd3.ifft(calc.hamiltonian.vHt_q) * Hartree 

478 else: 

479 vHt_g = calc.hamiltonian.vHt_g * Hartree 

480 vHt_z = np.mean(np.mean(vHt_g, axis=0), axis=0) 

481 

482 if plot_pot: 

483 import matplotlib.pyplot as plt 

484 plt.plot(vHt_z) 

485 plt.show() 

486 return vHt_z[0]