Coverage for gpaw/new/solvation.py: 22%

122 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1import numpy as np 

2from ase.units import Ha, Bohr 

3from gpaw.fd_operators import Gradient 

4from gpaw.new.c import add_to_density 

5from gpaw.new.environment import Environment 

6from gpaw.new.poisson import PoissonSolver, PoissonSolverWrapper 

7from gpaw.solvation.poisson import WeightedFDPoissonSolver 

8from gpaw.solvation.cavity import Cavity 

9from gpaw.solvation.dielectric import Dielectric 

10from gpaw.solvation.interactions import Interaction 

11from gpaw.dft import Parameter 

12 

13 

14class Solvation(Parameter): 

15 name = 'solvation' 

16 

17 def __init__(self, cavity, dielectric, interactions=None): 

18 print(cavity) 

19 self.cavity = Cavity.from_dict(cavity) 

20 self.dielectric = Dielectric.from_dict(dielectric) 

21 self.interactions = [Interaction.from_dict(i) 

22 for i in interactions or []] 

23 

24 def todict(self): 

25 return {'cavity': self.cavity.todict(), 

26 'dielectric': self.dielectric.todict(), 

27 'interactions': [ 

28 {'name': i.__class__.__name__, **i.todict()} 

29 for i in self.interactions]} 

30 

31 def build(self, 

32 setups, 

33 grid, 

34 relpos_ac, 

35 log, 

36 comm): 

37 return SolvationEnvironment( 

38 cavity=self.cavity, 

39 dielectric=self.dielectric, 

40 interactions=self.interactions, 

41 setups=setups, grid=grid, relpos_ac=relpos_ac, 

42 log=log, comm=comm) 

43 

44 

45class SolvationEnvironment(Environment): 

46 def __init__(self, 

47 *, 

48 cavity, 

49 dielectric, 

50 interactions=None, 

51 setups, grid, relpos_ac, log, comm): 

52 self.cavity = cavity 

53 self.dielectric = dielectric 

54 self.interactions = interactions or [] 

55 finegd = grid._gd 

56 self.grid = grid 

57 self.comm = comm 

58 self.cavity.set_grid_descriptor(finegd) 

59 self.dielectric.set_grid_descriptor(finegd) 

60 for ia in self.interactions: 

61 ia.set_grid_descriptor(finegd) 

62 self.cavity.allocate() 

63 self.dielectric.allocate() 

64 for ia in self.interactions: 

65 ia.allocate() 

66 from ase import Atoms 

67 self.atoms = Atoms([setup.symbol for setup in setups], 

68 scaled_positions=relpos_ac, 

69 cell=grid.cell * Bohr, 

70 pbc=grid.pbc) 

71 self.cavity.update_atoms(self.atoms, log) 

72 for ia in self.interactions: 

73 ia.update_atoms(self.atoms, log) 

74 self.grad_v = [Gradient(grid, v, 1.0, n=3) for v in range(3)] 

75 self.vt_ia_r = grid.empty() # self.finegd.zeros() 

76 self.e_interactions = np.nan 

77 super().__init__(len(self.atoms)) 

78 

79 def interaction_energy(self): 

80 return self.e_interactions * Ha 

81 

82 def create_poisson_solver(self, grid, *, xp, **kwargs) -> PoissonSolver: 

83 psolver = WeightedFDPoissonSolver() 

84 psolver.set_dielectric(self.dielectric) 

85 psolver.set_grid_descriptor(self.grid._gd) 

86 return PoissonSolverWrapper(psolver) 

87 

88 def update1(self, nt_r, kin_en_using_band=True): 

89 density = DensityWrapper(nt_r) 

90 self.cavity_changed = self.cavity.update(self.atoms, density) 

91 if self.cavity_changed: 

92 self.cavity.update_vol_surf() 

93 self.dielectric.update(self.cavity) 

94 

95 def update2(self, nt_r, vHt_r, vt_sr): 

