Coverage for gpaw/unfold.py: 66%
259 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1import numpy as np
2import pickle
4from ase.units import Hartree, Bohr
6from gpaw.kpt_descriptor import to1bz
7from gpaw.new.ase_interface import GPAW
8from gpaw.spinorbit import soc_eigenstates
9from gpaw.pw.descriptor import PWDescriptor
10import gpaw.mpi as mpi
13class Unfold:
14 """This Class is used to Unfold the Bands of a supercell (SC) calculations
15 into a the primitive cell (PC). As a convention (when possible) capital
16 letters variables are related to the SC while lowercase ones to the
17 PC """
19 def __init__(self,
20 name=None,
21 calc=None,
22 M=None,
23 spin=0,
24 spinorbit=None,
25 theta=90,
26 scale=1.0,
27 phi=90):
29 self.name = name
30 self.calc = GPAW(calc, txt=None, communicator=mpi.serial_comm)
31 self.M = np.array(M, dtype=float)
32 self.spinorbit = spinorbit
33 self.spin = spin
35 self.gd = self.calc.wfs.gd.new_descriptor()
37 self.kd = self.calc.wfs.kd
38 if self.calc.wfs.mode == 'pw':
39 self.pd = self.calc.wfs.pd
40 else:
41 self.pd = PWDescriptor(ecut=None, gd=self.gd, kd=self.kd,
42 dtype=complex)
44 self.acell_cv = self.gd.cell_cv
45 self.bcell_cv = 2 * np.pi * self.gd.icell_cv
47 self.nb = self.calc.get_number_of_bands()
49 self.v_Kmn = None
50 if spinorbit:
51 assert self.calc.density.collinear
52 self.nb *= 2
53 if mpi.world.rank == 0:
54 print('Calculating spinorbit Corrections')
55 soc = soc_eigenstates(self.calc,
56 scale=scale, theta=theta, phi=phi)
57 self.e_mK = soc.eigenvalues().T
58 self.v_Kmn = soc.eigenvectors()
59 if mpi.world.rank == 0:
60 print('Done with the spinorbit Corrections')
62 def get_K_index(self, K):
63 """Find the index of a given K."""
65 K = np.array([K])
66 bzKG = to1bz(K, self.acell_cv)[0]
67 iK = self.kd.where_is_q(bzKG, self.kd.ibzk_kc)
68 return iK
70 def get_g(self, iK):
71 """Not all the G vectors are relevant for the bands unfolding,
72 but only the ones that match the PC reciprocal vectors.
73 This function finds the relevant ones."""
75 G_Gv_temp = self.pd.get_reciprocal_vectors(q=iK, add_q=False)
76 G_Gc_temp = np.dot(G_Gv_temp, np.linalg.inv(self.bcell_cv))
78 iG_list = []
79 g_list = []
80 for iG, G in enumerate(G_Gc_temp):
81 a = np.dot(G, np.linalg.inv(self.M).T)
82 check = np.abs(a) % 1 < 1e-5
83 check2 = np.abs((np.abs(a[np.where(~check)]) % 1) - 1) < 1e-5
84 if all(check) or all(check2):
85 iG_list.append(iG)
86 g_list.append(G)
88 return np.array(iG_list), np.array(g_list)
90 def get_G_index(self, iK, G, G_list):
91 """Find the index of a given G."""
93 G_list -= G
94 sumG = np.sum(abs(G_list), axis=1)
95 iG = np.where(sumG < 1e-5)[0]
96 return iG
98 def get_eigenvalues(self, iK):
99 """Get the list of eigenvalues for a given iK."""
101 if not self.spinorbit:
102 e_m = self.calc.get_eigenvalues(kpt=iK, spin=self.spin)
103 else:
104 e_m = self.e_mK[:, iK]
105 return np.array(e_m) / Hartree
107 def get_pw_wavefunctions_k(self, iK):
108 """Get the list of Fourier coefficients of the WaveFunction for a
109 given iK. For spinors the number of bands is doubled and a spin
110 dimension is added."""
112 psi_mgrid = get_rs_wavefunctions_k(self.calc, iK, self.spinorbit,
113 self.v_Kmn, spin=self.spin)
114 if not self.spinorbit and self.calc.density.collinear:
115 psi_list_mG = []
116 for i in range(len(psi_mgrid)):
117 psi_list_mG.append(self.pd.fft(psi_mgrid[i], iK))
119 psi_mG = np.array(psi_list_mG)
120 return psi_mG
121 else:
122 u0_list_mG = []
123 u1_list_mG = []
124 for i in range(psi_mgrid.shape[0]):
125 u0_list_mG.append(self.pd.fft(psi_mgrid[i, 0], iK))
126 u1_list_mG.append(self.pd.fft(psi_mgrid[i, 1], iK))
128 u0_mG = np.array(u0_list_mG)
129 u1_mG = np.array(u1_list_mG)
131 u_mG = np.zeros((len(u0_mG),
132 2,
133 u0_mG.shape[1]), complex)
134 u_mG[:, 0] = u0_mG
135 u_mG[:, 1] = u1_mG
137 return u_mG
139 def get_spectral_weights_k(self, k_t):
140 r"""Returns the spectral weights for a given k in the PC:
142 P_mK(k_t) = \sum_n |<Km|k_t n>|**2
144 which can be shown to be equivalent to:
146 P_mK(k_t) = \sum_g |C_Km(g+k_t-K)|**2
147 """
149 K_c, G_t = find_K_from_k(k_t, self.M)
150 iK = self.get_K_index(K_c)
151 iG_list, g_list = self.get_g(iK)
152 gG_t_list = g_list + G_t
154 G_Gv = self.pd.get_reciprocal_vectors(q=iK, add_q=False)
155 G_Gc = np.dot(G_Gv, np.linalg.inv(self.bcell_cv))
157 igG_t_list = []
158 for g in gG_t_list:
159 igG_t_list.append(self.get_G_index(iK, g.copy(), G_Gc.copy()))
161 C_mG = self.get_pw_wavefunctions_k(iK)
162 P_m = []
163 if not self.spinorbit and self.calc.density.collinear:
164 for m in range(self.nb):
165 P = 0.
166 norm = np.sum(np.linalg.norm(C_mG[m, :])**2)
167 for iG in igG_t_list:
168 P += np.linalg.norm(C_mG[m, iG])**2
169 P_m.append(P / norm)
170 else:
171 for m in range(self.nb):
172 P = 0.
173 norm = np.sum(np.linalg.norm(C_mG[m, 0, :])**2 +
174 np.linalg.norm(C_mG[m, 1, :])**2)
175 for iG in igG_t_list:
176 P += (np.linalg.norm(C_mG[m, 0, iG])**2 +
177 np.linalg.norm(C_mG[m, 1, iG])**2)
178 P_m.append(P / norm)
180 return np.array(P_m)
182 def get_spectral_weights(self, kpoints, filename=None):
183 """Collect the spectral weights for the k points in the kpoints list.
185 This function is parallelized over k's."""
187 Nk = len(kpoints)
188 Nb = self.nb
190 world = mpi.world
191 if filename is None:
192 try:
193 e_mK, P_mK = pickle.load(open('weights_' + self.name +
194 '.pckl', 'rb'))
195 except OSError:
196 e_Km = []
197 P_Km = []
198 if world.rank == 0:
199 print('Getting EigenValues and Weights')
201 e_Km = np.zeros((Nk, Nb))
202 P_Km = np.zeros((Nk, Nb))
203 myk = range(0, Nk)[world.rank::world.size]
204 for ik in myk:
205 k = kpoints[ik]
206 print('kpoint: %s' % k)
207 K_c, G_c = find_K_from_k(k, self.M)
208 iK = self.get_K_index(K_c)
209 e_Km[ik] = self.get_eigenvalues(iK)
210 P_Km[ik] = self.get_spectral_weights_k(k)
212 world.barrier()
213 world.sum(e_Km)
214 world.sum(P_Km)
216 e_mK = np.array(e_Km).T
217 P_mK = np.array(P_Km).T
218 if world.rank == 0:
219 pickle.dump((e_mK, P_mK),
220 open('weights_' + self.name + '.pckl', 'wb'))
221 else:
222 e_mK, P_mK = pickle.load(open(filename, 'rb'))
224 return e_mK, P_mK
226 def spectral_function(self, kpts, x, X, points_name, width=0.002,
227 npts=10000, filename=None):
228 r"""Returns the spectral function for all the ks in kpoints:
230 eta / pi
232 A_k(e) = \sum_m P_mK(k) x ---------------------
234 (e - e_mk)**2 + eta**2
237 at each k-points defined on npts energy points in the range
238 [emin, emax]. The width keyword is FWHM = 2 * eta."""
240 Nk = len(kpts)
241 A_ke = np.zeros((Nk, npts), float)
243 world = mpi.world
244 e_mK, P_mK = self.get_spectral_weights(kpts, filename)
245 if world.rank == 0:
246 print('Calculating the Spectral Function')
247 emin = np.min(e_mK) - 5 * width
248 emax = np.max(e_mK) + 5 * width
249 e = np.linspace(emin, emax, npts)
251 for ik in range(Nk):
252 for ie in range(len(e_mK[:, ik])):
253 e0 = e_mK[ie, ik]
254 D = (width / 2 / np.pi) / ((e - e0)**2 + (width / 2)**2)
255 A_ke[ik] += P_mK[ie, ik] * D
256 if world.rank == 0:
257 pickle.dump((e * Hartree, A_ke, x, X, points_name),
258 open('sf_' + self.name + '.pckl', 'wb'))
259 print('Spectral Function calculation completed!')
260 return
263def find_K_from_k(k, M):
264 """Gets a k vector in scaled coordinates and returns a K vector and the
265 unfolding G in scaled Coordinates."""
267 KG = np.dot(M, k)
268 G = np.zeros(3, dtype=int)
270 for i in range(3):
271 if KG[i] > 0.5000001:
272 G[i] = int(np.round(KG[i]))
273 KG[i] -= np.round(KG[i])
274 elif KG[i] < -0.4999999:
275 G[i] = int(np.round(KG[i]))
276 KG[i] += abs(np.round(KG[i]))
278 return KG, G
281def get_rs_wavefunctions_k(calc, iK, spinorbit=False, v_Kmn=None, spin=0):
282 """Get the list of WaveFunction for a given iK. For spinors the number of
283 bands is doubled and a spin dimension is added."""
285 N_c = calc.wfs.gd.N_c
286 k_c = calc.wfs.kd.ibzk_kc[iK]
287 Nb = calc.wfs.bd.nbands
288 Ns = calc.wfs.nspins
289 eikr_R = np.exp(-2j * np.pi * np.dot(np.indices(N_c).T,
290 k_c / N_c).T)
292 if calc.wfs.mode == 'lcao' and not calc.wfs.positions_set:
293 calc.initialize_positions()
295 if not spinorbit:
296 psit_mgrid = np.array(
297 [calc.get_pseudo_wave_function(m, iK, spin, periodic=True)
298 for m in range(Nb)]) * Bohr**1.5
299 return psit_mgrid
300 else:
301 v_mn = v_Kmn[iK]
302 v0_mn = v_mn[:, ::2]
303 v1_mn = v_mn[:, 1::2]
305 u0_ngrid = np.array(
306 [calc.wfs.get_wave_function_array(n, iK, 0) * eikr_R
307 for n in range(Nb)])
308 u1_ngrid = np.array(
309 [calc.wfs.get_wave_function_array(n, iK, (Ns - 1)) * eikr_R
310 for n in range(Nb)])
311 u0_mG = np.swapaxes(np.dot(v0_mn, np.swapaxes(u0_ngrid, 0, 2)), 1, 2)
312 u1_mG = np.swapaxes(np.dot(v1_mn, np.swapaxes(u1_ngrid, 0, 2)), 1, 2)
313 ut_mgrid = np.zeros((len(u0_mG),
314 2,
315 len(u0_mG[0]),
316 len(u0_mG[0, 0]),
317 len(u0_mG[0, 0, 0])), complex)
318 ut_mgrid[:, 0] = u0_mG
319 ut_mgrid[:, 1] = u1_mG
320 return ut_mgrid
323def plot_spectral_function(filename, color='blue', eref=None,
324 emin=None, emax=None, scale=1):
325 """Function to plot spectral function corresponding to the bandstructure
326 along the kpoints path."""
328 try:
329 e, A_ke, x, X, points_name = pickle.load(open(filename + '.pckl',
330 'rb'))
331 except OSError:
332 print('You Need to Calculate the SF first!')
333 raise SystemExit()
335 import matplotlib.pyplot as plt
336 print('Plotting Spectral Function')
338 if eref is not None:
339 e -= eref
340 if emin is None:
341 emin = e.min()
342 if emax is None:
343 emax = e.max()
345 A_ke /= np.max(A_ke)
346 A_ek = A_ke.T * scale
347 A_ekc = np.reshape(A_ek, (A_ek.shape[0], A_ek.shape[1]))
349 mycmap = make_colormap(color)
351 plt.figure()
353 plt.plot([0, x[-1]], 2 * [0.0], '--', c='0.5')
354 plt.imshow(A_ekc + 0.23,
355 cmap=mycmap,
356 aspect='auto',
357 origin='lower',
358 vmin=0.,
359 vmax=1,
360 extent=[0, x[-1], e.min(), e.max()])
362 for k in X[1:-1]:
363 plt.plot([k, k], [emin, emax], lw=0.5, c='0.5')
364 plt.xticks(X, points_name, size=20)
365 plt.yticks(size=20)
366 plt.ylabel('E(eV)', size=20)
367 plt.axis([0, x[-1], emin, emax])
368 plt.savefig(filename + '_spec.png')
369 plt.show()
372def plot_band_structure(e_mK, P_mK, x, X, points_name,
373 weights_mK=None, color='red', fit=True, nfit=200):
374 """Function to plot the bandstructure using the P_mK weights directly.
375 each point is represented with a filled circle, whose size and color
376 vary with as a function of P_mK."""
378 import matplotlib.pyplot as plt
379 print('Plotting Bands Structure')
380 emin = e_mK.min()
381 emax = e_mK.max()
383 new_cmap = make_colormap(color)
385 if weights_mK is None:
386 weights_mK = P_mK.copy()
387 else:
388 weights_mK *= P_mK.copy()
390 plt.figure()
391 plt.plot([0, x[-1]], 2 * [0.0], '--', c='0.5')
392 plt.scatter(np.tile(x, len(e_mK)), e_mK.reshape(-1),
393 c=P_mK.reshape(-1),
394 cmap=new_cmap,
395 vmin=0.,
396 vmax=1.,
397 s=20. * weights_mK.reshape(-1),
398 marker='o',
399 edgecolor='none')
401 for k in X[1:-1]:
402 plt.plot([k, k], [emin, emax], lw=0.5, c='0.5')
403 plt.xticks(X, points_name, size=20)
404 plt.yticks(size=20)
405 plt.ylabel('E(eV)', size=20)
406 plt.axis([0, x[-1], emin, emax])
407 plt.show()
410def make_colormap(main_color):
411 """Custom colormaps used in plot_spectral function and
412 plot_band_structure."""
414 from matplotlib.colors import LinearSegmentedColormap
415 if main_color == 'blue':
416 cdict = {'red': ((0.0, 1.0, 1.0),
417 (0.25, 1.0, 1.0),
418 (0.5, 0.0, 0.0),
419 (1.0, 0.0, 0.0)),
421 'green': ((0.0, 1.0, 1.0),
422 (0.25, 1.0, 1.0),
423 (0.5, 0.0, 0.0),
424 (1.0, 0.0, 0.0)),
426 'blue': ((0.0, 1.0, 1.0),
427 (0.25, 1.0, 1.0),
428 (0.5, 1.0, 1.0),
429 (1.0, 0.75, 0.75))}
431 elif main_color == 'red':
432 cdict = {'red': ((0.0, 1.0, 1.0),
433 (0.25, 1.0, 1.0),
434 (0.5, 1.0, 1.0),
435 (1.0, 0.75, 0.75)),
437 'green': ((0.0, 1.0, 1.0),
438 (0.25, 1.0, 1.0),
439 (0.5, 0.0, 0.0),
440 (1.0, 0.0, 0.0)),
442 'blue': ((0.0, 1.0, 1.0),
443 (0.25, 1.0, 1.0),
444 (0.5, 0.0, 0.0),
445 (1.0, 0.0, 0.0))}
447 elif main_color == 'green':
448 cdict = {'red': ((0.0, 1.0, 1.0),
449 (0.25, 1.0, 1.0),
450 (0.5, 0.0, 0.0),
451 (1.0, 0.0, 0.0)),
453 'green': ((0.0, 1.0, 1.0),
454 (0.25, 1.0, 1.0),
455 (0.5, 1.0, 1.0),
456 (1.0, 0.75, 0.75)),
458 'blue': ((0.0, 1.0, 1.0),
459 (0.25, 1.0, 1.0),
460 (0.5, 0.0, 0.0),
461 (1.0, 0.0, 0.0)),
463 'alpha': ((0.0, 0.0, 0.0),
464 (0.25, 0.1, 0.1),
465 (0.5, 1.0, 1.0),
466 (1.0, 1.0, 1.0))}
468 cmap = LinearSegmentedColormap('mymap', cdict)
469 return cmap
472def get_vacuum_level(calc, plot_pot=False):
473 """Get the vacuum energy level from a given calculator."""
475 calc.restore_state()
476 if calc.wfs.mode == 'pw':
477 vHt_g = calc.hamiltonian.pd3.ifft(calc.hamiltonian.vHt_q) * Hartree
478 else:
479 vHt_g = calc.hamiltonian.vHt_g * Hartree
480 vHt_z = np.mean(np.mean(vHt_g, axis=0), axis=0)
482 if plot_pot:
483 import matplotlib.pyplot as plt
484 plt.plot(vHt_z)
485 plt.show()
486 return vHt_z[0]