Coverage for gpaw/response/mpa_interpolation.py: 99%
112 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1"""
2This file contains several routines to do the non linear
3interpolation of poles and residues of respose fuctions
4corresponding to the Multipole Approximation (MPA)
5developed in the Ref. [1].
7The implemented solver is the one based on Pade-Thiele
8formula (See App. A of Ref. [1]).
10[1] DA. Leon et al, PRB 104, 115157 (2021)
11"""
12from __future__ import annotations
13from typing import Tuple, no_type_check
14from gpaw.typing import Array1D, Array2D, Array3D
15import numpy as np
16from numpy.linalg import eigvals
19def fit_residue(
20 npr_GG: Array2D, omega_w: Array1D, X_wGG: Array3D, E_pGG: Array3D
21) -> Array3D:
22 npols = len(E_pGG)
23 nw = len(omega_w)
24 A_GGwp = np.zeros((*E_pGG.shape[1:], nw, npols), dtype=np.complex128)
25 b_GGw = np.zeros((*E_pGG.shape[1:], nw), dtype=np.complex128)
26 for w in range(nw):
27 A_GGwp[:, :, w, :] = (
28 2 * E_pGG / (omega_w[w]**2 - E_pGG**2)).transpose((1, 2, 0))
29 b_GGw[:, :, w] = X_wGG[w, :, :]
31 for p in range(npols):
32 for w in range(A_GGwp.shape[2]):
33 A_GGwp[:, :, w, p][p >= npr_GG] = 0.0
35 temp_GGp = np.einsum('GHwp,GHw->GHp',
36 A_GGwp.conj(), b_GGw)
37 XTX_GGpp = np.einsum('GHwp,GHwo->GHpo',
38 A_GGwp.conj(), A_GGwp)
40 if XTX_GGpp.shape[2] == 1:
41 # 1D matrix, invert the number
42 XTX_GGpp = 1 / XTX_GGpp
43 R_GGp = np.einsum('GHpo,GHo->GHp',
44 XTX_GGpp, temp_GGp)
45 else:
46 try:
47 # Note: Numpy 2.0 changed the broadcasting rules of
48 # `solve()`;
49 # temporarily pad the array shape with an extra dimension to
50 # emulate the old behavior
51 R_GGp = np.linalg.solve(
52 XTX_GGpp, temp_GGp.reshape(temp_GGp.shape + (1,)))[..., 0]
53 except np.linalg.LinAlgError:
54 XTX_GGpp = np.linalg.pinv(XTX_GGpp)
55 R_GGp = np.einsum('GHpo,GHo->GHp',
56 XTX_GGpp, temp_GGp)
58 return R_GGp.transpose((2, 0, 1))
61class Solver:
62 """
63 X(w) is approximated as a sum of poles
64 Form of one pole: 2*E*R/(w**2-E**2)
65 The input are two w and X(w) for each pole
66 The output are E and R coefficients
67 """
68 def __init__(self, omega_w: Array1D, threshold=1e-5, epsilon=1e-8):
69 """
70 Parameters
71 ----------
72 omega_w : Array of complex frequencies set in mpa sampling.
73 The length corresponds to twice the number of poles
74 threshold : Threshold for small and too close poles
75 epsilon : Precision for positive zero imaginary part of the poles
76 """
77 assert len(omega_w) % 2 == 0
78 self.omega_w = omega_w
79 self.npoles = len(omega_w) // 2
80 self.threshold = threshold
81 self.epsilon = epsilon
83 def solve(self, X_wGG):
84 """
85 X_wGG is any response function evaluated at omega_w
86 it returns a tuple of poles and residues (E_pGG, R_pGG)
87 where p is the pole index
88 """
89 raise NotImplementedError
92class SinglePoleSolver(Solver):
93 def __init__(self, omega_w: Array1D):
94 Solver.__init__(self, omega_w=omega_w)
96 def solve(self, X_wGG: Array3D) -> Tuple[Array2D, Array2D]:
97 """
98 This interpolates X_wGG using a single pole (E_GG, R_GG)
99 """
100 assert len(X_wGG) == 2
102 omega_w = self.omega_w
103 E_GG = ((X_wGG[0, :, :] * omega_w[0]**2 -
104 X_wGG[1, :, :] * omega_w[1]**2) /
105 (X_wGG[0, :, :] - X_wGG[1, :, :])
106 ) # analytical solution
108 def branch_sqrt_inplace(E_GG: Array2D):
109 E_GG.real = np.abs(E_GG.real) # physical pole
110 E_GG.imag = -np.abs(E_GG.imag) # correct time ordering
111 E_GG[:] = np.emath.sqrt(E_GG)
113 branch_sqrt_inplace(E_GG)
114 mask = E_GG < self.threshold # null pole
115 E_GG[mask] = self.threshold - 1j * self.epsilon
117 R_GG = fit_residue(
118 npr_GG=np.zeros_like(E_GG) + 1,
119 omega_w=omega_w,
120 X_wGG=X_wGG,
121 E_pGG=E_GG.reshape((1, *E_GG.shape)))[0, :, :]
123 return E_GG.reshape((1, *E_GG.shape)), R_GG.reshape((1, *R_GG.shape))
126class MultipoleSolver(Solver):
127 def __init__(self, omega_w: Array1D):
128 Solver.__init__(self, omega_w=omega_w)
130 def solve(self, X_wGG: Array3D) -> Tuple[Array3D, Array3D]:
131 """
132 This interpolates X_wGG using a sveral poles (E_pGG, R_pGG)
133 """
134 assert len(X_wGG) == 2 * self.npoles
136 # First the poles are obtained (non linear part of the problem)
137 E_GGp, npr_GG = pade_solve(X_wGG, self.omega_w**2)
138 E_pGG = E_GGp.transpose((2, 0, 1))
139 # The residues are obtained in a linear least square problem with
140 # complex variables
141 R_pGG = fit_residue(npr_GG, self.omega_w, X_wGG, E_pGG)
142 return E_pGG, R_pGG
145def RESolver(omega_w: Array1D):
146 assert len(omega_w) % 2 == 0
147 npoles = len(omega_w) / 2
148 assert npoles > 0
149 if npoles == 1:
150 return SinglePoleSolver(omega_w)
151 else:
152 return MultipoleSolver(omega_w)
155def mpa_cond_vectorized(
156 npols: int, z_w: Array1D, E_GGp: Array3D, pole_resolution: float = 1e-5
157) -> Tuple[Array3D, Array2D]:
158 wmax = np.max(np.real(np.emath.sqrt(z_w))) * 1.5
160 E_GGp = np.emath.sqrt(E_GGp)
161 args = np.abs(E_GGp.real), np.abs(E_GGp.imag)
162 E_GGp = np.maximum(*args) - 1j * np.minimum(*args)
163 E_GGp.sort(axis=2) # Sort according to real part
165 for i in range(npols):
166 out_poles_GG = E_GGp[:, :, i].real > wmax
167 E_GGp[out_poles_GG, i] = 2 * wmax - 0.01j
168 for j in range(i + 1, npols):
169 diff = E_GGp[:, :, j].real - E_GGp[:, :, i].real
170 equal_poles_GG = diff < pole_resolution
171 if np.sum(equal_poles_GG.ravel()):
172 break
174 # if the poles are to close, move them to the end, set value
175 # to sort to the end of array (e.g., 2x wmax)
176 E_GGp[:, :, j] = np.where(
177 equal_poles_GG,
178 (E_GGp[:, :, j].real + E_GGp[:, :, i].real) / 2
179 + 1j * np.maximum(E_GGp[:, :, i].imag, E_GGp[:, :, j].imag),
180 E_GGp[:, :, j],
181 )
182 E_GGp[equal_poles_GG, i] = 2 * wmax - 0.01j
184 E_GGp.sort(axis=2) # Sort according to real part
186 npr_GG = np.sum(E_GGp.real < wmax, axis=2)
187 return E_GGp, npr_GG
190@no_type_check
191def pade_solve(X_wGG: Array3D, z_w: Array1D) -> Tuple[Array3D, Array2D]:
192 nw, nG1, nG2 = X_wGG.shape
193 npols = nw // 2
194 nm = npols + 1
195 b_GGm = np.zeros((nG1, nG2, nm), dtype=np.complex128)
196 b_GGm[..., 0] = 1.0
197 bm1_GGm = b_GGm
198 c_GGw = X_wGG.transpose((1, 2, 0)).copy()
200 for i in range(1, 2 * npols):
201 cm1_GGw = np.copy(c_GGw)
202 bm2_GGm = np.copy(bm1_GGm)
203 bm1_GGm = np.copy(b_GGm)
205 current_z = z_w[i - 1]
207 c_GGw[..., i:] = (
208 (cm1_GGw[..., i - 1, np.newaxis] - cm1_GGw[..., i:]) /
209 ((z_w[i:] - current_z) * cm1_GGw[..., i:])
210 )
212 b_GGm -= current_z * c_GGw[..., i, np.newaxis] * bm2_GGm
213 bm2_GGm[..., 1:] = c_GGw[..., i, np.newaxis] * bm2_GGm[..., :-1]
214 b_GGm[..., 1:] += bm2_GGm[..., 1:]
216 companion_GGpp = np.zeros((nG1, nG2, npols, npols),
217 dtype=np.complex128)
219 # Create a poly companion matrix in vectorized form
220 # Equal to following serial code
221 # for i in range(nG):
222 # for j in range(nG):
223 # companion_GGpp[i, j] = poly.polycompanion(b_GGm[i, j])
224 b_GGm /= b_GGm[:, :, -1][..., None]
225 companion_GGpp.reshape((nG1, nG2, -1))[:, :, npols::npols + 1] = 1
226 companion_GGpp[:, :, :, -1] = -b_GGm[:, :, :npols]
228 E_GGp = eigvals(companion_GGpp)
229 E_GGp, npr_GG = mpa_cond_vectorized(npols=npols, z_w=z_w, E_GGp=E_GGp)
230 return E_GGp, npr_GG