Coverage for gpaw/test/response/test_jdos.py: 99%

95 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1# General modules 

2import pytest 

3 

4from itertools import product 

5 

6import numpy as np 

7 

8# import matplotlib.pyplot as plt 

9 

10# Script modules 

11from ase.units import Hartree 

12 

13from gpaw import GPAW 

14import gpaw.mpi as mpi 

15from gpaw.response import ResponseGroundStateAdapter 

16from gpaw.response.frequencies import ComplexFrequencyDescriptor 

17from gpaw.response.jdos import JDOSCalculator 

18from gpaw.response.kpoints import KPointFinder 

19from gpaw.test.response.test_chiks import (generate_system_s, 

20 generate_qrel_q, get_q_c, 

21 generate_nblocks_n) 

22from gpaw.test.gpwfile import response_band_cutoff 

23 

24 

25@pytest.mark.response 

26@pytest.mark.kspair 

27@pytest.mark.parametrize('system,qrel', 

28 product(generate_system_s(), generate_qrel_q())) 

29def test_jdos(in_tmp_dir, gpw_files, system, qrel): 

30 # ---------- Inputs ---------- # 

31 

32 # What material, spin-component and q-vector to calculate the jdos for 

33 wfs, spincomponent = system 

34 q_c = get_q_c(wfs, qrel) 

35 

36 # Where to evaluate the jdos 

37 omega_w = np.linspace(-10.0, 10.0, 321) 

38 eta = 0.2 

39 zd = ComplexFrequencyDescriptor.from_array(omega_w + 1.j * eta) 

40 

41 # Calculation parameters (which should not affect the result) 

42 qsymmetry_s = [True, False] 

43 bandsummation_b = ['double', 'pairwise'] 

44 nblocks_n = generate_nblocks_n() 

45 

46 # ---------- Script ---------- # 

47 

48 # Set up the ground state adapter based on the fixture 

49 calc = GPAW(gpw_files[wfs], parallel=dict(domain=1)) 

50 nbands = response_band_cutoff[wfs] 

51 gs = ResponseGroundStateAdapter(calc) 

52 

53 # Calculate the jdos manually 

54 serial_calc = GPAW(gpw_files[wfs], communicator=mpi.serial_comm) 

55 jdos_refcalc = MyManualJDOS(serial_calc) 

56 jdosref_w = jdos_refcalc.calculate(spincomponent, q_c, 

57 omega_w, 

58 eta=eta, 

59 nbands=nbands) 

60 

61 # Calculate the jdos using the PairFunctionIntegrator module 

62 for qsymmetry in qsymmetry_s: 

63 for bandsummation in bandsummation_b: 

64 for nblocks in nblocks_n: 

65 jdos_calc = JDOSCalculator(gs, 

66 nbands=nbands, 

67 qsymmetry=qsymmetry, 

68 bandsummation=bandsummation, 

69 nblocks=nblocks) 

70 jdos = jdos_calc.calculate(spincomponent, q_c, zd) 

71 jdos_w = jdos.array 

72 assert jdos_w == pytest.approx(jdosref_w) 

73 

74 # plt.subplot() 

75 # plt.plot(wd.omega_w * Hartree, jdos_w) 

76 # plt.plot(wd.omega_w * Hartree, jdosref_w) 

77 # plt.title(f'{q_c} {spincomponent}') 

78 # plt.show() 

79 

80 

81class MyManualJDOS: 

82 def __init__(self, calc): 

83 self.calc = calc 

84 self.nspins = calc.wfs.nspins 

85 

86 kd = calc.wfs.kd 

87 gd = calc.wfs.gd 

88 self.kd = kd 

89 self.kptfinder = KPointFinder(kd.bzk_kc) 

90 self.kweight = 1 / (gd.volume * len(kd.bzk_kc)) 

91 

92 def calculate(self, spincomponent, q_c, omega_w, 

93 eta=0.2, 

94 nbands=None): 

95 r"""Calculate the joint density of states: 

96 __ __ 

97 1 \ \ 

98 g_j(q, ω) = ‾ / / (f_nks - f_mk+qs') δ(ω-[ε_mk+qs' - ε_nks]) 

99 V ‾‾ ‾‾ 

100 k n,m 

101 

102 for a given spin component specifying the spin transitions s -> s'. 

103 """ 

104 q_c = np.asarray(q_c) 

105 # Internal frequencies in Hartree 

106 omega_w = omega_w / Hartree 

107 eta = eta / Hartree 

108 # Allocate array 

109 jdos_w = np.zeros_like(omega_w) 

110 

111 for K1, k1_c in enumerate(self.kd.bzk_kc): 

112 # de = e2 - e1, df = f2 - f1 

113 de_t, df_t = self.get_transitions(K1, k1_c, q_c, 

114 spincomponent, nbands) 

115 

116 if self.nspins == 1: 

117 df_t *= 2 

118 

119 # Set up jdos 

120 delta_wt = self.delta(omega_w, eta, de_t) 

121 jdos_wt = - df_t[np.newaxis] * delta_wt 

122 

123 # Sum over transitions 

124 jdos_w += np.sum(jdos_wt, axis=1) 

125 

126 return self.kweight * jdos_w 

127 

128 @staticmethod 

129 def delta(omega_w, eta, de_t): 

130 r"""Create lorentzian delta-functions 

131 

132 ~ 1 η 

133 δ(ω-Δε) = ‾ ‾‾‾‾‾‾‾‾‾‾‾‾‾‾ 

134 π (ω-Δε)^2 + η^2 

135 """ 

136 x_wt = omega_w[:, np.newaxis] - de_t[np.newaxis] 

137 return eta / np.pi / (x_wt**2. + eta**2.) 

138 

139 def get_transitions(self, K1, k1_c, q_c, spincomponent, nbands): 

140 assert isinstance(nbands, int) 

141 if spincomponent == '00': 

142 if self.nspins == 1: 

143 s1_s = [0] 

144 s2_s = [0] 

145 else: 

146 s1_s = [0, 1] 

147 s2_s = [0, 1] 

148 elif spincomponent == '+-': 

149 s1_s = [0] 

150 s2_s = [1] 

151 else: 

152 raise ValueError(spincomponent) 

153 

154 # Find k_c + q_c 

155 K2 = self.kptfinder.find(k1_c + q_c) 

156 

157 de_t = [] 

158 df_t = [] 

159 kd = self.kd 

160 calc = self.calc 

161 for s1, s2 in zip(s1_s, s2_s): 

162 # Get composite u=(s,k) indices and KPoint objects 

163 u1 = kd.bz2ibz_k[K1] * self.nspins + s1 

164 u2 = kd.bz2ibz_k[K2] * self.nspins + s2 

165 kpt1, kpt2 = calc.wfs.kpt_u[u1], calc.wfs.kpt_u[u2] 

166 

167 # Extract eigenenergies and occupation numbers 

168 eps1_n = kpt1.eps_n[:nbands] 

169 eps2_n = kpt2.eps_n[:nbands] 

170 f1_n = kpt1.f_n[:nbands] / kpt1.weight 

171 f2_n = kpt2.f_n[:nbands] / kpt2.weight 

172 

173 # Append data 

174 de_nm = eps2_n[:, np.newaxis] - eps1_n[np.newaxis] 

175 df_nm = f2_n[:, np.newaxis] - f1_n[np.newaxis] 

176 de_t += list(de_nm.flatten()) 

177 df_t += list(df_nm.flatten()) 

178 de_t = np.array(de_t) 

179 df_t = np.array(df_t) 

180 

181 return de_t, df_t