Coverage for gpaw/hybrids/eigenvalues.py: 99%

166 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1"""Calculate non self-consistent eigenvalues for hybrid functionals.""" 

2from __future__ import annotations 

3import functools 

4import json 

5from pathlib import Path 

6from typing import Generator, List, Optional, Tuple, Union 

7 

8import numpy as np 

9from ase.units import Ha 

10from gpaw.calculator import GPAW as GPAWOld 

11from gpaw import GPAW 

12from gpaw.new.ase_interface import ASECalculator 

13from gpaw.kpt_descriptor import KPointDescriptor 

14from gpaw.mpi import serial_comm 

15from gpaw.pw.descriptor import PWDescriptor 

16from gpaw.pw.lfc import PWLFC 

17from gpaw.typing import Array3D 

18from gpaw.xc import XC 

19from gpaw.xc.kernel import XCNull 

20from gpaw.xc.tools import vxc 

21 

22from gpaw.hybrids import parse_name 

23from gpaw.hybrids.coulomb import coulomb_interaction 

24from gpaw.hybrids.kpts import RSKPoint, get_kpt, to_real_space 

25from gpaw.hybrids.paw import calculate_paw_stuff 

26from gpaw.hybrids.symmetry import Symmetry 

27 

28 

29def non_self_consistent_eigenvalues( 

30 calc: Union[GPAWOld, ASECalculator, str, Path], 

31 xcname: str, 

32 n1: int = 0, 

33 n2: int = 0, 

34 kpt_indices: List[int] = None, 

35 snapshot: Union[str, Path] = None, 

36 ftol: float = 1e-9) -> tuple[Array3D, 

37 Array3D, 

38 Array3D]: 

39 """Calculate non self-consistent eigenvalues for a hybrid functional. 

40 

41 Based on a self-consistent DFT calculation (calc). Only eigenvalues n1 to 

42 n2 - 1 for the IBZ indices in kpt_indices are calculated 

43 (default is all bands and all k-points). EXX integrals involving 

44 states with occupation numbers less than ftol are skipped. Use 

45 snapshot='name.json' to get snapshots for each k-point finished. 

46 

47 Returns three (nspins, nkpts, n2 - n1)-shaped ndarrays 

48 with contributions to the eigenvalues in eV: 

49 

50 >>> nsceigs = non_self_consistent_eigenvalues 

51 >>> eig_dft, vxc_dft, vxc_hyb = nsceigs('<gpw-file>', xcname='PBE0') 

52 >>> eig_hyb = eig_dft - vxc_dft + vxc_hyb 

53 """ 

54 

55 if not isinstance(calc, (GPAWOld, ASECalculator)): 

56 if calc == '<gpw-file>': # for doctest 

57 return (np.zeros((1, 1, 1)), 

58 np.zeros((1, 1, 1)), 

59 np.zeros((1, 1, 1))) 

60 calc = GPAW(Path(calc), txt=None, parallel={'band': 1, 'kpt': 1}) 

61 

62 assert isinstance(calc, (GPAWOld, ASECalculator)) 

63 wfs = calc.wfs 

64 

65 if n2 <= 0: 

66 n2 += wfs.bd.nbands 

67 

68 if kpt_indices is None: 

69 kpt_indices = np.arange(wfs.kd.nibzkpts).tolist() 

70 

71 path = Path(snapshot) if snapshot is not None else None 

72 

73 e_dft_sin = np.zeros(0) 

74 v_dft_sin = np.zeros(0) 

75 

76 # sl=semilocal, nl=nonlocal 

77 v_hyb_sl_sin = np.zeros(0) 

78 v_hyb_nl_sin: Optional[List[List[np.ndarray]]] = None 

79 

80 if path: 

81 e_dft_sin, v_dft_sin, v_hyb_sl_sin, v_hyb_nl_sin = read_snapshot(path) 

82 

83 xcname, exx_fraction, omega, yukawa = parse_name(xcname) 

84 

85 if v_dft_sin.size == 0: 

86 xc = XC(xcname) 

87 e_dft_sin, v_dft_sin, v_hyb_sl_sin = _semi_local( 

88 calc, xc, n1, n2, kpt_indices) 

89 write_snapshot(e_dft_sin, v_dft_sin, v_hyb_sl_sin, v_hyb_nl_sin, 

90 path, wfs.world) 

91 # Non-local hybrid contribution 

92 if v_hyb_nl_sin is None: 

93 v_hyb_nl_sin = [[] for s in range(wfs.nspins)] 

94 

95 # Find missing indices: 

96 kpt_indices_s = [kpt_indices[len(v_hyb_nl_in):] 

97 for v_hyb_nl_in in v_hyb_nl_sin] 

98 

99 if any(len(kpt_indices) > 0 for kpt_indices in kpt_indices_s): 

100 for s, v_hyb_nl_n in _non_local(calc, n1, n2, kpt_indices_s, 

101 ftol, omega, yukawa): 

