Coverage for gpaw/utilities/ps2ae.py: 66%

152 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1from math import pi, sqrt 

2from warnings import warn 

3from typing import Optional, List, Dict 

4 

5import numpy as np 

6from ase.units import Bohr, Ha 

7 

8from gpaw.calculator import GPAW 

9from gpaw.atom.shapefunc import shape_functions 

10from gpaw.fftw import get_efficient_fft_size 

11from gpaw.grid_descriptor import GridDescriptor 

12from gpaw.lfc import LocalizedFunctionsCollection as LFC 

13from gpaw.utilities import h2gpts 

14from gpaw.pw.descriptor import PWDescriptor 

15from gpaw.mpi import serial_comm 

16from gpaw.setup import Setup 

17from gpaw.spline import Spline 

18from gpaw.typing import Array3D 

19 

20 

21class Interpolator: 

22 def __init__(self, gd1, gd2, dtype=float): 

23 self.pd1 = PWDescriptor(0.0, gd1, dtype) 

24 self.pd2 = PWDescriptor(0.0, gd2, dtype) 

25 

26 def interpolate(self, a_r): 

27 return self.pd1.interpolate(a_r, self.pd2)[0] 

28 

29 

30POINTS = 200 

31 

32 

33class PS2AE: 

34 """Transform PS to AE wave functions. 

35 

36 Interpolates PS wave functions to a fine grid and adds PAW 

37 corrections in order to obtain true AE wave functions. 

38 """ 

39 def __init__(self, 

40 calc: GPAW, 

41 grid_spacing: float = 0.05, 

42 n: int = 2, 

43 h=None # deprecated 

44 ): 

45 """Create transformation object. 

46 

47 calc: GPAW calculator object 

48 The calcalator that has the wave functions. 

49 grid_spacing: float 

50 Desired grid-spacing in Angstrom. 

51 n: int 

52 Force number of points to be a mulitiple of n. 

53 """ 

54 if h is not None: 

55 warn('Please use grid_spacing=... instead of h=...') 

56 grid_spacing = h 

57 

58 self.calc = calc 

59 gd = calc.wfs.gd 

60 

61 gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm) 

62 

63 # Descriptor for the final grid: 

64 N_c = h2gpts(grid_spacing / Bohr, gd.cell_cv) 

65 N_c = np.array([get_efficient_fft_size(N, n) for N in N_c]) 

66 gd2 = self.gd = GridDescriptor(N_c, gd.cell_cv, comm=serial_comm) 

67 self.interpolator = Interpolator(gd1, gd2, self.calc.wfs.dtype) 

68 

69 self._dphi: Optional[LFC] = None # PAW correction 

70 

71 self.dv = self.gd.dv * Bohr**3 

72 

73 @property 

74 def dphi(self) -> LFC: 

75 if self._dphi is not None: 

76 return self._dphi 

77 

78 splines: Dict[Setup, List[Spline]] = {} 

79 dphi_aj = [] 

80 for setup in self.calc.wfs.setups: 

81 dphi_j = splines.get(setup) 

82 if dphi_j is None: 

83 rcut = max(setup.rcut_j) * 1.1 

84 gcut = setup.rgd.ceil(rcut) 

85 dphi_j = [] 

86 for l, phi_g, phit_g in zip(setup.l_j, 

87 setup.data.phi_jg, 

88 setup.data.phit_jg): 

89 dphi_g = (phi_g - phit_g)[:gcut] 

90 dphi_j.append(setup.rgd.spline(dphi_g, rcut, l, 

91 points=200)) 

92 splines[setup] = dphi_j 

93 dphi_aj.append(dphi_j) 

94 

95 self._dphi = LFC(self.gd, dphi_aj, kd=self.calc.wfs.kd.copy(), 

96 dtype=self.calc.wfs.dtype) 

97 self._dphi.set_positions(self.calc.spos_ac) 

98 

99 return self._dphi 

100 

101 def get_wave_function(self, 

102 n: int, 

103 k: int = 0, 

104 s: int = 0, 

105 ae: bool = True, 

106 periodic: bool = False) -> Array3D: 

107 """Interpolate wave function. 

108 

109 Returns 3-d array in units of Ang**-1.5. 

110 

111 n: int 

112 Band index. 

113 k: int 

114 K-point index. 

115 s: int 

116 Spin index. 

117 ae: bool 

118 Add PAW correction to get an all-electron wave function. 

119 periodic: 

120 Return periodic part of wave-function, u(r), instead of 

121 psi(r)=exp(ikr)u(r). 

122 """ 

123 u_r = self.calc.get_pseudo_wave_function(n, k, s, 

124 periodic=True) 

125 u_R = self.interpolator.interpolate(u_r * Bohr**1.5) 

126 

127 k_c = self.calc.wfs.kd.ibzk_kc[k] 

128 gamma = np.isclose(k_c, 0.0).all() 

129 

130 if gamma: 

131 eikr_R = 1.0 

132 else: 

133 eikr_R = self.gd.plane_wave(k_c) 

134 

135 if ae: 

136 dphi = self.dphi 

137 wfs = self.calc.wfs 

138 P_nI = wfs.collect_projections(k, s) 

139 

140 if wfs.world.rank == 0: 