96 if self.cavity.depends_on_el_density: 

97 del_g_del_n_g = self.cavity.del_g_del_n_g 

98 del_eps_del_g_g = self.dielectric.del_eps_del_g_g 

99 Veps = -1 / (8 * np.pi) * del_eps_del_g_g * del_g_del_n_g 

100 Veps *= grad_squared(vHt_r, self.grad_v).data 

101 for vt_r in vt_sr.data: 

102 vt_r += Veps 

103 

104 density = DensityWrapper(nt_r) 

105 ia_changed = [ 

106 ia.update( 

107 self.atoms, 

108 density, 

109 self.cavity if self.cavity_changed else None) 

110 for ia in self.interactions] 

111 if any(ia_changed): 

112 self.vt_ia_r.data.fill(.0) 

113 for ia in self.interactions: 

114 if ia.depends_on_el_density: 

115 self.vt_ia_r.data += ia.delta_E_delta_n_g 

116 if self.cavity.depends_on_el_density: 

117 self.vt_ia_r.data += (ia.delta_E_delta_g_g * 

118 self.cavity.del_g_del_n_g) 

119 if len(self.interactions) > 0: 

120 for vt_r in vt_sr.data: 

121 vt_r += self.vt_ia_r.data 

122 Eias = np.array([ia.E for ia in self.interactions]) 

123 self.grid.comm.sum(Eias) 

124 self.e_interactions = Eias.sum() 

125 

126 self.cavity.communicate_vol_surf(self.comm) 

127 for E, ia in zip(Eias, self.interactions): 

128 pass 

129 

130 self.atoms = None 

131 return self.e_interactions 

132 

133 def forces(self, nt_r, vHt_r): 

134 F_av = np.zeros((self.natoms, 3)) 

135 add_el_force_correction( 

136 nt_r, vHt_r, self.grad_v, self.cavity, self.dielectric, F_av) 

137 

138 density = DensityWrapper(nt_r) 

139 

140 for ia in self.interactions: 

141 if self.cavity.depends_on_atomic_positions: 

142 delta_E_delta_g_r = self.grid.from_data( 

143 ia.delta_E_delta_g_g) 

144 for a, F_v in enumerate(F_av): 

145 del_g_del_r_vg = self.grid.from_data( 

146 self.cavity.get_del_r_vg(a, density)) 

147 F_v -= delta_E_delta_g_r.integrate(del_g_del_r_vg, 

148 skip_sum=True) 

149 

150 if ia.depends_on_atomic_positions: 

151 for a, F_v in enumerate(F_av): 

152 del_E_del_r_vr = self.grid.from_data( 

153 ia.get_del_r_vg(a, density)) 

154 F_v -= del_E_del_r_vr.integrate(skip_sum=True) 

155 

156 return F_av 

157 

158 

159def add_el_force_correction(nt_r, vHt_r, grad_v, cavity, dielectric, F_av): 

160 if not cavity.depends_on_atomic_positions: 

161 return 

162 

163 fixed_r = grad_squared(vHt_r, grad_v) # XXX grad_vHt_g inexact in bmgs 

164 fixed_r.data *= 1 / (8 * np.pi) * dielectric.del_eps_del_g_g 

165 

166 density = DensityWrapper(nt_r) 

167 

168 for a, F_v in enumerate(F_av): 

169 del_g_del_r_vr = fixed_r.desc.from_data( 

170 cavity.get_del_r_vg(a, density)) 

171 F_v += fixed_r.integrate(del_g_del_r_vr, skip_sum=True) 

172 

173 

174class DensityWrapper: 

175 def __init__(self, nt_r): 

176 self.nt_g = nt_r.data 

177 

178 

179def grad_squared(a_r, grad_v): 

180 tmp_r = a_r.new() 

181 b_r = a_r.desc.zeros() 

182 for grad in grad_v: 

183 grad(a_r, tmp_r) 

184 add_to_density(1, tmp_r.data, b_r.data) 

185 return b_r