Coverage for gpaw/hybrids/energy.py: 96%
106 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1from __future__ import annotations
2from pathlib import Path
4import numpy as np
5from ase.units import Ha
7from gpaw import GPAW
8from gpaw.new.ase_interface import ASECalculator
9from gpaw.kpt_descriptor import KPointDescriptor
10from gpaw.mpi import serial_comm
11from gpaw.pw.descriptor import PWDescriptor
12from gpaw.pw.lfc import PWLFC
13from gpaw.xc import XC
14from . import parse_name
15from .coulomb import coulomb_interaction
16from .kpts import get_kpt
17from .paw import calculate_paw_stuff
18from .symmetry import Symmetry
19from gpaw.typing import Array1D
22def non_self_consistent_energy(calc: ASECalculator | str | Path,
23 xcname: str,
24 ftol=1e-9) -> Array1D:
25 """Calculate non self-consistent energy for Hybrid functional.
27 Based on a self-consistent DFT calculation (calc). EXX integrals involving
28 states with occupation numbers less than ftol are skipped.
30 >>> energies = non_self_consistent_energy('<gpw-file>',
31 ... xcname='HSE06')
32 >>> e_hyb = energies.sum()
34 The correction to the self-consistent energy will be
35 ``energies[1:].sum()``.
37 The returned energy contributions are (in eV):
39 1. DFT total free energy (not extrapolated to zero smearing)
40 2. minus DFT XC energy
41 3. Hybrid semi-local XC energy
42 4. EXX core-core energy
43 5. EXX core-valence energy
44 6. EXX valence-valence energy
45 """
47 if calc == '<gpw-file>': # for doctest
48 return np.zeros(6)
50 if isinstance(calc, (str, Path)):
51 calc = GPAW(calc, txt=None, parallel={'band': 1, 'kpt': 1})
53 assert not isinstance(calc, (str, Path)) # for mypy
54 wfs = calc.wfs
55 dens = calc.density
56 kd = wfs.kd
57 setups = wfs.setups
58 nspins = wfs.nspins
60 nocc = max(((kpt.f_n / kpt.weight) > ftol).sum()
61 for kpt in wfs.kpt_u)
62 nocc = kd.comm.max_scalar(wfs.bd.comm.sum_scalar(int(nocc)))
64 xcname, exx_fraction, omega, yukawa = parse_name(xcname)
66 xc = XC(xcname)
67 exc = 0.0
68 for a, D_sp in dens.D_asp.items():
69 exc += xc.calculate_paw_correction(setups[a], D_sp)
70 exc = dens.finegd.comm.sum_scalar(exc)
71 if dens.nt_sg is None:
72 dens.interpolate_pseudo_density()
73 exc += xc.calculate(dens.finegd, dens.nt_sg)
75 coulomb = coulomb_interaction(omega, wfs.gd, kd, yukawa=yukawa)
76 sym = Symmetry(kd)
78 paw_s = calculate_paw_stuff(wfs, dens)
80 ecc = sum(setup.ExxC for setup in setups) * exx_fraction
81 evc = 0.0
82 evv = 0.0
83 for spin in range(nspins):
84 kpts = [get_kpt(wfs, k, spin, 0, nocc) for k in range(kd.nibzkpts)]
85 e1, e2 = calculate_energy(kpts, paw_s[spin],
86 wfs, sym, coulomb, calc.spos_ac)
87 evc += e1 * exx_fraction * 2 / wfs.nspins
88 evv += e2 * exx_fraction * 2 / wfs.nspins
90 return np.array([calc.hamiltonian.e_total_free,
91 -calc.hamiltonian.e_xc,
92 exc,
93 ecc,
94 evc,
95 evv]) * Ha
98def calculate_energy(kpts, paw, wfs, sym, coulomb, spos_ac):
99 pd = kpts[0].psit.pd
100 gd = pd.gd.new_descriptor(comm=serial_comm)
101 comm = wfs.world
103 exxvv = 0.0
104 exxvc = 0.0
105 for i, kpt in enumerate(kpts):
106 for a, VV_ii in paw.VV_aii.items():
107 P_ni = kpt.proj[a]
108 vv_n = np.einsum('ni, ij, nj -> n',
109 P_ni.conj(), VV_ii, P_ni).real
110 vc_n = np.einsum('ni, ij, nj -> n',
111 P_ni.conj(), paw.VC_aii[a], P_ni).real
112 exxvv -= vv_n.dot(kpt.f_n) * kpt.weight
113 exxvc -= vc_n.dot(kpt.f_n) * kpt.weight
115 for i1, i2, s, k1, k2, count in sym.pairs(kpts, wfs, spos_ac):
116 q_c = k2.k_c - k1.k_c
117 qd = KPointDescriptor([-q_c])
119 pd12 = PWDescriptor(pd.ecut, gd, pd.dtype, kd=qd)
120 ghat = PWLFC([data.ghat_l for data in wfs.setups], pd12)
121 ghat.set_positions(spos_ac)
123 v_G = coulomb.get_potential(pd12)
124 e_nn = calculate_exx_for_pair(k1, k2, ghat, v_G, comm,
125 paw.Delta_aiiL)
126 e_nn *= count
127 e = k1.f_n.dot(e_nn).dot(k2.f_n) / sym.kd.nbzkpts**2
128 exxvv -= 0.5 * e
130 exxvv = comm.sum_scalar(exxvv)
131 exxvc = comm.sum_scalar(exxvc)
133 return exxvc, exxvv
136def calculate_exx_for_pair(k1,
137 k2,
138 ghat,
139 v_G,
140 comm,
141 Delta_aiiL):
143 N1 = len(k1.u_nR)
144 N2 = len(k2.u_nR)
146 size = comm.size
147 rank = comm.rank
149 Q_annL = [np.einsum('mi, ijL, nj -> mnL',
150 k1.proj[a],
151 Delta_iiL,
152 k2.proj[a].conj())
153 for a, Delta_iiL in enumerate(Delta_aiiL)]
155 if k1 is k2:
156 n2max = (N1 + size - 1) // size
157 else:
158 n2max = N2
160 e_nn = np.zeros((N1, N2))
161 rho_nG = ghat.pd.empty(n2max, k1.u_nR.dtype)
163 for n1, u1_R in enumerate(k1.u_nR):
164 if k1 is k2:
165 B = (N1 - n1 + size - 1) // size
166 n2a = min(n1 + rank * B, N2)
167 n2b = min(n2a + B, N2)
168 else:
169 n2a = 0
170 n2b = N2
172 for n2, rho_G in enumerate(rho_nG[:n2b - n2a], n2a):
173 rho_G[:] = ghat.pd.fft(u1_R * k2.u_nR[n2].conj())
175 ghat.add(rho_nG[:n2b - n2a],
176 {a: Q_nnL[n1, n2a:n2b]
177 for a, Q_nnL in enumerate(Q_annL)})
179 for n2, rho_G in enumerate(rho_nG[:n2b - n2a], n2a):
180 vrho_G = v_G * rho_G
181 e = ghat.pd.integrate(rho_G, vrho_G).real
182 e_nn[n1, n2] = e
183 if k1 is k2:
184 e_nn[n2, n1] = e
186 return e_nn