Coverage for gpaw/spline.py: 75%
55 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
2# Copyright (C) 2003 CAMP
3# Please see the accompanying LICENSE file for further information.
5import numpy as np
7from gpaw import debug
8from gpaw.utilities import is_contiguous
9import gpaw.cgpaw as cgpaw
12class Spline:
13 """Spline object"""
14 def __init__(self, spline):
15 self.spline = spline
16 self.l = self.get_angular_momentum_number()
18 @classmethod
19 def from_data(cls, l, rmax, f_g):
20 """The integer l gives the angular momentum quantum number and
21 the list contains the spline values from r=0 to r=rcut.
23 The array f_g gives the radial part of the function on the grid.
24 The radial function is multiplied by a real solid spherical harmonics
25 (r^l * Y_lm).
26 """
27 assert rmax > 0.0
28 f_g = np.array(f_g, float)
29 # Copy so we don't change the values of the input array
30 f_g[-1] = 0.0
31 return cls(cgpaw.Spline(l, rmax, f_g))
33 def get_cutoff(self):
34 """Return the radial cutoff."""
35 return self.spline.get_cutoff()
37 def get_angular_momentum_number(self):
38 """Return the angular momentum quantum number."""
39 return self.spline.get_angular_momentum_number()
41 def get_npoints(self):
42 return self.spline.get_npoints()
44 def __repr__(self):
45 return ('Spline(l={}, rmax={:.2f}, ...)'
46 .format(self.get_angular_momentum_number(),
47 self.get_cutoff()))
49 def get_value_and_derivative(self, r):
50 """Return the value and derivative."""
51 return self.spline.get_value_and_derivative(r)
53 def __call__(self, r):
54 assert r >= 0.0
55 return self.spline(r)
57 def map(self, r_x):
58 """Map f(r) onto a given radial grid."""
59 out_x = np.empty_like(r_x)
60 assert r_x.flags.c_contiguous
61 self.spline.map(r_x, out_x)
62 return out_x
64 def __getstate__(self):
65 state = self.__dict__.copy()
66 rmax = self.get_cutoff()
67 state['spline'] = (
68 rmax,
69 self.map(np.linspace(0.0, rmax, self.get_npoints())))
70 return state
72 def __setstate__(self, state):
73 rmax, f_g = state['spline']
74 state['spline'] = cgpaw.Spline(state['l'], rmax, f_g)
75 self.__dict__.update(state)
77 def get_functions(self, gd, start_c, end_c, spos_c):
78 h_cv = gd.h_cv
79 # start_c is the new origin so we translate gd.beg_c to start_c
80 origin_c = np.array([0, 0, 0])
81 pos_v = np.dot(spos_c, gd.cell_cv) - np.dot(start_c, h_cv)
82 A_gm, G_b = cgpaw.spline_to_grid(self.spline,
83 origin_c,
84 end_c - start_c,
85 pos_v,
86 h_cv,
87 end_c - start_c,
88 origin_c)
90 if debug:
91 assert G_b.ndim == 1 and G_b.shape[0] % 2 == 0
92 assert is_contiguous(G_b, np.intc)
93 assert A_gm.shape[:-1] == np.sum(G_b[1::2] - G_b[::2])
95 indices_gm, ng, nm = self.spline.get_indices_from_zranges(start_c,
96 end_c, G_b)
97 shape = (nm,) + tuple(end_c - start_c)
98 work_mB = np.zeros(shape, dtype=A_gm.dtype)
99 np.put(work_mB, indices_gm, A_gm)
100 return work_mB