102 v_hyb_nl_sin[s].append(v_hyb_nl_n * exx_fraction) 

103 write_snapshot(e_dft_sin, v_dft_sin, v_hyb_sl_sin, v_hyb_nl_sin, 

104 path, wfs.world) 

105 

106 return (e_dft_sin * Ha, 

107 v_dft_sin * Ha, 

108 (v_hyb_sl_sin + v_hyb_nl_sin) * Ha) 

109 

110 

111def _semi_local(calc: GPAWOld | ASECalculator, 

112 xc, 

113 n1: int, 

114 n2: int, 

115 kpt_indices: List[int] 

116 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 

117 wfs = calc.wfs 

118 nspins = wfs.nspins 

119 e_dft_sin = np.array([[calc.get_eigenvalues(k, spin)[n1:n2] 

120 for k in kpt_indices] 

121 for spin in range(nspins)]) 

122 v_dft_sin = vxc(calc.gs_adapter(), n1=n1, n2=n2)[:, kpt_indices] 

123 if isinstance(xc.kernel, XCNull): 

124 v_hyb_sl_sin = np.zeros_like(v_dft_sin) 

125 else: 

126 v_hyb_sl_sin = vxc(calc.gs_adapter(), xc, n1=n1, n2=n2)[:, kpt_indices] 

127 return e_dft_sin / Ha, v_dft_sin / Ha, v_hyb_sl_sin / Ha 

128 

129 

130def _non_local(calc: GPAWOld | ASECalculator, 

131 n1: int, 

132 n2: int, 

133 kpt_indices_s: List[List[int]], 

134 ftol: float, 

135 omega: float, 

136 yukawa: bool) -> Generator[Tuple[int, np.ndarray], None, None]: 

137 wfs = calc.wfs 

138 kd = wfs.kd 

139 dens = calc.density 

140 

141 nocc = max(((kpt.f_n / kpt.weight) > ftol).sum() 

142 for kpt in wfs.kpt_u) 

143 nocc = kd.comm.max_scalar(wfs.bd.comm.sum_scalar(int(nocc))) 

144 

145 coulomb = coulomb_interaction(omega, wfs.gd, kd, yukawa=yukawa) 

146 sym = Symmetry(kd) 

147 

148 paw_s = calculate_paw_stuff(wfs, dens) 

149 

150 for spin, kpt_indices in enumerate(kpt_indices_s): 

151 if len(kpt_indices) == 0: 

152 continue 

153 kpts2 = [get_kpt(wfs, k, spin, 0, nocc) for k in range(kd.nibzkpts)] 

154 for i in kpt_indices: 

155 kpt1 = get_kpt(wfs, i, spin, n1, n2) 

156 v_n = _calculate_eigenvalues( 

157 kpt1, kpts2, paw_s[spin], kd, coulomb, sym, wfs, calc.spos_ac) 

158 wfs.world.sum(v_n) 

159 yield spin, v_n 

160 

161 

162def _calculate_eigenvalues(kpt1, kpts2, paw, kd, coulomb, sym, wfs, spos_ac): 

163 pd = kpt1.psit.pd 

164 gd = pd.gd.new_descriptor(comm=serial_comm) 

165 comm = wfs.world 

166 size = comm.size 

167 rank = comm.rank 

168 

169 nsym = len(kd.symmetry.op_scc) 

170 assert len(kpts2) == kd.nibzkpts 

171 

172 N1 = len(kpt1.psit.array) 

173 N2 = len(kpts2[0].psit.array) 

174 

175 size1, size2 = layout(N1, N2, size) 

176 assert size1 * size2 == size 

177 B1 = (N1 + size1 - 1) // size1 

178 B2 = (N2 + size2 - 1) // size2 

179 rank1, rank2 = divmod(rank, size2) 

180 n1a = min(B1 * rank1, N1) 

181 n1b = min(n1a + B1, N1) 

182 n2a = min(B2 * rank2, N2) 

183 n2b = min(n2a + B2, N2) 

184 

185 u1_nR = to_real_space(kpt1.psit, n1a, n1b) 

186 proj1all = kpt1.proj.broadcast() 

187 proj1 = proj1all.view(n1a, n1b) 

188 

189 E_n = np.zeros(N1) 

190 e_n = E_n[n1a:n1b] 

191 e_nn = np.empty((n1b - n1a, n2b - n2a)) 

192 

193 for i2, kpt2 in enumerate(kpts2): 

194 u2_nR = to_real_space(kpt2.psit, n2a, n2b) 

195 rskpt0 = RSKPoint(u2_nR, 

196 kpt2.proj.broadcast().view(n2a, n2b), 

197 kpt2.f_n[n2a:n2b], 

198 kpt2.k_c, 

199 kpt2.weight) 

200 for k, i in enumerate(kd.bz2ibz_k): 

201 if i != i2: 

202 continue 

203 s = kd.sym_k[k] + kd.time_reversal_k[k] * nsym 

204 rskpt2 = sym.apply_symmetry(s, rskpt0, wfs, spos_ac) 

205 q_c = rskpt2.k_c - kpt1.k_c 

206 qd = KPointDescriptor([-q_c]) 

207 pd12 = PWDescriptor(pd.ecut, gd, pd.dtype, kd=qd) 

208 ghat = PWLFC([data.ghat_l for data in wfs.setups], pd12) 

209 ghat.set_positions(spos_ac) 

210 v_G = coulomb.get_potential(pd12) 

211 Q_annL = [np.einsum('mi, ijL, nj -> mnL', 

212 proj1[a], 

213 Delta_iiL, 

214 rskpt2.proj[a].conj()) 

215 for a, Delta_iiL in enumerate(paw.Delta_aiiL)] 

216 rho_nG = ghat.pd.empty(n2b - n2a, u1_nR.dtype) 

217 

218 for n1, u1_R in enumerate(u1_nR): 

219 for u2_R, rho_G in zip(rskpt2.u_nR, rho_nG): 

220 rho_G[:] = ghat.pd.fft(u1_R * u2_R.conj()) 

221 

222 ghat.add(rho_nG, 

223 {a: Q_nnL[n1] for a, Q_nnL in enumerate(Q_annL)}) 

224 

225 for n2, rho_G in enumerate(rho_nG): 

226 vrho_G = v_G * rho_G 

227 e = ghat.pd.integrate(rho_G, vrho_G).real 

228 e_nn[n1, n2] = e / kd.nbzkpts 

229 e_n -= e_nn.dot(rskpt2.f_n) 

230 

231 for a, VV_ii in paw.VV_aii.items(): 

232 P_ni = proj1all[a] 

233 vv_n = np.einsum('ni, ij, nj -> n', 

234 P_ni.conj(), VV_ii, P_ni).real 

235 if paw.VC_aii[a] is not None: 

236 vc_n = np.einsum('ni, ij, nj -> n', 

237 P_ni.conj(), paw.VC_aii[a], P_ni).real 

238 else: 

239 vc_n = 0.0 

240 E_n -= (2 * vv_n + vc_n) 

241 

242 return E_n 

243 

244 

245def write_snapshot(e_dft_sin: np.ndarray, 

246 v_dft_sin: np.ndarray, 

247 v_hyb_sl_sin: np.ndarray, 

248 v_hyb_nl_sin: Optional[List[List[np.ndarray]]], 

249 path: Optional[Path], 

250 comm) -> None: 

251 """Write to json-file what has been calculated so far.""" 

252 if comm.rank == 0 and path: 

253 dct = {'e_dft_sin': e_dft_sin.tolist(), 

254 'v_dft_sin': v_dft_sin.tolist(), 

255 'v_hyb_sl_sin': v_hyb_sl_sin.tolist()} 

256 if v_hyb_nl_sin is not None: 

257 dct['v_hyb_nl_sin'] = [[v_n.tolist() 

258 for v_n in v_in] 

259 for v_in in v_hyb_nl_sin] 

260 path.write_text(json.dumps(dct, indent=0)) 

261 

262 

263def read_snapshot(snapshot: Path 

264 ) -> Tuple[np.ndarray, 

265 np.ndarray, 

266 np.ndarray, 

267 Optional[List[List[np.ndarray]]]]: 

268 """Read from json-file what has already been calculated.""" 

269 if snapshot.is_file(): 

270 dct = json.loads(snapshot.read_text()) 

271 v_hyb_nl_sin = dct.get('v_hyb_nl_sin') 

272 if v_hyb_nl_sin is not None: 

273 v_hyb_nl_sin = [[np.array(v_n) 

274 for v_n in v_in] 

275 for v_in in v_hyb_nl_sin] 

276 return (np.array(dct['e_dft_sin']), 

277 np.array(dct['v_dft_sin']), 

278 np.array(dct['v_hyb_sl_sin']), 

279 v_hyb_nl_sin) 

280 return np.array([[[]]]), np.array([[[]]]), np.array([[[]]]), None 

281 

282 

283@functools.lru_cache() 

284def layout(n1: int, n2: int, size: int) -> Tuple[int, int]: 

285 """Distribute n1*n2 matrix over s1*s2=size blocks. 

286 

287 Returns s1, s2. 

288 

289 >>> layout(10, 10, 8) 

290 (4, 2) 

291 """ 

292 candidates: List[Tuple[float, int, int]] = [] 

293 for s1 in range(1, size + 1): 

294 s2, r = divmod(size, s1) 

295 if r > 0: 

296 continue 

297 fitness = (1 - idle(n1, s1)) * (1 - idle(n2, s2)) 

298 candidates.append((fitness, s1, s2)) 

299 return max(candidates)[1:] 

300 

301 

302def idle(n: int, s: int) -> float: 

303 """Idle fraction (helper function for layout() function).""" 

304 b = (n + s - 1) // s 

305 return 1 - n / (b * s)