Coverage for gpaw/hybrids/energy.py: 96%

106 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1from __future__ import annotations 

2from pathlib import Path 

3 

4import numpy as np 

5from ase.units import Ha 

6 

7from gpaw import GPAW 

8from gpaw.new.ase_interface import ASECalculator 

9from gpaw.kpt_descriptor import KPointDescriptor 

10from gpaw.mpi import serial_comm 

11from gpaw.pw.descriptor import PWDescriptor 

12from gpaw.pw.lfc import PWLFC 

13from gpaw.xc import XC 

14from . import parse_name 

15from .coulomb import coulomb_interaction 

16from .kpts import get_kpt 

17from .paw import calculate_paw_stuff 

18from .symmetry import Symmetry 

19from gpaw.typing import Array1D 

20 

21 

22def non_self_consistent_energy(calc: ASECalculator | str | Path, 

23 xcname: str, 

24 ftol=1e-9) -> Array1D: 

25 """Calculate non self-consistent energy for Hybrid functional. 

26 

27 Based on a self-consistent DFT calculation (calc). EXX integrals involving 

28 states with occupation numbers less than ftol are skipped. 

29 

30 >>> energies = non_self_consistent_energy('<gpw-file>', 

31 ... xcname='HSE06') 

32 >>> e_hyb = energies.sum() 

33 

34 The correction to the self-consistent energy will be 

35 ``energies[1:].sum()``. 

36 

37 The returned energy contributions are (in eV): 

38 

39 1. DFT total free energy (not extrapolated to zero smearing) 

40 2. minus DFT XC energy 

41 3. Hybrid semi-local XC energy 

42 4. EXX core-core energy 

43 5. EXX core-valence energy 

44 6. EXX valence-valence energy 

45 """ 

46 

47 if calc == '<gpw-file>': # for doctest 

48 return np.zeros(6) 

49 

50 if isinstance(calc, (str, Path)): 

51 calc = GPAW(calc, txt=None, parallel={'band': 1, 'kpt': 1}) 

52 

53 assert not isinstance(calc, (str, Path)) # for mypy 

54 wfs = calc.wfs 

55 dens = calc.density 

56 kd = wfs.kd 

57 setups = wfs.setups 

58 nspins = wfs.nspins 

59 

60 nocc = max(((kpt.f_n / kpt.weight) > ftol).sum() 

61 for kpt in wfs.kpt_u) 

62 nocc = kd.comm.max_scalar(wfs.bd.comm.sum_scalar(int(nocc))) 

63 

64 xcname, exx_fraction, omega, yukawa = parse_name(xcname) 

65 

66 xc = XC(xcname) 

67 exc = 0.0 

68 for a, D_sp in dens.D_asp.items(): 

69 exc += xc.calculate_paw_correction(setups[a], D_sp) 

70 exc = dens.finegd.comm.sum_scalar(exc) 

71 if dens.nt_sg is None: 

72 dens.interpolate_pseudo_density() 

73 exc += xc.calculate(dens.finegd, dens.nt_sg) 

74 

75 coulomb = coulomb_interaction(omega, wfs.gd, kd, yukawa=yukawa) 

76 sym = Symmetry(kd) 

77 

78 paw_s = calculate_paw_stuff(wfs, dens) 

79 

80 ecc = sum(setup.ExxC for setup in setups) * exx_fraction 

81 evc = 0.0 

82 evv = 0.0 

83 for spin in range(nspins): 

84 kpts = [get_kpt(wfs, k, spin, 0, nocc) for k in range(kd.nibzkpts)] 

85 e1, e2 = calculate_energy(kpts, paw_s[spin], 

86 wfs, sym, coulomb, calc.spos_ac) 

87 evc += e1 * exx_fraction * 2 / wfs.nspins 

88 evv += e2 * exx_fraction * 2 / wfs.nspins 

89 

90 return np.array([calc.hamiltonian.e_total_free, 

91 -calc.hamiltonian.e_xc, 

92 exc, 

93 ecc, 

94 evc, 

95 evv]) * Ha 

