Coverage for gpaw/test/response/test_mpa_vectorization.py: 100%

53 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1import numpy as np 

2from gpaw.response.mpa_interpolation import fit_residue, RESolver 

3from .mpa_interpolation_scalar import (mpa_R_fit as fit_residue_fortran, 

4 mpa_RE_solver, Xeval) 

5 

6 

7def test_residues(in_tmp_dir): 

8 nG = 5 

9 npols = 10 

10 Omega_GGp = np.empty((nG, nG, npols), dtype=np.complex128) 

11 residues_GGp = np.empty((nG, nG, npols), dtype=np.complex128) 

12 X_GGw = np.empty((nG, nG, 2 * npols), dtype=np.complex128) 

13 R_fortran_GGp = np.empty((nG, nG, npols), dtype=np.complex128) 

14 omega_w = np.linspace(0., 5., 2 * npols) + 0.1j 

15 

16 rng = np.random.default_rng(seed=1) 

17 for g1 in range(nG): 

18 for g2 in range(nG): 

19 Omega_GGp[g1, g2] = rng.random(npols) * 0.05 + 5.5 - 0.01j 

20 residues_GGp[g1, g2] = rng.random(npols) 

21 X_GGw[g1, g2] = Xeval(Omega_GGp[g1, g2], 

22 residues_GGp[g1, g2], 

23 omega_w) 

24 R_fortran_GGp[g1, g2] = fit_residue_fortran(npols, npols, omega_w, 

25 X_GGw[g1, g2], 

26 Omega_GGp[g1, g2]) 

27 

28 R_pGG = fit_residue(np.ones((nG, nG)) * npols, 

29 omega_w, 

30 X_GGw.transpose(2, 0, 1), 

31 Omega_GGp.transpose(2, 0, 1)) 

32 

33 R_GGp = R_pGG.transpose(1, 2, 0) 

34 

35 X_fit_GGw = Xeval(Omega_GGp, R_GGp, omega_w) 

36 X_fortran_fit_GGw = Xeval(Omega_GGp, R_fortran_GGp, omega_w) 

37 assert np.allclose(X_fit_GGw, X_fortran_fit_GGw, atol=1e-6) 

38 

39 if 0: 

40 from matplotlib import pyplot as plt 

41 g1, g2 = 0, 0 

42 

43 plt.plot(omega_w, X_GGw[g1, g2].real, 'k', ls='--') 

44 plt.plot(omega_w, X_GGw[g1, g2].imag, 'gray', ls='--') 

45 

46 plt.plot(omega_w, X_fit_GGw[g1, g2].real) 

47 plt.plot(omega_w, X_fit_GGw[g1, g2].imag) 

48 

49 plt.plot(omega_w, X_fortran_fit_GGw[g1, g2].real, ls=':') 

50 plt.plot(omega_w, X_fortran_fit_GGw[g1, g2].imag, ls=':') 

51 plt.show() 

52 

53 

54def test_poles(in_tmp_dir): 

55 nG = 5 

56 npols = 100 

57 wmax = 2 

58 Omega_GGp = np.empty((nG, nG, npols), dtype=np.complex128) 

59 residues_GGp = np.empty((nG, nG, npols), dtype=np.complex128) 

60 

61 npols_mpa = 6 

62 omega_p = np.linspace(0, wmax, npols_mpa) 

63 omega_w = np.concatenate((omega_p + 0.1j, omega_p + 1.j)) 

64 

65 X_GGw = np.empty((nG, nG, 2 * npols_mpa), dtype=np.complex128) 

66 E_GGp = np.empty((nG, nG, npols_mpa), dtype=np.complex128) 

67 R_GGp = np.empty((nG, nG, npols_mpa), dtype=np.complex128) 

68 E_fortran_GGp = np.empty((nG, nG, npols_mpa), dtype=np.complex128) 

69 R_fortran_GGp = np.empty((nG, nG, npols_mpa), dtype=np.complex128) 

70 

71 rng = np.random.default_rng(seed=2) 

72 for g1 in range(nG): 

73 for g2 in range(nG): 

74 Omega_GGp[g1, g2] = rng.normal(1, 0.5, npols) - 0.05j 

75 residues_GGp[g1, g2] = 0.1 + rng.random(npols) 

76 X_GGw[g1, g2] = Xeval(Omega_GGp[g1, g2], 

77 residues_GGp[g1, g2], 

78 omega_w) 

79 

80 R_fortran_GGp[g1, g2], E_fortran_GGp[g1, g2], _, _ = ( 

81 mpa_RE_solver(npols_mpa, omega_w, X_GGw[g1, g2])) 

82 ind = np.argsort(E_fortran_GGp[g1, g2].real) 

83 E_fortran_GGp[g1, g2] = E_fortran_GGp[g1, g2, ind] 

84 R_fortran_GGp[g1, g2] = R_fortran_GGp[g1, g2, ind] 

85 

86 E_pGG, R_pGG = RESolver(omega_w).solve(X_GGw.transpose(2, 0, 1)) 

87 

88 E_GGp = E_pGG.transpose(1, 2, 0) 

89 R_GGp = R_pGG.transpose(1, 2, 0) 

90 

91 assert np.allclose(E_GGp, E_fortran_GGp, rtol=1e-4, atol=1e-6) 

92 

93 if 0: # asserting R or X fails due to ill conditioning in np.linalg.solve 

94 from matplotlib import pyplot as plt 

95 

96 omega_grid = np.linspace(0., wmax, 100) + 0.01j 

97 X_fit_GGw = Xeval(E_GGp, R_GGp, omega_grid) 

98 X_fortran_fit_GGw = Xeval(E_fortran_GGp, R_fortran_GGp, omega_grid) 

99 a = np.allclose(X_fit_GGw, X_fortran_fit_GGw, rtol=1e-4, atol=1e-6) 

100 print('assert X', a) 

101 

102 X_num_GGw = Xeval(Omega_GGp, residues_GGp, omega_grid) 

103 g1, g2 = 0, 0 

104 plt.plot(omega_grid.real, X_num_GGw[g1, g2].real, 'k', ls='--') 

105 plt.plot(omega_grid.real, X_num_GGw[g1, g2].imag, 'gray', ls='--') 

106 

107 plt.plot(omega_grid.real, X_fit_GGw[g1, g2].real) 

108 plt.plot(omega_grid.real, X_fit_GGw[g1, g2].imag) 

109 plt.plot(omega_grid.real, X_fortran_fit_GGw[g1, g2].real, ls=':') 

110 plt.plot(omega_grid.real, X_fortran_fit_GGw[g1, g2].imag, ls=':') 

111 plt.show()