Coverage for gpaw/new/pw/pot_calc.py: 91%

120 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1import numpy as np 

2from gpaw.core import PWDesc 

3from gpaw.mpi import broadcast_float 

4from gpaw.new import zips, spinsum, trace 

5from gpaw.new.pot_calc import PotentialCalculator 

6from gpaw.new.pw.stress import calculate_stress 

7from gpaw.setup import Setups 

8 

9 

10class PlaneWavePotentialCalculator(PotentialCalculator): 

11 def __init__(self, 

12 grid, 

13 fine_grid, 

14 pw: PWDesc, 

15 setups: Setups, 

16 xc, 

17 poisson_solver, 

18 *, 

19 relpos_ac, 

20 atomdist, 

21 environment, 

22 extensions, 

23 soc=False, 

24 xp=np): 

25 self.xp = xp 

26 self.pw = pw 

27 super().__init__(xc, poisson_solver, setups, 

28 relpos_ac=relpos_ac, 

29 environment=environment, 

30 extensions=extensions, 

31 soc=soc) 

32 

33 self.vbar_ag = setups.create_local_potentials( 

34 pw, relpos_ac, atomdist, xp) 

35 

36 self.fftplan = grid.fft_plans(xp=xp) 

37 self.fftplan2 = fine_grid.fft_plans(xp=xp) 

38 

39 self.grid = grid 

40 self.fine_grid = fine_grid 

41 

42 self.vbar_g = pw.zeros(xp=xp) 

43 self.vbar_ag.add_to(self.vbar_g) 

44 self.vbar0_g = self.vbar_g.gather() 

45 

46 # For forces and stress: 

47 self._nt_g = None 

48 self._vt_g = None 

49 self._dedtaut_g = None 

50 

51 def interpolate(self, a_R, a_r=None): 

52 return a_R.interpolate(self.fftplan, self.fftplan2, 

53 grid=self.fine_grid, out=a_r) 

54 

55 def restrict(self, a_r, a_R=None): 

56 return a_r.fft_restrict(self.fftplan2, self.fftplan, 

57 grid=self.grid, out=a_R) 

58 

59 def _interpolate_density(self, nt_sR): 

60 nt_sr = self.fine_grid.empty(nt_sR.dims, xp=self.xp) 

61 pw = self.vbar_g.desc 

62 

63 if pw.comm.rank == 0: 

64 pw0 = self.poisson_solver.pwg0 

65 indices = self.xp.asarray(pw0.indices(self.fftplan.shape)) 

66 nt0_g = pw0.zeros(xp=self.xp) 

67 else: 

68 nt0_g = None 

69 

70 ndensities = nt_sR.dims[0] % 3 

71 for spin, (nt_R, nt_r) in enumerate(zips(nt_sR, nt_sr)): 

72 self.interpolate(nt_R, nt_r) 

73 if spin < ndensities and pw.comm.rank == 0: 

74 nt0_g.data += self.xp.asarray( 

75 self.fftplan.tmp_Q.ravel()[indices]) 

76 

77 return nt_sr, nt0_g 

78 

79 def _interpolate_and_calculate_xc(self, xc, density): 

80 nt_sr, nt0_g = self._interpolate_density(density.nt_sR) 

81 

82 if density.taut_sR is not None: 

83 taut_sr = self.interpolate(density.taut_sR) 

84 else: 

85 taut_sr = None 

86 

87 e_xc, vxct_sr, dedtaut_sr = xc.calculate(nt_sr, taut_sr) 

88 

89 return nt_sr, nt0_g, taut_sr, e_xc, vxct_sr, dedtaut_sr 

90 

91 def calculate_non_selfconsistent_exc(self, xc, density): 

92 _, _, _, e_xc, _, _ = self._interpolate_and_calculate_xc(xc, density) 

93 return e_xc 

94 

95 @trace 

96 def calculate_pseudo_potential(self, density, ibzwfs, vHt_h=None): 

97 nt_sr, nt0_g, taut_sr, e_xc, vxct_sr, dedtaut_sr = ( 

98 self._interpolate_and_calculate_xc(self.xc, density)) 

99 

100 pw = self.vbar_g.desc 

101 if pw.comm.rank == 0: 

102 nt0_g.data *= 1 / np.prod(density.nt_sR.desc.size_c) 

103 e_zero = self.vbar0_g.integrate(nt0_g) 

104 else: 

105 e_zero = 0.0 

106 e_zero = broadcast_float(float(e_zero), pw.comm) 

107 

108 if pw.comm.rank == 0: 

109 vt0_g = self.vbar0_g.copy() 

