Coverage for gpaw/utilities/ps2ae.py: 66%
152 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1from math import pi, sqrt
2from warnings import warn
3from typing import Optional, List, Dict
5import numpy as np
6from ase.units import Bohr, Ha
8from gpaw.calculator import GPAW
9from gpaw.atom.shapefunc import shape_functions
10from gpaw.fftw import get_efficient_fft_size
11from gpaw.grid_descriptor import GridDescriptor
12from gpaw.lfc import LocalizedFunctionsCollection as LFC
13from gpaw.utilities import h2gpts
14from gpaw.pw.descriptor import PWDescriptor
15from gpaw.mpi import serial_comm
16from gpaw.setup import Setup
17from gpaw.spline import Spline
18from gpaw.typing import Array3D
21class Interpolator:
22 def __init__(self, gd1, gd2, dtype=float):
23 self.pd1 = PWDescriptor(0.0, gd1, dtype)
24 self.pd2 = PWDescriptor(0.0, gd2, dtype)
26 def interpolate(self, a_r):
27 return self.pd1.interpolate(a_r, self.pd2)[0]
30POINTS = 200
33class PS2AE:
34 """Transform PS to AE wave functions.
36 Interpolates PS wave functions to a fine grid and adds PAW
37 corrections in order to obtain true AE wave functions.
38 """
39 def __init__(self,
40 calc: GPAW,
41 grid_spacing: float = 0.05,
42 n: int = 2,
43 h=None # deprecated
44 ):
45 """Create transformation object.
47 calc: GPAW calculator object
48 The calcalator that has the wave functions.
49 grid_spacing: float
50 Desired grid-spacing in Angstrom.
51 n: int
52 Force number of points to be a mulitiple of n.
53 """
54 if h is not None:
55 warn('Please use grid_spacing=... instead of h=...')
56 grid_spacing = h
58 self.calc = calc
59 gd = calc.wfs.gd
61 gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm)
63 # Descriptor for the final grid:
64 N_c = h2gpts(grid_spacing / Bohr, gd.cell_cv)
65 N_c = np.array([get_efficient_fft_size(N, n) for N in N_c])
66 gd2 = self.gd = GridDescriptor(N_c, gd.cell_cv, comm=serial_comm)
67 self.interpolator = Interpolator(gd1, gd2, self.calc.wfs.dtype)
69 self._dphi: Optional[LFC] = None # PAW correction
71 self.dv = self.gd.dv * Bohr**3
73 @property
74 def dphi(self) -> LFC:
75 if self._dphi is not None:
76 return self._dphi
78 splines: Dict[Setup, List[Spline]] = {}
79 dphi_aj = []
80 for setup in self.calc.wfs.setups:
81 dphi_j = splines.get(setup)
82 if dphi_j is None:
83 rcut = max(setup.rcut_j) * 1.1
84 gcut = setup.rgd.ceil(rcut)
85 dphi_j = []
86 for l, phi_g, phit_g in zip(setup.l_j,
87 setup.data.phi_jg,
88 setup.data.phit_jg):
89 dphi_g = (phi_g - phit_g)[:gcut]
90 dphi_j.append(setup.rgd.spline(dphi_g, rcut, l,
91 points=200))
92 splines[setup] = dphi_j
93 dphi_aj.append(dphi_j)
95 self._dphi = LFC(self.gd, dphi_aj, kd=self.calc.wfs.kd.copy(),
96 dtype=self.calc.wfs.dtype)
97 self._dphi.set_positions(self.calc.spos_ac)
99 return self._dphi
101 def get_wave_function(self,
102 n: int,
103 k: int = 0,
104 s: int = 0,
105 ae: bool = True,
106 periodic: bool = False) -> Array3D:
107 """Interpolate wave function.
109 Returns 3-d array in units of Ang**-1.5.
111 n: int
112 Band index.
113 k: int
114 K-point index.
115 s: int
116 Spin index.
117 ae: bool
118 Add PAW correction to get an all-electron wave function.
119 periodic:
120 Return periodic part of wave-function, u(r), instead of
121 psi(r)=exp(ikr)u(r).
122 """
123 u_r = self.calc.get_pseudo_wave_function(n, k, s,
124 periodic=True)
125 u_R = self.interpolator.interpolate(u_r * Bohr**1.5)
127 k_c = self.calc.wfs.kd.ibzk_kc[k]
128 gamma = np.isclose(k_c, 0.0).all()
130 if gamma:
131 eikr_R = 1.0
132 else:
133 eikr_R = self.gd.plane_wave(k_c)
135 if ae:
136 dphi = self.dphi
137 wfs = self.calc.wfs
138 P_nI = wfs.collect_projections(k, s)
140 if wfs.world.rank == 0:
141 psi_R = u_R * eikr_R
142 P_ai = {}
143 I1 = 0
144 for a, setup in enumerate(wfs.setups):
145 I2 = I1 + setup.ni
146 P_ai[a] = P_nI[n, I1:I2]
147 I1 = I2
148 dphi.add(psi_R, P_ai, k)
149 u_R = psi_R / eikr_R
151 wfs.world.broadcast(u_R, 0)
153 if periodic:
154 return u_R * Bohr**-1.5
155 else:
156 return u_R * eikr_R * Bohr**-1.5
158 def get_pseudo_density(self,
159 add_compensation_charges: bool = True) -> Array3D:
160 """Interpolate pseudo density."""
161 dens = self.calc.density
162 gd1 = dens.gd
163 assert gd1.comm.size == 1
164 interpolator = Interpolator(gd1, self.gd)
165 dens_r = dens.nt_sG[:dens.nspins].sum(axis=0)
166 dens_R = interpolator.interpolate(dens_r)
168 if add_compensation_charges:
169 dens.calculate_multipole_moments()
170 ghat = LFC(self.gd, [setup.ghat_l for setup in dens.setups],
171 integral=sqrt(4 * pi))
172 ghat.set_positions(self.calc.spos_ac)
173 Q_aL = {}
174 for a, Q_L in dens.Q_aL.items():
175 Q_aL[a] = Q_L.copy()
176 Q_aL[a][0] += dens.setups[a].Nv / (4 * pi)**0.5
177 ghat.add(dens_R, Q_aL)
179 return dens_R / Bohr**3
181 def get_electrostatic_potential(self,
182 ae: bool = True,
183 rcgauss: float = 0.02) -> Array3D:
184 """Interpolate electrostatic potential.
186 Return value in eV.
188 ae: bool
189 Add PAW correction to get the all-electron potential.
190 rcgauss: float
191 Width of gaussian (in Angstrom) used to represent the nuclear
192 charge.
193 """
194 gd = self.calc.hamiltonian.finegd
195 v_r = self.calc.get_electrostatic_potential() / Ha
196 gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm)
197 interpolator = Interpolator(gd1, self.gd)
198 v_R = interpolator.interpolate(v_r)
200 if ae:
201 self.add_potential_correction(v_R, rcgauss / Bohr)
203 return v_R * Ha
205 def add_potential_correction(self,
206 v_R: Array3D,
207 rcgauss: float) -> None:
208 dens = self.calc.density
209 dens.D_asp.redistribute(dens.atom_partition.as_serial())
210 dens.Q_aL.redistribute(dens.atom_partition.as_serial())
212 dv_a1 = []
213 for a, D_sp in dens.D_asp.items():
214 setup = dens.setups[a]
215 c = setup.xc_correction
216 rgd = c.rgd
217 params = setup.data.shape_function.copy()
218 params['lmax'] = 0
219 ghat_g = shape_functions(rgd, **params)[0]
220 Z_g = shape_functions(rgd, 'gauss', rcgauss, lmax=0)[0] * setup.Z
221 D_q = np.dot(D_sp.sum(0), c.B_pqL[:, :, 0])
222 dn_g = np.dot(D_q, (c.n_qg - c.nt_qg)) * sqrt(4 * pi)
223 dn_g += 4 * pi * (c.nc_g - c.nct_g)
224 dn_g -= Z_g
225 dn_g -= dens.Q_aL[a][0] * ghat_g * sqrt(4 * pi)
226 dv_g = rgd.poisson(dn_g) / sqrt(4 * pi)
227 dv_g[1:] /= rgd.r_g[1:]
228 dv_g[0] = dv_g[1]
229 dv_g[-1] = 0.0
230 dv_a1.append([rgd.spline(dv_g, points=POINTS)])
232 dens.D_asp.redistribute(dens.atom_partition)
233 dens.Q_aL.redistribute(dens.atom_partition)
235 if dv_a1:
236 dv = LFC(self.gd, dv_a1)
237 dv.set_positions(self.calc.spos_ac)
238 dv.add(v_R)
239 dens.gd.comm.broadcast(v_R, 0)
242def interpolate_weight(calc, weight, h=0.05, n=2):
243 """interpolates cdft weight function, gd is the fine grid."""
244 gd = calc.density.finegd
246 weight = gd.collect(weight, broadcast=True)
247 weight = gd.zero_pad(weight)
249 w = np.zeros_like(weight)
250 gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm)
251 gd1.distribute(weight, w)
253 N_c = h2gpts(h / Bohr, gd.cell_cv)
254 N_c = np.array([get_efficient_fft_size(N, n) for N in N_c])
255 gd2 = GridDescriptor(N_c, gd.cell_cv, comm=serial_comm)
257 interpolator = Interpolator(gd1, gd2)
258 W = interpolator.interpolate(w)
260 return W