141 psi_R = u_R * eikr_R 

142 P_ai = {} 

143 I1 = 0 

144 for a, setup in enumerate(wfs.setups): 

145 I2 = I1 + setup.ni 

146 P_ai[a] = P_nI[n, I1:I2] 

147 I1 = I2 

148 dphi.add(psi_R, P_ai, k) 

149 u_R = psi_R / eikr_R 

150 

151 wfs.world.broadcast(u_R, 0) 

152 

153 if periodic: 

154 return u_R * Bohr**-1.5 

155 else: 

156 return u_R * eikr_R * Bohr**-1.5 

157 

158 def get_pseudo_density(self, 

159 add_compensation_charges: bool = True) -> Array3D: 

160 """Interpolate pseudo density.""" 

161 dens = self.calc.density 

162 gd1 = dens.gd 

163 assert gd1.comm.size == 1 

164 interpolator = Interpolator(gd1, self.gd) 

165 dens_r = dens.nt_sG[:dens.nspins].sum(axis=0) 

166 dens_R = interpolator.interpolate(dens_r) 

167 

168 if add_compensation_charges: 

169 dens.calculate_multipole_moments() 

170 ghat = LFC(self.gd, [setup.ghat_l for setup in dens.setups], 

171 integral=sqrt(4 * pi)) 

172 ghat.set_positions(self.calc.spos_ac) 

173 Q_aL = {} 

174 for a, Q_L in dens.Q_aL.items(): 

175 Q_aL[a] = Q_L.copy() 

176 Q_aL[a][0] += dens.setups[a].Nv / (4 * pi)**0.5 

177 ghat.add(dens_R, Q_aL) 

178 

179 return dens_R / Bohr**3 

180 

181 def get_electrostatic_potential(self, 

182 ae: bool = True, 

183 rcgauss: float = 0.02) -> Array3D: 

184 """Interpolate electrostatic potential. 

185 

186 Return value in eV. 

187 

188 ae: bool 

189 Add PAW correction to get the all-electron potential. 

190 rcgauss: float 

191 Width of gaussian (in Angstrom) used to represent the nuclear 

192 charge. 

193 """ 

194 gd = self.calc.hamiltonian.finegd 

195 v_r = self.calc.get_electrostatic_potential() / Ha 

196 gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm) 

197 interpolator = Interpolator(gd1, self.gd) 

198 v_R = interpolator.interpolate(v_r) 

199 

200 if ae: 

201 self.add_potential_correction(v_R, rcgauss / Bohr) 

202 

203 return v_R * Ha 

204 

205 def add_potential_correction(self, 

206 v_R: Array3D, 

207 rcgauss: float) -> None: 

208 dens = self.calc.density 

209 dens.D_asp.redistribute(dens.atom_partition.as_serial()) 

210 dens.Q_aL.redistribute(dens.atom_partition.as_serial()) 

211 

212 dv_a1 = [] 

213 for a, D_sp in dens.D_asp.items(): 

214 setup = dens.setups[a] 

215 c = setup.xc_correction 

216 rgd = c.rgd 

217 params = setup.data.shape_function.copy() 

218 params['lmax'] = 0 

219 ghat_g = shape_functions(rgd, **params)[0] 

220 Z_g = shape_functions(rgd, 'gauss', rcgauss, lmax=0)[0] * setup.Z 

221 D_q = np.dot(D_sp.sum(0), c.B_pqL[:, :, 0]) 

222 dn_g = np.dot(D_q, (c.n_qg - c.nt_qg)) * sqrt(4 * pi) 

223 dn_g += 4 * pi * (c.nc_g - c.nct_g) 

224 dn_g -= Z_g 

225 dn_g -= dens.Q_aL[a][0] * ghat_g * sqrt(4 * pi) 

226 dv_g = rgd.poisson(dn_g) / sqrt(4 * pi) 

227 dv_g[1:] /= rgd.r_g[1:] 

228 dv_g[0] = dv_g[1] 

229 dv_g[-1] = 0.0 

230 dv_a1.append([rgd.spline(dv_g, points=POINTS)]) 

231 

232 dens.D_asp.redistribute(dens.atom_partition) 

233 dens.Q_aL.redistribute(dens.atom_partition) 

234 

235 if dv_a1: 

236 dv = LFC(self.gd, dv_a1) 

237 dv.set_positions(self.calc.spos_ac) 

238 dv.add(v_R) 

239 dens.gd.comm.broadcast(v_R, 0) 

240 

241 

242def interpolate_weight(calc, weight, h=0.05, n=2): 

243 """interpolates cdft weight function, gd is the fine grid.""" 

244 gd = calc.density.finegd 

245 

246 weight = gd.collect(weight, broadcast=True) 

247 weight = gd.zero_pad(weight) 

248 

249 w = np.zeros_like(weight) 

250 gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm) 

251 gd1.distribute(weight, w) 

252 

253 N_c = h2gpts(h / Bohr, gd.cell_cv) 

254 N_c = np.array([get_efficient_fft_size(N, n) for N in N_c]) 

255 gd2 = GridDescriptor(N_c, gd.cell_cv, comm=serial_comm) 

256 

257 interpolator = Interpolator(gd1, gd2) 

258 W = interpolator.interpolate(w) 

259 

260 return W