Coverage for gpaw/new/environment.py: 43%

67 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from __future__ import annotations 

2import numpy as np 

3from gpaw.new.poisson import PoissonSolver 

4from gpaw.core import UGArray, UGDesc, PWArray 

5from ase.units import Ha 

6from gpaw.new.ibzwfs import IBZWaveFunctions 

7 

8 

9class Environment: 

10 """Environment object. 

11 

12 Used for jellium, solvation, solvated jellium model, ... 

13 """ 

14 def __init__(self, natoms: int): 

15 self.natoms = natoms 

16 self.charge = 0.0 

17 

18 def create_poisson_solver(self, *, grid, xp, solver) -> PoissonSolver: 

19 return solver.build(grid=grid, xp=xp) 

20 

21 def post_scf_convergence(self, 

22 ibzwfs: IBZWaveFunctions, 

23 nelectrons: float, 

24 occ_calc, 

25 mixer, 

26 log) -> bool: 

27 """Allow for environment to "converge".""" 

28 return True 

29 

30 def update1(self, nt_r) -> None: 

31 """Hook called right before solving the Poisson equation.""" 

32 pass 

33 

34 def update1pw(self, nt_g) -> None: 

35 """PW-mode hook called right before solving the Poisson equation.""" 

36 pass 

37 

38 def update2(self, nt_r, vHt_r, vt_sr) -> float: 

39 """Calculate environment energy.""" 

40 return 0.0 

41 

42 def forces(self, nt_r, vHt_r): 

43 return np.zeros((self.natoms, 3)) 

44 

45 

46class Jellium(Environment): 

47 def __init__(self, 

48 jellium, 

49 natoms: int, 

50 grid: UGDesc): 

51 super().__init__(natoms) 

52 self.grid = grid 

53 self.charge = jellium.charge 

54 self.mask_r = grid.from_data(jellium.mask_g / jellium.volume) 

55 self.mask_g: PWArray | str = 'undefined' 

56 

57 def update1(self, nt_r: UGArray) -> None: 

58 nt_r.data -= self.mask_r.data * self.charge 

59 

60 def update1pw(self, nt_g: PWArray | None) -> None: 

61 if self.mask_g == 'undefined': 

62 mask_r = self.mask_r.gather() 

63 if nt_g is not None: 

64 self.mask_g = mask_r.fft(pw=nt_g.desc) 

65 else: 

66 self.mask_g = 'ready' 

67 if nt_g is None: 

68 return 

69 assert not isinstance(self.mask_g, str) 

70 nt_g.data -= self.mask_g.data * self.charge 

71 

72 

73class FixedPotentialJellium(Jellium): 

74 def __init__(self, 

75 jellium, 

76 natoms: int, 

77 grid: UGDesc, 

78 workfunction: float, # eV 

79 tolerance: float = 0.001): # eV 

80 """Adjust jellium charge to get the desired Fermi-level.""" 

81 super().__init__(jellium, natoms, grid) 

82 self.workfunction = workfunction / Ha 

83 self.tolerance = tolerance / Ha 

84 # Charge, Fermi-level history: 

85 self.history: list[tuple[float, float]] = [] 

86 

87 def post_scf_convergence(self, 

88 ibzwfs: IBZWaveFunctions, 

89 nelectrons: float, 

90 occ_calc, 

91 mixer, 

92 log) -> bool: 

93 fl1 = ibzwfs.fermi_level 

94 log(f'charge: {self.charge:.6f} |e|, Fermi-level: {fl1 * Ha:.3f} eV') 

95 fl = -self.workfunction 

96 if abs(fl1 - fl) <= self.tolerance: 

97 return True 

98 self.history.append((self.charge, fl1)) 

99 if len(self.history) == 1: 

100 area = abs(np.linalg.det(self.grid.cell_cv[:2, :2])) 

101 dc = -(fl1 - fl) * area * 0.02 

102 else: 

103 (c2, fl2), (c1, fl1) = self.history[-2:] 

104 c = c2 + (fl - fl2) / (fl1 - fl2) * (c1 - c2) 

105 dc = c - c1 

106 if abs(dc) > abs(c2 - c1): 

107 dc *= abs((c2 - c1) / dc) 

108 self.charge += dc 

109 nelectrons += dc 

110 ibzwfs.calculate_occs(occ_calc, nelectrons) 

111 mixer.reset() 

112 return False