Coverage for gpaw/tddft/solvers/base.py: 88%

17 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from abc import ABC, abstractmethod 

2 

3 

4class BaseSolver(ABC): 

5 """Abstract base class for solvers. 

6 

7 Implementations of this class solves a set of linear equations A.x = b. 

8 

9 Parameters 

10 ---------- 

11 tolerance: float 

12 tolerance for the norm of the residual ||b - A.x||^2 

13 max_iterations: integer 

14 maximum number of iterations 

15 eps: float 

16 if abs(rho) or (omega) < eps, it's regarded as zero 

17 and the method breaks down 

18 

19 Note 

20 ---- 

21 Tolerance should not be smaller than attainable accuracy, which is 

22 order of kappa(A) * eps, where kappa(A) is the (spectral) condition 

23 number of the matrix. The maximum number of iterations should be 

24 significantly less than matrix size, approximately 

25 .5 sqrt(kappa) ln(2/tolerance). A small number is treated as zero 

26 if it's magnitude is smaller than argument eps. 

27 """ 

28 def __init__(self, 

29 tolerance=1e-8, 

30 max_iterations=1000, 

31 eps=1e-15): 

32 self.tol = tolerance 

33 self.max_iter = max_iterations 

34 if (eps <= tolerance): 

35 self.eps = eps 

36 else: 

37 raise ValueError( 

38 "Invalid tolerance (tol = %le < eps = %le)." 

39 % (tolerance, eps)) 

40 

41 self.iterations = -1 

42 

43 def todict(self): 

44 return {'name': self.__class__.__name__, 

45 'tolerance': self.tol, 

46 'max_iterations': self.max_iter, 

47 'eps': self.eps} 

48 

49 def initialize(self, gd, timer): 

50 """Initialize propagator using runtime objects. 

51 

52 Parameters 

53 ---------- 

54 gd: GridDescriptor 

55 grid descriptor for coarse (pseudowavefunction) grid 

56 timer: Timer 

57 timer 

58 """ 

59 self.gd = gd 

60 self.timer = timer 

61 

62 @abstractmethod 

63 def solve(self, A, x, b): 

64 """Solve a set of linear equations A.x = b. 

65 

66 Parameters: 

67 A matrix A 

68 x initial guess x_0 (on entry) and the result (on exit) 

69 b right-hand side (multi)vector 

70 

71 """ 

72 raise NotImplementedError()