Coverage for gpaw/response/tool.py: 37%

161 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import sys 

2 

3import numpy as np 

4from scipy.optimize import leastsq 

5 

6from ase.units import Ha 

7import gpaw.mpi as mpi 

8from gpaw.response.integrators import Domain 

9from gpaw.response.qpd import SingleQPWDescriptor 

10from gpaw.response.pair import KPointPairFactory, get_gs_and_context 

11 

12 

13def check_degenerate_bands(filename, etol): 

14 

15 from gpaw import GPAW 

16 calc = GPAW(filename, txt=None) 

17 print('Number of Electrons :', calc.get_number_of_electrons()) 

18 nibzkpt = calc.get_ibz_k_points().shape[0] 

19 nbands = calc.get_number_of_bands() 

20 print('Number of Bands :', nbands) 

21 print('Number of ibz-kpoints :', nibzkpt) 

22 e_kn = np.array([calc.get_eigenvalues(k) for k in range(nibzkpt)]) 

23 f_kn = np.array([calc.get_occupation_numbers(k) for k in range(nibzkpt)]) 

24 for k in range(nibzkpt): 

25 for n in range(1, nbands): 

26 if (f_kn[k, n - 1] - f_kn[k, n] > 1e-5)\ 

27 and (np.abs(e_kn[k, n] - e_kn[k, n - 1]) < etol): 

28 print(k, n, e_kn[k, n], e_kn[k, n - 1]) 

29 

30 

31def get_orbitals(calc): 

32 """Get LCAO orbitals on 3D grid by lcao_to_grid method.""" 

33 

34 bfs_a = [setup.basis_functions_J for setup in calc.wfs.setups] 

35 

36 from gpaw.lfc import BasisFunctions 

37 bfs = BasisFunctions(calc.wfs.gd, bfs_a, calc.wfs.kd.comm, cut=True) 

38 bfs.set_positions(calc.spos_ac) 

39 

40 nLCAO = calc.get_number_of_bands() 

41 orb_MG = calc.wfs.gd.zeros(nLCAO) 

42 C_M = np.identity(nLCAO) 

43 bfs.lcao_to_grid(C_M, orb_MG, q=-1) 

44 

45 return orb_MG 

46 

47 

48def get_bz_transitions(filename, q_c, bzk_kc, 

49 spins='all', 

50 ecut=50, txt=sys.stdout): 

51 """ 

52 Get transitions in the Brillouin zone from kpoints bzk_kv 

53 contributing to the linear response at wave vector q_c. 

54 """ 

55 

56 ecut /= Ha 

57 

58 gs, context = get_gs_and_context(filename, txt=txt, world=mpi.world, 

59 timer=None) 

60 

61 kptpair_factory = KPointPairFactory(gs=gs, context=context) 

62 qpd = SingleQPWDescriptor.from_q(q_c, ecut, gs.gd) 

63 bzk_kv = np.dot(bzk_kc, qpd.gd.icell_cv) * 2 * np.pi 

64 

65 if spins == 'all': 

66 spins = range(kptpair_factory.gs.nspins) 

67 else: 

68 for spin in spins: 

69 assert spin in range(kptpair_factory.gs.nspins) 

70 

71 domain = Domain(bzk_kv, spins) 

72 return kptpair_factory, qpd, domain 

73 

74 

75def get_chi0_integrand(kptpair_factory, qpd, n_n, m_m, point): 

76 """ 

77 Calculates the pair densities, occupational differences 

78 and energy differences of transitions from certain kpoint 

79 and spin. 

80 """ 

81 

82 k_v = point.kpt_c 

83 optical_limit = qpd.optical_limit 

84 k_c = np.dot(qpd.gd.cell_cv, k_v) / (2 * np.pi) 

85 K = kptpair_factory.gs.kpoints.kptfinder.find(k_c) 

86 

87 pair_calc = kptpair_factory.pair_calculator() 

88 kptpair = kptpair_factory.get_kpoint_pair( 

89 qpd, point.spin, K, n_n[0], n_n[-1] + 1, 

90 m_m[0], m_m[-1] + 1) 

91 

92 pairden_paw_corr = pair_calc.gs.pair_density_paw_corrections 

93 pawcorr = pairden_paw_corr(qpd) 

94 

95 df_nm = kptpair.get_occupation_differences() 

96 eps_n = kptpair.kpt1.eps_n 

97 eps_m = kptpair.kpt2.eps_n 

98 

99 if optical_limit: 

100 n_nmP = pair_calc.get_optical_pair_density( 

101 qpd, kptpair, n_n, m_m, pawcorr=pawcorr) 

102 

103 return n_nmP, df_nm, eps_n, eps_m 

104 else: 

105 n_nmG = pair_calc.get_pair_density( 

106 qpd, kptpair, n_n, m_m, pawcorr=pawcorr) 

107 

108 return n_nmG, df_nm, eps_n, eps_m 

109 

110 

111def get_degeneracy_matrix(eps_n, tol=1.e-3): 

112 """ 

113 Generate a matrix that can sum over degenerate values. 

114 """ 

115 degmat = [] 

116 eps_N = [] 

117 nn = len(eps_n) 

118 nstart = 0 

119 while nstart < nn: 

120 deg = [0] * nstart + [1] 

121 eps_N.append(eps_n[nstart]) 

122 for n in range(nstart + 1, nn): 

123 if abs(eps_n[nstart] - eps_n[n]) < tol: 

124 deg += [1] 

125 nstart += 1 

126 else: 

127 break 

128 deg += [0] * (nn - len(deg)) 

129 degmat.append(deg) 

130 nstart += 1 

131 

