Coverage for gpaw/response/pair_transitions.py: 89%

103 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4 

5 

6class PairTransitions: 

7 """Bookkeeping object for transitions in band and spin indices. 

8 

9 All transitions between different band and spin indices (for a given pair 

10 of k-points k and k + q) are accounted for via single transition index t, 

11 

12 t (composite transition index): (n, s) -> (n', s') 

13 """ 

14 

15 def __init__(self, n1_t, n2_t, s1_t, s2_t): 

16 """Construct the PairTransitions object. 

17 

18 Parameters 

19 ---------- 

20 n1_t : np.array 

21 Band index of k-point k for each transition t. 

22 n2_t : np.array 

23 Band index of k-point k + q for each transition t. 

24 s1_t : np.array 

25 Spin index of k-point k for each transition t. 

26 s2_t : np.array 

27 Spin index of k-point k + q for each transition t. 

28 """ 

29 self.n1_t = n1_t 

30 self.n2_t = n2_t 

31 self.s1_t = s1_t 

32 self.s2_t = s2_t 

33 

34 assert len(n2_t) == len(self) 

35 assert len(s1_t) == len(self) 

36 assert len(s2_t) == len(self) 

37 

38 def __len__(self): 

39 return len(self.n1_t) 

40 

41 def get_band_indices(self): 

42 return self.n1_t, self.n2_t 

43 

44 def get_spin_indices(self): 

45 return self.s1_t, self.s2_t 

46 

47 def get_intraband_mask(self): 

48 """Get mask for selecting intraband transitions.""" 

49 intraband_t = (self.n1_t == self.n2_t) & (self.s1_t == self.s2_t) 

50 return intraband_t 

51 

52 @classmethod 

53 def from_transitions_domain_arguments(cls, spincomponent, 

54 nbands, nocc1, nocc2, nspins, 

55 bandsummation) -> PairTransitions: 

56 """Generate the band and spin transitions integration domain. 

57 

58 The integration domain is determined by the spin rotation (from spin 

59 index s to spin index s'), the number of bands and spins in the 

60 underlying ground state calculation as well as the band summation 

61 scheme. 

62 

63 The integration domain automatically excludes transitions between two 

64 occupied bands and two unoccupied bands respectively. 

65 

66 Parameters 

67 ---------- 

68 spincomponent : str 

69 Spin component (μν) of the pair function. 

70 Currently, '00', 'uu', 'dd', '+-' and '-+' are implemented. 

71 nbands : int 

72 Maximum band index to include. 

73 nocc1 : int 

74 Number of completely filled bands in the ground state calculation 

75 nocc2 : int 

76 Number of non-empty bands in the ground state calculation 

77 nspins : int 

78 Number of spin channels in the ground state calculation (1 or 2) 

79 bandsummation : str 

80 Band (and spin) summation scheme for pairs of Kohn-Sham orbitals 

81 'pairwise': sum over pairs of bands (and spins) 

82 'double': double sum over band (and spin) indices. 

83 """ 

84 n1_M, n2_M = get_band_transitions_domain(bandsummation, nbands, 

85 nocc1=nocc1, 

86 nocc2=nocc2) 

87 s1_S, s2_S = get_spin_transitions_domain(bandsummation, 

88 spincomponent, nspins) 

89 

90 n1_t, n2_t, s1_t, s2_t = transitions_in_composite_index(n1_M, n2_M, 

91 s1_S, s2_S) 

92 

93 return cls(n1_t, n2_t, s1_t, s2_t) 

94 

95 

96def get_band_transitions_domain(bandsummation, nbands, nocc1=None, nocc2=None): 

97 """Get all pairs of bands to sum over 

98 

99 Parameters 

100 ---------- 

101 bandsummation : str 

102 Band summation method 

103 nbands : int 

104 number of bands 

105 nocc1 : int 

106 number of completely filled bands 

107 nocc2 : int 

108 number of non-empty bands 

109 

110 Returns 

111 ------- 

112 n1_M : ndarray 

113 band index 1, M = (n1, n2) composite index 

114 n2_M : ndarray 

115 band index 2, M = (n1, n2) composite index 

116 """ 

117 _get_band_transitions_domain =\ 

118 create_get_band_transitions_domain(bandsummation) 

119 n1_M, n2_M = _get_band_transitions_domain(nbands) 

120 

121 return remove_null_transitions(n1_M, n2_M, nocc1=nocc1, nocc2=nocc2) 

122 

123 

124def create_get_band_transitions_domain(bandsummation): 

125 """Creator component deciding how to carry out band summation.""" 

126 if bandsummation == 'pairwise': 

127 return get_pairwise_band_transitions_domain 

128 elif bandsummation == 'double': 

129 return get_double_band_transitions_domain 

