Coverage for gpaw/spline.py: 75%

55 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1 

2# Copyright (C) 2003 CAMP 

3# Please see the accompanying LICENSE file for further information. 

4 

5import numpy as np 

6 

7from gpaw import debug 

8from gpaw.utilities import is_contiguous 

9import gpaw.cgpaw as cgpaw 

10 

11 

12class Spline: 

13 """Spline object""" 

14 def __init__(self, spline): 

15 self.spline = spline 

16 self.l = self.get_angular_momentum_number() 

17 

18 @classmethod 

19 def from_data(cls, l, rmax, f_g): 

20 """The integer l gives the angular momentum quantum number and 

21 the list contains the spline values from r=0 to r=rcut. 

22 

23 The array f_g gives the radial part of the function on the grid. 

24 The radial function is multiplied by a real solid spherical harmonics 

25 (r^l * Y_lm). 

26 """ 

27 assert rmax > 0.0 

28 f_g = np.array(f_g, float) 

29 # Copy so we don't change the values of the input array 

30 f_g[-1] = 0.0 

31 return cls(cgpaw.Spline(l, rmax, f_g)) 

32 

33 def get_cutoff(self): 

34 """Return the radial cutoff.""" 

35 return self.spline.get_cutoff() 

36 

37 def get_angular_momentum_number(self): 

38 """Return the angular momentum quantum number.""" 

39 return self.spline.get_angular_momentum_number() 

40 

41 def get_npoints(self): 

42 return self.spline.get_npoints() 

43 

44 def __repr__(self): 

45 return ('Spline(l={}, rmax={:.2f}, ...)' 

46 .format(self.get_angular_momentum_number(), 

47 self.get_cutoff())) 

48 

49 def get_value_and_derivative(self, r): 

50 """Return the value and derivative.""" 

51 return self.spline.get_value_and_derivative(r) 

52 

53 def __call__(self, r): 

54 assert r >= 0.0 

55 return self.spline(r) 

56 

57 def map(self, r_x): 

58 """Map f(r) onto a given radial grid.""" 

59 out_x = np.empty_like(r_x) 

60 assert r_x.flags.c_contiguous 

61 self.spline.map(r_x, out_x) 

62 return out_x 

63 

64 def __getstate__(self): 

65 state = self.__dict__.copy() 

66 rmax = self.get_cutoff() 

67 state['spline'] = ( 

68 rmax, 

69 self.map(np.linspace(0.0, rmax, self.get_npoints()))) 

70 return state 

71 

72 def __setstate__(self, state): 

73 rmax, f_g = state['spline'] 

74 state['spline'] = cgpaw.Spline(state['l'], rmax, f_g) 

75 self.__dict__.update(state) 

76 

77 def get_functions(self, gd, start_c, end_c, spos_c): 

78 h_cv = gd.h_cv 

79 # start_c is the new origin so we translate gd.beg_c to start_c 

80 origin_c = np.array([0, 0, 0]) 

81 pos_v = np.dot(spos_c, gd.cell_cv) - np.dot(start_c, h_cv) 

82 A_gm, G_b = cgpaw.spline_to_grid(self.spline, 

83 origin_c, 

84 end_c - start_c, 

85 pos_v, 

86 h_cv, 

87 end_c - start_c, 

88 origin_c) 

89 

90 if debug: 

91 assert G_b.ndim == 1 and G_b.shape[0] % 2 == 0 

92 assert is_contiguous(G_b, np.intc) 

93 assert A_gm.shape[:-1] == np.sum(G_b[1::2] - G_b[::2]) 

94 

95 indices_gm, ng, nm = self.spline.get_indices_from_zranges(start_c, 

96 end_c, G_b) 

97 shape = (nm,) + tuple(end_c - start_c) 

98 work_mB = np.zeros(shape, dtype=A_gm.dtype) 

99 np.put(work_mB, indices_gm, A_gm) 

100 return work_mB