132 return np.array(degmat), np.array(eps_N) 

133 

134 

135def get_individual_transition_strengths(n_nmG, df_nm, G1, G2): 

136 return (df_nm * n_nmG[:, :, G1] * n_nmG[:, :, G2].conj()).real 

137 

138 

139def find_peaks(x, y, threshold=None): 

140 """ Find peaks for a certain curve. 

141 

142 Usage: 

143 threshold = (xmin, xmax, ymin, ymax) 

144 

145 """ 

146 

147 assert isinstance(x, np.ndarray) and isinstance(y, np.ndarray) 

148 assert x.ndim == 1 and y.ndim == 1 

149 assert x.shape[0] == y.shape[0] 

150 

151 if threshold is None: 

152 threshold = (x.min(), x.max(), y.min(), y.max()) 

153 

154 if not isinstance(threshold, tuple): 

155 threshold = (threshold, ) 

156 

157 if len(threshold) == 1: 

158 threshold += (x.max(), y.min(), y.max()) 

159 elif len(threshold) == 2: 

160 threshold += (y.min(), y.max()) 

161 elif len(threshold) == 3: 

162 threshold += (y.max(),) 

163 else: 

164 pass 

165 

166 xmin = threshold[0] 

167 xmax = threshold[1] 

168 ymin = threshold[2] 

169 ymax = threshold[3] 

170 

171 peak = {} 

172 npeak = 0 

173 for i in range(1, x.shape[0] - 1): 

174 if (y[i] >= ymin and y[i] <= ymax and 

175 x[i] >= xmin and x[i] <= xmax): 

176 if y[i] > y[i - 1] and y[i] > y[i + 1]: 

177 peak[npeak] = np.array([x[i], y[i]]) 

178 npeak += 1 

179 

180 peakarray = np.zeros([npeak, 2]) 

181 for i in range(npeak): 

182 peakarray[i] = peak[i] 

183 

184 return peakarray 

185 

186 

187def lorz_fit(x, y, npeak=1, initpara=None): 

188 """ Fit curve using Lorentzian function 

189 

190 Note: currently only valid for one and two lorentizian 

191 

192 The lorentzian function is defined as:: 

193 

194 A w 

195 lorz = --------------------- + y0 

196 (x-x0)**2 + w**2 

197 

198 where A is the peak amplitude, w is the width, (x0,y0) the peak position 

199 

200 Parameters: 

201 

202 x, y: ndarray 

203 Input data for analyze 

204 p: ndarray 

205 Parameters for curving fitting function. [A, x0, y0, w] 

206 p0: ndarray 

207 Parameters for initial guessing. similar to p 

208 

209 """ 

210 

211 def residual(p, x, y): 

212 

213 err = y - lorz(x, p, npeak) 

214 return err 

215 

216 def lorz(x, p, npeak): 

217 

218 if npeak == 1: 

219 return p[0] * p[3] / ((x - p[1])**2 + p[3]**2) + p[2] 

220 if npeak == 2: 

221 return (p[0] * p[3] / ((x - p[1])**2 + p[3]**2) + p[2] 

222 + p[4] * p[7] / ((x - p[5])**2 + p[7]**2) + p[6]) 

223 else: 

224 raise ValueError('Larger than 2 peaks not supported yet!') 

225 

226 if initpara is None: 

227 if npeak == 1: 

228 initpara = np.array([1., 0., 0., 0.1]) 

229 if npeak == 2: 

230 initpara = np.array([1., 0., 0., 0.1, 

231 3., 0., 0., 0.1]) 

232 p0 = initpara 

233 

234 result = leastsq(residual, p0, args=(x, y), maxfev=2000) 

235 

236 yfit = lorz(x, result[0], npeak) 

237 

238 return yfit, result[0] 

239 

240 

241def linear_fit(x, y, initpara=None): 

242 def residual(p, x, y): 

243 err = y - linear(x, p) 

244 return err 

245 

246 def linear(x, p): 

247 return p[0] * x + p[1] 

248 

249 if initpara is None: 

250 initpara = np.array([1.0, 1.0]) 

251 

252 p0 = initpara 

253 result = leastsq(residual, p0, args=(x, y), maxfev=2000) 

254 yfit = linear(x, result[0]) 

255 

256 return yfit, result[0] 

257 

258 

259def plot_setfont(): 

260 import matplotlib.pyplot as plt 

261 params = {'axes.labelsize': 18, 

262 'text.fontsize': 18, 

263 'legend.fontsize': 18, 

264 'xtick.labelsize': 18, 

265 'ytick.labelsize': 18, 

266 'text.usetex': True} 

267 # 'figure.figsize': fig_size} 

268 plt.rcParams.update(params) 

269 

270 

271def plot_setticks(x=True, y=True): 

272 import matplotlib.pyplot as plt 

273 plt.minorticks_on() 

274 ax = plt.gca() 

275 if x: 

276 ax.xaxis.set_major_locator(plt.AutoLocator()) 

277 x_major = ax.xaxis.get_majorticklocs() 

278 dx_minor = (x_major[-1] - x_major[0]) / (len(x_major) - 1) / 5. 

279 ax.xaxis.set_minor_locator(plt.MultipleLocator(dx_minor)) 

280 else: 

281 plt.minorticks_off() 

282 

283 if y: 

284 ax.yaxis.set_major_locator(plt.AutoLocator()) 

285 y_major = ax.yaxis.get_majorticklocs() 

286 dy_minor = (y_major[-1] - y_major[0]) / (len(y_major) - 1) / 5. 

287 ax.yaxis.set_minor_locator(plt.MultipleLocator(dy_minor)) 

288 else: 

289 plt.minorticks_off()