130 raise ValueError(bandsummation) 

131 

132 

133def get_double_band_transitions_domain(nbands): 

134 """Make a simple double sum""" 

135 n_n = np.arange(0, nbands) 

136 m_m = np.arange(0, nbands) 

137 n_nm, m_nm = np.meshgrid(n_n, m_m) 

138 n_M, m_M = n_nm.flatten(), m_nm.flatten() 

139 

140 return n_M, m_M 

141 

142 

143def get_pairwise_band_transitions_domain(nbands): 

144 """Make a sum over all pairs""" 

145 n_n = range(0, nbands) 

146 n_M = [] 

147 m_M = [] 

148 for n in n_n: 

149 m_m = range(n, nbands) 

150 n_M += [n] * len(m_m) 

151 m_M += m_m 

152 

153 return np.array(n_M), np.array(m_M) 

154 

155 

156def remove_null_transitions(n1_M, n2_M, nocc1=None, nocc2=None): 

157 """Remove pairs of bands, between which transitions are impossible""" 

158 n1_newM = [] 

159 n2_newM = [] 

160 for n1, n2 in zip(n1_M, n2_M): 

161 if nocc1 is not None and (n1 < nocc1 and n2 < nocc1): 

162 continue # both bands are fully occupied 

163 elif nocc2 is not None and (n1 >= nocc2 and n2 >= nocc2): 

164 continue # both bands are completely unoccupied 

165 n1_newM.append(n1) 

166 n2_newM.append(n2) 

167 

168 return np.array(n1_newM), np.array(n2_newM) 

169 

170 

171def get_spin_transitions_domain(bandsummation, spincomponent, nspins): 

172 """Get structure of the sum over spins 

173 

174 Parameters 

175 ---------- 

176 bandsummation : str 

177 Band summation method 

178 spincomponent : str 

179 Spin component (μν) of the pair function. 

180 Currently, '00', 'uu', 'dd', '+-' and '-+' are implemented. 

181 nspins : int 

182 number of spin channels in ground state calculation 

183 

184 Returns 

185 ------- 

186 s1_s : ndarray 

187 spin index 1, S = (s1, s2) composite index 

188 s2_S : ndarray 

189 spin index 2, S = (s1, s2) composite index 

190 """ 

191 _get_spin_transitions_domain =\ 

192 create_get_spin_transitions_domain(bandsummation) 

193 return _get_spin_transitions_domain(spincomponent, nspins) 

194 

195 

196def create_get_spin_transitions_domain(bandsummation): 

197 """Creator component deciding how to carry out spin summation.""" 

198 if bandsummation == 'pairwise': 

199 return get_pairwise_spin_transitions_domain 

200 elif bandsummation == 'double': 

201 return get_double_spin_transitions_domain 

202 raise ValueError(bandsummation) 

203 

204 

205def get_double_spin_transitions_domain(spincomponent, nspins): 

206 """Usual spin rotations forward in time""" 

207 if nspins == 1: 

208 if spincomponent == '00': 

209 s1_S = [0] 

210 s2_S = [0] 

211 else: 

212 raise ValueError(spincomponent, nspins) 

213 else: 

214 if spincomponent == '00': 

215 s1_S = [0, 1] 

216 s2_S = [0, 1] 

217 elif spincomponent == 'uu': 

218 s1_S = [0] 

219 s2_S = [0] 

220 elif spincomponent == 'dd': 

221 s1_S = [1] 

222 s2_S = [1] 

223 elif spincomponent == '+-': 

224 s1_S = [0] # spin up 

225 s2_S = [1] # spin down 

226 elif spincomponent == '-+': 

227 s1_S = [1] # spin down 

228 s2_S = [0] # spin up 

229 else: 

230 raise ValueError(spincomponent) 

231 

232 return np.array(s1_S), np.array(s2_S) 

233 

234 

235def get_pairwise_spin_transitions_domain(spincomponent, nspins): 

236 """In a sum over pairs, transitions including a spin rotation may have to 

237 include terms, propagating backwards in time.""" 

238 if spincomponent in ['+-', '-+']: 

239 assert nspins == 2 

240 return np.array([0, 1]), np.array([1, 0]) 

241 else: 

242 return get_double_spin_transitions_domain(spincomponent, nspins) 

243 

244 

245def transitions_in_composite_index(n1_M, n2_M, s1_S, s2_S): 

246 """Use a composite index t for transitions (n, s) -> (n', s').""" 

247 n1_MS, s1_MS = np.meshgrid(n1_M, s1_S) 

248 n2_MS, s2_MS = np.meshgrid(n2_M, s2_S) 

249 return n1_MS.flatten(), n2_MS.flatten(), s1_MS.flatten(), s2_MS.flatten()