110 else: 

111 vt0_g = None 

112 

113 self.environment.update1pw(nt0_g) 

114 

115 Q_aL = density.calculate_compensation_charge_coefficients() 

116 

117 e_coulomb, vHt_h, V_aL = self.poisson_solver.solve( 

118 nt0_g, Q_aL, vt0_g, vHt_h) 

119 

120 if pw.comm.rank == 0: 

121 vt0_R = vt0_g.ifft( 

122 plan=self.fftplan, 

123 grid=density.nt_sR.desc.new(comm=None)) 

124 

125 vt_sR = density.nt_sR.new() 

126 vt_sR[0].scatter_from(vt0_R if pw.comm.rank == 0 else None) 

127 if density.ndensities == 2: 

128 vt_sR.data[1] = vt_sR.data[0] 

129 vt_sR.data[density.ndensities:] = 0.0 

130 

131 # e_external = self.external_potential.update_potential(vt_sR, density) 

132 e_external = 0.0 

133 

134 vtmp_R = vt_sR.desc.empty(xp=self.xp) 

135 for spin, (vt_R, vxct_r) in enumerate(zips(vt_sR, vxct_sr)): 

136 self.restrict(vxct_r, vtmp_R) 

137 vt_R.data += vtmp_R.data 

138 

139 self._reset() 

140 

141 e_stress = e_coulomb + e_zero 

142 

143 return ({'coulomb': e_coulomb, 

144 'zero': e_zero, 

145 'xc': e_xc, 

146 'external': e_external}, 

147 vt_sR, 

148 dedtaut_sr, 

149 vHt_h, 

150 V_aL, 

151 e_stress) 

152 

153 def move(self, relpos_ac, atomdist): 

154 super().move(relpos_ac, atomdist) 

155 self.poisson_solver.move(relpos_ac, atomdist) 

156 self.vbar_ag.move(relpos_ac, atomdist) 

157 self.vbar_g.data[:] = 0.0 

158 self.vbar_ag.add_to(self.vbar_g) 

159 self.vbar0_g = self.vbar_g.gather() 

160 self._reset() 

161 

162 def _reset(self): 

163 self._vt_g = None 

164 self._nt_g = None 

165 self._dedtaut_g = None 

166 

167 def _force_stress_helper(self, density, potential): 

168 # Only do the work once - in case both forces and stresses are needed: 

169 if self._vt_g is not None: 

170 return self._vt_g, self._nt_g, self._dedtaut_g 

171 

172 nt_R = spinsum(density.nt_sR) 

173 vt_R = spinsum(potential.vt_sR, mean=True) 

174 self._vt_g = vt_R.fft(self.fftplan, pw=self.pw) 

175 self._nt_g = nt_R.fft(self.fftplan, pw=self.pw) 

176 

177 dedtaut_sR = potential.dedtaut_sR 

178 if dedtaut_sR is not None: 

179 dedtaut_R = spinsum(dedtaut_sR, mean=True) 

180 self._dedtaut_g = dedtaut_R.fft(self.fftplan, pw=self.pw) 

181 else: 

182 self._dedtaut_g = None 

183 

184 return self._vt_g, self._nt_g, self._dedtaut_g 

185 

186 def force_contributions(self, Q_aL, density, potential): 

187 if potential.vHt_x is None: 

188 raise RuntimeError(ERROR.format(thing='forces')) 

189 vt_g, nt_g, dedtaut_g = self._force_stress_helper(density, potential) 

190 if dedtaut_g is None: 

191 Ftauct_av = None 

192 else: 

193 Ftauct_av = density.tauct_aX.derivative(dedtaut_g) 

194 

195 return ( 

196 self.poisson_solver.force_contribution(Q_aL, 

197 potential.vHt_x, 

198 nt_g), 

199 density.nct_aX.derivative(vt_g), 

200 Ftauct_av, 

201 self.vbar_ag.derivative(nt_g), 

202 self.extensions_force_av) 

203 

204 def stress(self, ibzwfs, density, potential): 

205 if potential.vHt_x is None: 

206 raise RuntimeError(ERROR.format(thing='stress')) 

207 vt_g, nt_g, dedtaut_g = self._force_stress_helper(density, potential) 

208 return calculate_stress(self, ibzwfs, density, potential, 

209 vt_g, nt_g, dedtaut_g) 

210 

211 

212ERROR = ( 

213 'Unable to calculate {thing}. Are you restartting from an old ' 

214 'gpw-file? In that case, calculate the {thing} before writing ' 

215 'the gpw-file or switch to new GPAW.')