Coverage for gpaw/response/mpa_interpolation.py: 99%

112 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1""" 

2This file contains several routines to do the non linear 

3interpolation of poles and residues of respose fuctions 

4corresponding to the Multipole Approximation (MPA) 

5developed in the Ref. [1]. 

6 

7The implemented solver is the one based on Pade-Thiele 

8formula (See App. A of Ref. [1]). 

9 

10[1] DA. Leon et al, PRB 104, 115157 (2021) 

11""" 

12from __future__ import annotations 

13from typing import Tuple, no_type_check 

14from gpaw.typing import Array1D, Array2D, Array3D 

15import numpy as np 

16from numpy.linalg import eigvals 

17 

18 

19def fit_residue( 

20 npr_GG: Array2D, omega_w: Array1D, X_wGG: Array3D, E_pGG: Array3D 

21) -> Array3D: 

22 npols = len(E_pGG) 

23 nw = len(omega_w) 

24 A_GGwp = np.zeros((*E_pGG.shape[1:], nw, npols), dtype=np.complex128) 

25 b_GGw = np.zeros((*E_pGG.shape[1:], nw), dtype=np.complex128) 

26 for w in range(nw): 

27 A_GGwp[:, :, w, :] = ( 

28 2 * E_pGG / (omega_w[w]**2 - E_pGG**2)).transpose((1, 2, 0)) 

29 b_GGw[:, :, w] = X_wGG[w, :, :] 

30 

31 for p in range(npols): 

32 for w in range(A_GGwp.shape[2]): 

33 A_GGwp[:, :, w, p][p >= npr_GG] = 0.0 

34 

35 temp_GGp = np.einsum('GHwp,GHw->GHp', 

36 A_GGwp.conj(), b_GGw) 

37 XTX_GGpp = np.einsum('GHwp,GHwo->GHpo', 

38 A_GGwp.conj(), A_GGwp) 

39 

40 if XTX_GGpp.shape[2] == 1: 

41 # 1D matrix, invert the number 

42 XTX_GGpp = 1 / XTX_GGpp 

43 R_GGp = np.einsum('GHpo,GHo->GHp', 

44 XTX_GGpp, temp_GGp) 

45 else: 

46 try: 

47 # Note: Numpy 2.0 changed the broadcasting rules of 

48 # `solve()`; 

49 # temporarily pad the array shape with an extra dimension to 

50 # emulate the old behavior 

51 R_GGp = np.linalg.solve( 

52 XTX_GGpp, temp_GGp.reshape(temp_GGp.shape + (1,)))[..., 0] 

53 except np.linalg.LinAlgError: 

54 XTX_GGpp = np.linalg.pinv(XTX_GGpp) 

55 R_GGp = np.einsum('GHpo,GHo->GHp', 

56 XTX_GGpp, temp_GGp) 

57 

58 return R_GGp.transpose((2, 0, 1)) 

59 

60 

61class Solver: 

62 """ 

63 X(w) is approximated as a sum of poles 

64 Form of one pole: 2*E*R/(w**2-E**2) 

65 The input are two w and X(w) for each pole 

66 The output are E and R coefficients 

67 """ 

68 def __init__(self, omega_w: Array1D, threshold=1e-5, epsilon=1e-8): 

69 """ 

70 Parameters 

71 ---------- 

72 omega_w : Array of complex frequencies set in mpa sampling. 

73 The length corresponds to twice the number of poles 

74 threshold : Threshold for small and too close poles 

75 epsilon : Precision for positive zero imaginary part of the poles 

76 """ 

77 assert len(omega_w) % 2 == 0 

78 self.omega_w = omega_w 

79 self.npoles = len(omega_w) // 2 

80 self.threshold = threshold 

81 self.epsilon = epsilon 

82 

83 def solve(self, X_wGG): 

84 """ 

85 X_wGG is any response function evaluated at omega_w 

86 it returns a tuple of poles and residues (E_pGG, R_pGG) 

87 where p is the pole index 

88 """ 

89 raise NotImplementedError 

90 

91 

92class SinglePoleSolver(Solver): 

93 def __init__(self, omega_w: Array1D): 

94 Solver.__init__(self, omega_w=omega_w) 

95 

96 def solve(self, X_wGG: Array3D) -> Tuple[Array2D, Array2D]: 

97 """ 

98 This interpolates X_wGG using a single pole (E_GG, R_GG) 

99 """ 

100 assert len(X_wGG) == 2 

101 

102 omega_w = self.omega_w 

103 E_GG = ((X_wGG[0, :, :] * omega_w[0]**2 - 

104 X_wGG[1, :, :] * omega_w[1]**2) / 

105 (X_wGG[0, :, :] - X_wGG[1, :, :]) 

106 ) # analytical solution 

107 

108 def branch_sqrt_inplace(E_GG: Array2D): 

109 E_GG.real = np.abs(E_GG.real) # physical pole 

110 E_GG.imag = -np.abs(E_GG.imag) # correct time ordering 

111 E_GG[:] = np.emath.sqrt(E_GG) 

112 

113 branch_sqrt_inplace(E_GG) 

114 mask = E_GG < self.threshold # null pole 

115 E_GG[mask] = self.threshold - 1j * self.epsilon 

116 

117 R_GG = fit_residue( 

118 npr_GG=np.zeros_like(E_GG) + 1, 

119 omega_w=omega_w, 

120 X_wGG=X_wGG, 

121 E_pGG=E_GG.reshape((1, *E_GG.shape)))[0, :, :] 

