Coverage for gpaw/atom/shapefunc.py: 30%
43 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1import numpy as np
2from .radialgd import RadialGridDescriptor
5def shape_functions(rgd: RadialGridDescriptor,
6 type: str,
7 rc: float,
8 lmax: int) -> np.ndarray:
9 """Shape functions for compensation charges."""
10 g_lg = rgd.zeros(lmax + 1)
11 r_g = rgd.r_g
13 if type == 'gauss':
14 g_lg[0] = 4 / rc**3 / np.sqrt(np.pi) * np.exp(-(r_g / rc)**2)
15 for l in range(1, lmax + 1):
16 g_lg[l] = 2.0 / (2 * l + 1) / rc**2 * r_g * g_lg[l - 1]
17 elif type == 'sinc':
18 g_lg[0] = np.sinc(r_g / rc)**2
19 g_lg[0, rgd.ceil(rc):] = 0.0
20 for l in range(1, lmax + 1):
21 g_lg[l] = r_g * g_lg[l - 1]
22 elif type == 'bessel':
23 from scipy.special import spherical_jn as jn
24 roots = [[3.141592653589793, 6.283185307179586],
25 [4.493409457909095, 7.7252518369375],
26 [5.76345919689455, 9.095011330476355]]
27 for l in range(lmax + 1):
28 q1, q2 = (x0 / rc for x0 in roots[l])
29 alpha = -q1 / q2 * jn(l, q1 * rc, True) / jn(l, q2 * rc, True)
30 g_lg[l] = jn(l, q1 * r_g) + alpha * jn(l, q2 * r_g)
31 g_lg[:, rgd.ceil(rc):] = 0.0
32 else:
33 1 / 0
35 for l in range(lmax + 1):
36 g_lg[l] /= rgd.integrate(g_lg[l], l) / (4 * np.pi)
38 return g_lg
41if __name__ == '__main__':
42 from .radialgd import EquidistantRadialGridDescriptor as RGD
43 from scipy.special import spherical_jn as jn
44 from scipy.optimize import root
45 import matplotlib.pyplot as plt
47 r = np.linspace(0, 1.2, 200)
49 if 0:
50 # Find roots of spherical Bessel functions:
51 for l in range(3):
52 for i, x0 in enumerate([3, 6]):
53 result = root(lambda x: jn(l, x), x0 + 1.5 * l)
54 x0 = result['x']
55 print(l, i, x0.item())
56 plt.plot(r, jn(l, r * x0), label=str(x0))
57 plt.legend()
58 plt.show()
60 rgd = RGD(0.01, 120)
61 for l in range(3):
62 rc = 0.3
63 for type in ['gauss', 'sinc', 'bessel']:
64 g_lg = shape_functions(rgd, type, rc, l)
65 plt.plot(rgd.r_g, g_lg[l], label=type) # type: ignore
66 rc = 1.0
68 plt.legend()
69 plt.show()