Coverage for gpaw/response/tool.py: 37%
161 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1import sys
3import numpy as np
4from scipy.optimize import leastsq
6from ase.units import Ha
7import gpaw.mpi as mpi
8from gpaw.response.integrators import Domain
9from gpaw.response.qpd import SingleQPWDescriptor
10from gpaw.response.pair import KPointPairFactory, get_gs_and_context
13def check_degenerate_bands(filename, etol):
15 from gpaw import GPAW
16 calc = GPAW(filename, txt=None)
17 print('Number of Electrons :', calc.get_number_of_electrons())
18 nibzkpt = calc.get_ibz_k_points().shape[0]
19 nbands = calc.get_number_of_bands()
20 print('Number of Bands :', nbands)
21 print('Number of ibz-kpoints :', nibzkpt)
22 e_kn = np.array([calc.get_eigenvalues(k) for k in range(nibzkpt)])
23 f_kn = np.array([calc.get_occupation_numbers(k) for k in range(nibzkpt)])
24 for k in range(nibzkpt):
25 for n in range(1, nbands):
26 if (f_kn[k, n - 1] - f_kn[k, n] > 1e-5)\
27 and (np.abs(e_kn[k, n] - e_kn[k, n - 1]) < etol):
28 print(k, n, e_kn[k, n], e_kn[k, n - 1])
31def get_orbitals(calc):
32 """Get LCAO orbitals on 3D grid by lcao_to_grid method."""
34 bfs_a = [setup.basis_functions_J for setup in calc.wfs.setups]
36 from gpaw.lfc import BasisFunctions
37 bfs = BasisFunctions(calc.wfs.gd, bfs_a, calc.wfs.kd.comm, cut=True)
38 bfs.set_positions(calc.spos_ac)
40 nLCAO = calc.get_number_of_bands()
41 orb_MG = calc.wfs.gd.zeros(nLCAO)
42 C_M = np.identity(nLCAO)
43 bfs.lcao_to_grid(C_M, orb_MG, q=-1)
45 return orb_MG
48def get_bz_transitions(filename, q_c, bzk_kc,
49 spins='all',
50 ecut=50, txt=sys.stdout):
51 """
52 Get transitions in the Brillouin zone from kpoints bzk_kv
53 contributing to the linear response at wave vector q_c.
54 """
56 ecut /= Ha
58 gs, context = get_gs_and_context(filename, txt=txt, world=mpi.world,
59 timer=None)
61 kptpair_factory = KPointPairFactory(gs=gs, context=context)
62 qpd = SingleQPWDescriptor.from_q(q_c, ecut, gs.gd)
63 bzk_kv = np.dot(bzk_kc, qpd.gd.icell_cv) * 2 * np.pi
65 if spins == 'all':
66 spins = range(kptpair_factory.gs.nspins)
67 else:
68 for spin in spins:
69 assert spin in range(kptpair_factory.gs.nspins)
71 domain = Domain(bzk_kv, spins)
72 return kptpair_factory, qpd, domain
75def get_chi0_integrand(kptpair_factory, qpd, n_n, m_m, point):
76 """
77 Calculates the pair densities, occupational differences
78 and energy differences of transitions from certain kpoint
79 and spin.
80 """
82 k_v = point.kpt_c
83 optical_limit = qpd.optical_limit
84 k_c = np.dot(qpd.gd.cell_cv, k_v) / (2 * np.pi)
85 K = kptpair_factory.gs.kpoints.kptfinder.find(k_c)
87 pair_calc = kptpair_factory.pair_calculator()
88 kptpair = kptpair_factory.get_kpoint_pair(
89 qpd, point.spin, K, n_n[0], n_n[-1] + 1,
90 m_m[0], m_m[-1] + 1)
92 pairden_paw_corr = pair_calc.gs.pair_density_paw_corrections
93 pawcorr = pairden_paw_corr(qpd)
95 df_nm = kptpair.get_occupation_differences()
96 eps_n = kptpair.kpt1.eps_n
97 eps_m = kptpair.kpt2.eps_n
99 if optical_limit:
100 n_nmP = pair_calc.get_optical_pair_density(
101 qpd, kptpair, n_n, m_m, pawcorr=pawcorr)
103 return n_nmP, df_nm, eps_n, eps_m
104 else:
105 n_nmG = pair_calc.get_pair_density(
106 qpd, kptpair, n_n, m_m, pawcorr=pawcorr)
108 return n_nmG, df_nm, eps_n, eps_m
111def get_degeneracy_matrix(eps_n, tol=1.e-3):
112 """
113 Generate a matrix that can sum over degenerate values.
114 """
115 degmat = []
116 eps_N = []
117 nn = len(eps_n)
118 nstart = 0
119 while nstart < nn:
120 deg = [0] * nstart + [1]
121 eps_N.append(eps_n[nstart])
122 for n in range(nstart + 1, nn):
123 if abs(eps_n[nstart] - eps_n[n]) < tol:
124 deg += [1]
125 nstart += 1
126 else:
127 break
128 deg += [0] * (nn - len(deg))
129 degmat.append(deg)
130 nstart += 1
132 return np.array(degmat), np.array(eps_N)
135def get_individual_transition_strengths(n_nmG, df_nm, G1, G2):
136 return (df_nm * n_nmG[:, :, G1] * n_nmG[:, :, G2].conj()).real
139def find_peaks(x, y, threshold=None):
140 """ Find peaks for a certain curve.
142 Usage:
143 threshold = (xmin, xmax, ymin, ymax)
145 """
147 assert isinstance(x, np.ndarray) and isinstance(y, np.ndarray)
148 assert x.ndim == 1 and y.ndim == 1
149 assert x.shape[0] == y.shape[0]
151 if threshold is None:
152 threshold = (x.min(), x.max(), y.min(), y.max())
154 if not isinstance(threshold, tuple):
155 threshold = (threshold, )
157 if len(threshold) == 1:
158 threshold += (x.max(), y.min(), y.max())
159 elif len(threshold) == 2:
160 threshold += (y.min(), y.max())
161 elif len(threshold) == 3:
162 threshold += (y.max(),)
163 else:
164 pass
166 xmin = threshold[0]
167 xmax = threshold[1]
168 ymin = threshold[2]
169 ymax = threshold[3]
171 peak = {}
172 npeak = 0
173 for i in range(1, x.shape[0] - 1):
174 if (y[i] >= ymin and y[i] <= ymax and
175 x[i] >= xmin and x[i] <= xmax):
176 if y[i] > y[i - 1] and y[i] > y[i + 1]:
177 peak[npeak] = np.array([x[i], y[i]])
178 npeak += 1
180 peakarray = np.zeros([npeak, 2])
181 for i in range(npeak):
182 peakarray[i] = peak[i]
184 return peakarray
187def lorz_fit(x, y, npeak=1, initpara=None):
188 """ Fit curve using Lorentzian function
190 Note: currently only valid for one and two lorentizian
192 The lorentzian function is defined as::
194 A w
195 lorz = --------------------- + y0
196 (x-x0)**2 + w**2
198 where A is the peak amplitude, w is the width, (x0,y0) the peak position
200 Parameters:
202 x, y: ndarray
203 Input data for analyze
204 p: ndarray
205 Parameters for curving fitting function. [A, x0, y0, w]
206 p0: ndarray
207 Parameters for initial guessing. similar to p
209 """
211 def residual(p, x, y):
213 err = y - lorz(x, p, npeak)
214 return err
216 def lorz(x, p, npeak):
218 if npeak == 1:
219 return p[0] * p[3] / ((x - p[1])**2 + p[3]**2) + p[2]
220 if npeak == 2:
221 return (p[0] * p[3] / ((x - p[1])**2 + p[3]**2) + p[2]
222 + p[4] * p[7] / ((x - p[5])**2 + p[7]**2) + p[6])
223 else:
224 raise ValueError('Larger than 2 peaks not supported yet!')
226 if initpara is None:
227 if npeak == 1:
228 initpara = np.array([1., 0., 0., 0.1])
229 if npeak == 2:
230 initpara = np.array([1., 0., 0., 0.1,
231 3., 0., 0., 0.1])
232 p0 = initpara
234 result = leastsq(residual, p0, args=(x, y), maxfev=2000)
236 yfit = lorz(x, result[0], npeak)
238 return yfit, result[0]
241def linear_fit(x, y, initpara=None):
242 def residual(p, x, y):
243 err = y - linear(x, p)
244 return err
246 def linear(x, p):
247 return p[0] * x + p[1]
249 if initpara is None:
250 initpara = np.array([1.0, 1.0])
252 p0 = initpara
253 result = leastsq(residual, p0, args=(x, y), maxfev=2000)
254 yfit = linear(x, result[0])
256 return yfit, result[0]
259def plot_setfont():
260 import matplotlib.pyplot as plt
261 params = {'axes.labelsize': 18,
262 'text.fontsize': 18,
263 'legend.fontsize': 18,
264 'xtick.labelsize': 18,
265 'ytick.labelsize': 18,
266 'text.usetex': True}
267 # 'figure.figsize': fig_size}
268 plt.rcParams.update(params)
271def plot_setticks(x=True, y=True):
272 import matplotlib.pyplot as plt
273 plt.minorticks_on()
274 ax = plt.gca()
275 if x:
276 ax.xaxis.set_major_locator(plt.AutoLocator())
277 x_major = ax.xaxis.get_majorticklocs()
278 dx_minor = (x_major[-1] - x_major[0]) / (len(x_major) - 1) / 5.
279 ax.xaxis.set_minor_locator(plt.MultipleLocator(dx_minor))
280 else:
281 plt.minorticks_off()
283 if y:
284 ax.yaxis.set_major_locator(plt.AutoLocator())
285 y_major = ax.yaxis.get_majorticklocs()
286 dy_minor = (y_major[-1] - y_major[0]) / (len(y_major) - 1) / 5.
287 ax.yaxis.set_minor_locator(plt.MultipleLocator(dy_minor))
288 else:
289 plt.minorticks_off()