122 

123 return E_GG.reshape((1, *E_GG.shape)), R_GG.reshape((1, *R_GG.shape)) 

124 

125 

126class MultipoleSolver(Solver): 

127 def __init__(self, omega_w: Array1D): 

128 Solver.__init__(self, omega_w=omega_w) 

129 

130 def solve(self, X_wGG: Array3D) -> Tuple[Array3D, Array3D]: 

131 """ 

132 This interpolates X_wGG using a sveral poles (E_pGG, R_pGG) 

133 """ 

134 assert len(X_wGG) == 2 * self.npoles 

135 

136 # First the poles are obtained (non linear part of the problem) 

137 E_GGp, npr_GG = pade_solve(X_wGG, self.omega_w**2) 

138 E_pGG = E_GGp.transpose((2, 0, 1)) 

139 # The residues are obtained in a linear least square problem with 

140 # complex variables 

141 R_pGG = fit_residue(npr_GG, self.omega_w, X_wGG, E_pGG) 

142 return E_pGG, R_pGG 

143 

144 

145def RESolver(omega_w: Array1D): 

146 assert len(omega_w) % 2 == 0 

147 npoles = len(omega_w) / 2 

148 assert npoles > 0 

149 if npoles == 1: 

150 return SinglePoleSolver(omega_w) 

151 else: 

152 return MultipoleSolver(omega_w) 

153 

154 

155def mpa_cond_vectorized( 

156 npols: int, z_w: Array1D, E_GGp: Array3D, pole_resolution: float = 1e-5 

157) -> Tuple[Array3D, Array2D]: 

158 wmax = np.max(np.real(np.emath.sqrt(z_w))) * 1.5 

159 

160 E_GGp = np.emath.sqrt(E_GGp) 

161 args = np.abs(E_GGp.real), np.abs(E_GGp.imag) 

162 E_GGp = np.maximum(*args) - 1j * np.minimum(*args) 

163 E_GGp.sort(axis=2) # Sort according to real part 

164 

165 for i in range(npols): 

166 out_poles_GG = E_GGp[:, :, i].real > wmax 

167 E_GGp[out_poles_GG, i] = 2 * wmax - 0.01j 

168 for j in range(i + 1, npols): 

169 diff = E_GGp[:, :, j].real - E_GGp[:, :, i].real 

170 equal_poles_GG = diff < pole_resolution 

171 if np.sum(equal_poles_GG.ravel()): 

172 break 

173 

174 # if the poles are to close, move them to the end, set value 

175 # to sort to the end of array (e.g., 2x wmax) 

176 E_GGp[:, :, j] = np.where( 

177 equal_poles_GG, 

178 (E_GGp[:, :, j].real + E_GGp[:, :, i].real) / 2 

179 + 1j * np.maximum(E_GGp[:, :, i].imag, E_GGp[:, :, j].imag), 

180 E_GGp[:, :, j], 

181 ) 

182 E_GGp[equal_poles_GG, i] = 2 * wmax - 0.01j 

183 

184 E_GGp.sort(axis=2) # Sort according to real part 

185 

186 npr_GG = np.sum(E_GGp.real < wmax, axis=2) 

187 return E_GGp, npr_GG 

188 

189 

190@no_type_check 

191def pade_solve(X_wGG: Array3D, z_w: Array1D) -> Tuple[Array3D, Array2D]: 

192 nw, nG1, nG2 = X_wGG.shape 

193 npols = nw // 2 

194 nm = npols + 1 

195 b_GGm = np.zeros((nG1, nG2, nm), dtype=np.complex128) 

196 b_GGm[..., 0] = 1.0 

197 bm1_GGm = b_GGm 

198 c_GGw = X_wGG.transpose((1, 2, 0)).copy() 

199 

200 for i in range(1, 2 * npols): 

201 cm1_GGw = np.copy(c_GGw) 

202 bm2_GGm = np.copy(bm1_GGm) 

203 bm1_GGm = np.copy(b_GGm) 

204 

205 current_z = z_w[i - 1] 

206 

207 c_GGw[..., i:] = ( 

208 (cm1_GGw[..., i - 1, np.newaxis] - cm1_GGw[..., i:]) / 

209 ((z_w[i:] - current_z) * cm1_GGw[..., i:]) 

210 ) 

211 

212 b_GGm -= current_z * c_GGw[..., i, np.newaxis] * bm2_GGm 

213 bm2_GGm[..., 1:] = c_GGw[..., i, np.newaxis] * bm2_GGm[..., :-1] 

214 b_GGm[..., 1:] += bm2_GGm[..., 1:] 

215 

216 companion_GGpp = np.zeros((nG1, nG2, npols, npols), 

217 dtype=np.complex128) 

218 

219 # Create a poly companion matrix in vectorized form 

220 # Equal to following serial code 

221 # for i in range(nG): 

222 # for j in range(nG): 

223 # companion_GGpp[i, j] = poly.polycompanion(b_GGm[i, j]) 

224 b_GGm /= b_GGm[:, :, -1][..., None] 

225 companion_GGpp.reshape((nG1, nG2, -1))[:, :, npols::npols + 1] = 1 

226 companion_GGpp[:, :, :, -1] = -b_GGm[:, :, :npols] 

227 

228 E_GGp = eigvals(companion_GGpp) 

229 E_GGp, npr_GG = mpa_cond_vectorized(npols=npols, z_w=z_w, E_GGp=E_GGp) 

230 return E_GGp, npr_GG