96 

97 

98def calculate_energy(kpts, paw, wfs, sym, coulomb, spos_ac): 

99 pd = kpts[0].psit.pd 

100 gd = pd.gd.new_descriptor(comm=serial_comm) 

101 comm = wfs.world 

102 

103 exxvv = 0.0 

104 exxvc = 0.0 

105 for i, kpt in enumerate(kpts): 

106 for a, VV_ii in paw.VV_aii.items(): 

107 P_ni = kpt.proj[a] 

108 vv_n = np.einsum('ni, ij, nj -> n', 

109 P_ni.conj(), VV_ii, P_ni).real 

110 vc_n = np.einsum('ni, ij, nj -> n', 

111 P_ni.conj(), paw.VC_aii[a], P_ni).real 

112 exxvv -= vv_n.dot(kpt.f_n) * kpt.weight 

113 exxvc -= vc_n.dot(kpt.f_n) * kpt.weight 

114 

115 for i1, i2, s, k1, k2, count in sym.pairs(kpts, wfs, spos_ac): 

116 q_c = k2.k_c - k1.k_c 

117 qd = KPointDescriptor([-q_c]) 

118 

119 pd12 = PWDescriptor(pd.ecut, gd, pd.dtype, kd=qd) 

120 ghat = PWLFC([data.ghat_l for data in wfs.setups], pd12) 

121 ghat.set_positions(spos_ac) 

122 

123 v_G = coulomb.get_potential(pd12) 

124 e_nn = calculate_exx_for_pair(k1, k2, ghat, v_G, comm, 

125 paw.Delta_aiiL) 

126 e_nn *= count 

127 e = k1.f_n.dot(e_nn).dot(k2.f_n) / sym.kd.nbzkpts**2 

128 exxvv -= 0.5 * e 

129 

130 exxvv = comm.sum_scalar(exxvv) 

131 exxvc = comm.sum_scalar(exxvc) 

132 

133 return exxvc, exxvv 

134 

135 

136def calculate_exx_for_pair(k1, 

137 k2, 

138 ghat, 

139 v_G, 

140 comm, 

141 Delta_aiiL): 

142 

143 N1 = len(k1.u_nR) 

144 N2 = len(k2.u_nR) 

145 

146 size = comm.size 

147 rank = comm.rank 

148 

149 Q_annL = [np.einsum('mi, ijL, nj -> mnL', 

150 k1.proj[a], 

151 Delta_iiL, 

152 k2.proj[a].conj()) 

153 for a, Delta_iiL in enumerate(Delta_aiiL)] 

154 

155 if k1 is k2: 

156 n2max = (N1 + size - 1) // size 

157 else: 

158 n2max = N2 

159 

160 e_nn = np.zeros((N1, N2)) 

161 rho_nG = ghat.pd.empty(n2max, k1.u_nR.dtype) 

162 

163 for n1, u1_R in enumerate(k1.u_nR): 

164 if k1 is k2: 

165 B = (N1 - n1 + size - 1) // size 

166 n2a = min(n1 + rank * B, N2) 

167 n2b = min(n2a + B, N2) 

168 else: 

169 n2a = 0 

170 n2b = N2 

171 

172 for n2, rho_G in enumerate(rho_nG[:n2b - n2a], n2a): 

173 rho_G[:] = ghat.pd.fft(u1_R * k2.u_nR[n2].conj()) 

174 

175 ghat.add(rho_nG[:n2b - n2a], 

176 {a: Q_nnL[n1, n2a:n2b] 

177 for a, Q_nnL in enumerate(Q_annL)}) 

178 

179 for n2, rho_G in enumerate(rho_nG[:n2b - n2a], n2a): 

180 vrho_G = v_G * rho_G 

181 e = ghat.pd.integrate(rho_G, vrho_G).real 

182 e_nn[n1, n2] = e 

183 if k1 is k2: 

184 e_nn[n2, n1] = e 

185 

186 return e_nn