Coverage for gpaw/new/pw/builder.py: 82%
183 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1from __future__ import annotations
3import os
4import warnings
5from functools import cached_property
7import numpy as np
8from ase.units import Ha
10from gpaw.core import PWDesc, UGDesc
11from gpaw.core.domain import Domain
12from gpaw.core.matrix import Matrix
13from gpaw.core.plane_waves import PWArray
14from gpaw.new import zips
15from gpaw.new.builder import create_uniform_grid
16from gpaw.new.gpw import as_double_precision
17from gpaw.new.pw.bloechl_poisson import BloechlPAWPoissonSolver
18from gpaw.new.pw.hamiltonian import PWHamiltonian, SpinorPWHamiltonian
19from gpaw.new.pw.hybrids import PWHybridHamiltonian
20from gpaw.new.pw.paw_poisson import SlowPAWPoissonSolver
21from gpaw.new.pw.poisson import make_poisson_solver
22from gpaw.new.pw.pot_calc import PlaneWavePotentialCalculator
23from gpaw.new.pwfd.builder import PWFDDFTComponentsBuilder
24from gpaw.new.xc import create_functional
25from gpaw.typing import Array1D
28class PWDFTComponentsBuilder(PWFDDFTComponentsBuilder):
29 def __init__(self,
30 atoms,
31 params,
32 *,
33 comm=None,
34 log=None):
35 mode = params.mode
36 self.ecut = mode.ecut / Ha
37 # mode.dedecut ???
38 super().__init__(atoms, params, comm=comm, log=log)
40 self._nct_ag = None
41 self._tauct_ag = None
43 nthreads = int(os.environ.get('OMP_NUM_THREADS', '') or '1')
44 if nthreads > 1:
45 warnings.warn(
46 'Using OMP_NUM_THREADS>1 in PW-mode is not useful!')
47 # We should just distribute the atom evenly, but that is not compatible
48 # with LCAO initialization!
49 # return AtomDistribution.from_number_of_atoms(len(self.relpos_ac),
50 # self.communicators['d'])
52 def create_uniform_grids(self):
53 grid = create_uniform_grid(
54 'pw',
55 self.params.gpts,
56 self.atoms.cell,
57 self.atoms.pbc,
58 self.ibz.symmetries,
59 h=self.params.h,
60 interpolation=self.params.interpolation or 'fft',
61 ecut=self.ecut,
62 comm=self.communicators['d'])
63 fine_grid = grid.new(size=grid.size_c * 2)
64 # decomposition=[2 * d for d in grid.decomposition]
65 return grid, fine_grid
67 def create_wf_description(self) -> Domain:
68 return PWDesc(ecut=self.ecut,
69 cell=self.grid.cell,
70 comm=self.grid.comm,
71 dtype=self.dtype)
73 def create_xc_functional(self):
74 return create_functional(self._xc,
75 self.fine_grid, self.xp)
77 @cached_property
78 def interpolation_desc(self):
79 """Plane-wave set used for interpolating from corse to fine grid."""
80 # By default, the size of the grid used for the FFT's (self.grid)
81 # will acommodate G-vectors up to 2 * self.ecut, but the grid-size
82 # could have been set using h=... or gpts=...
83 ecut = min(2 * self.ecut, self.grid.ekin_max())
84 return PWDesc(ecut=ecut,
85 cell=self.grid.cell,
86 comm=self.grid.comm)
88 @cached_property
89 def electrostatic_potential_desc(self):
90 if self.fast_poisson_solver:
91 return self.interpolation_desc
92 return self.interpolation_desc.new(ecut=8 * self.ecut)
94 @cached_property
95 def fast_poisson_solver(self) -> bool:
96 fast = self.params.poissonsolver.params.get('fast', False)
97 if fast:
98 # Only works for gaussian compensation charges at the moment:
99 fast = False
100 for s in self.setups:
101 if not hasattr(s, 'data'):
102 break
103 if s.data.shape_function['type'] != 'gauss':
104 break
105 else: # no break
106 fast = True
107 return fast
109 def get_pseudo_core_densities(self):
110 if self._nct_ag is None:
111 self._nct_ag = self.setups.create_pseudo_core_densities(
112 self.interpolation_desc, self.relpos_ac, self.atomdist,
113 xp=self.xp)
114 return self._nct_ag
116 def get_pseudo_core_ked(self):
117 if self._tauct_ag is None:
118 self._tauct_ag = self.setups.create_pseudo_core_ked(
119 self.interpolation_desc, self.relpos_ac, self.atomdist)
120 return self._tauct_ag
122 def create_poisson_solver(self, env):
123 psparams = self.params.poissonsolver.params.copy() or {'strength': 1.0}
124 psparams.pop('fast', False)
126 if self.fast_poisson_solver:
127 grid = self.grid
128 else:
129 grid = self.fine_grid
131 pw = self.electrostatic_potential_desc
132 ps = make_poisson_solver(pw,
133 grid,
134 self.charge,
135 env,
136 **psparams)
138 if self.fast_poisson_solver:
139 cutoff_a = [s.data.shape_function['rc'] for s in self.setups]
140 return BloechlPAWPoissonSolver(
141 pw, cutoff_a, ps, self.relpos_ac, self.atomdist, self.xp)
143 return SlowPAWPoissonSolver(
144 self.interpolation_desc,
145 self.setups,
146 ps, self.relpos_ac, self.atomdist, self.xp)
148 def create_potential_calculator(self):
149 env = self.create_environment(self.fine_grid)
150 return PlaneWavePotentialCalculator(
151 self.grid, self.fine_grid,
152 self.interpolation_desc,
153 self.setups,
154 self.xc,
155 self.create_poisson_solver(env),
156 relpos_ac=self.relpos_ac,
157 atomdist=self.atomdist,
158 soc=self.soc,
159 xp=self.xp,
160 environment=env,
161 extensions=self.get_extensions())
163 def create_hamiltonian_operator(self, blocksize=10):
164 if self.ncomponents < 4:
165 if self.xc.exx_fraction == 0.0:
166 return PWHamiltonian(self.grid, self.wf_desc, self.xp)
167 assert self.communicators['d'].size == 1
168 assert self.communicators['k'].size == 1
169 assert self.nbands % self.communicators['b'].size == 0
170 return PWHybridHamiltonian(
171 self.grid, self.wf_desc, self.xc, self.setups,
172 self.relpos_ac, self.atomdist,
173 comp_charge_in_real_space=self.params.experimental.get(
174 'ccirs'))
175 return SpinorPWHamiltonian(self.qspiral_v)
177 def convert_wave_functions_from_uniform_grid(self,
178 C_nM: Matrix,
179 basis_set,
180 kpt_c,
181 q):
182 if self.params.experimental.get('fast_pw_init', True):
183 if self.ncomponents < 4:
184 from gpaw.core.pwacf import PWAtomCenteredFunctions
185 pw = self.wf_desc.new(kpt=kpt_c)
186 phit_aJG = PWAtomCenteredFunctions(
187 [setup.basis_functions_J for setup in self.setups],
188 self.relpos_ac,
189 pw,
190 atomdist=self.atomdist,
191 xp=self.xp)
192 psit_nG = pw.empty(self.nbands,
193 comm=self.communicators['b'],
194 xp=self.xp)
195 mynbands, M = C_nM.dist.shape
196 phit_aJG.multiply(C_nM.to_xp(self.xp).to_dtype(pw.dtype),
197 out_nG=psit_nG[:mynbands])
198 return psit_nG
200 lcao_dtype = complex if \
201 np.issubdtype(self.dtype, np.complexfloating) else float
203 grid = self.grid.new(kpt=kpt_c, dtype=lcao_dtype)
204 pw = self.wf_desc.new(kpt=kpt_c, dtype=lcao_dtype)
205 if self.dtype != lcao_dtype:
206 pw_correct = self.wf_desc.new(kpt=kpt_c, dtype=self.dtype)
208 if np.issubdtype(self.dtype, np.complexfloating):
209 emikr_R = grid.eikr(-kpt_c)
211 mynbands, M = C_nM.dist.shape
212 if self.ncomponents < 4:
213 psit_nG = pw.empty(self.nbands, self.communicators['b'])
214 psit_nR = grid.zeros(mynbands)
215 basis_set.lcao_to_grid(C_nM.data, psit_nR.data, q)
217 for psit_R, psit_G in zips(psit_nR, psit_nG, strict=False):
218 if np.issubdtype(self.dtype, np.complexfloating):
219 psit_R.data *= emikr_R
220 psit_R.fft(out=psit_G)
222 if self.dtype != lcao_dtype:
223 psit2_nG = pw_correct.empty(self.nbands,
224 self.communicators['b'])
225 psit2_nG.data[:] = psit_nG.data
226 return psit2_nG.to_xp(self.xp)
227 return psit_nG.to_xp(self.xp)
228 else:
229 psit_nsG = pw.empty((self.nbands, 2), self.communicators['b'])
230 psit_sR = grid.empty(2)
231 C_nsM = C_nM.data.reshape((mynbands, 2, M // 2))
232 for psit_sG, C_sM in zips(psit_nsG, C_nsM, strict=False):
233 psit_sR.data[:] = 0.0
234 basis_set.lcao_to_grid(C_sM, psit_sR.data, q)
235 psit_sR.data *= emikr_R
236 for psit_G, psit_R in zips(psit_sG, psit_sR):
237 psit_R.fft(out=psit_G)
238 return psit_nsG
240 def read_ibz_wave_functions(self, reader):
241 ibzwfs = super().read_ibz_wave_functions(reader)
243 if 'coefficients' not in reader.wave_functions:
244 return ibzwfs
246 singlep = reader.get('precision', 'double') == 'single'
247 c = reader.bohr**1.5
248 if reader.version < 0:
249 c = 1 # very old gpw file
250 elif reader.version < 4:
251 c /= self.grid.size_c.prod()
253 index_kG = reader.wave_functions.indices
255 if self.ncomponents == 4:
256 shape = (self.nbands, 2)
257 else:
258 shape = (self.nbands,)
260 for wfs in ibzwfs:
261 pw = self.wf_desc.new(kpt=wfs.kpt_c)
262 if wfs.spin == 0:
263 check_g_vector_ordering(self.grid, pw, index_kG[wfs.k])
265 index = (wfs.spin, wfs.k) if self.ncomponents != 4 else (wfs.k,)
266 data = reader.wave_functions.proxy('coefficients', *index)
267 data.scale = c
268 data.length_of_last_dimension = pw.shape[-1]
270 if self.communicators['w'].size == 1 and not singlep:
271 orig_shape = data.shape
272 data.shape = shape + pw.shape
273 wfs.psit_nX = pw.from_data(data)
274 data.shape = orig_shape
275 else:
276 band_comm = self.communicators['b']
277 wfs.psit_nX = PWArray(pw, shape, comm=band_comm)
278 mynbands = (self.nbands +
279 band_comm.size - 1) // band_comm.size
280 n1 = min(band_comm.rank * mynbands, self.nbands)
281 n2 = min((band_comm.rank + 1) * mynbands, self.nbands)
282 if pw.comm.rank == 0:
283 assert wfs.psit_nX.mydims[0] == n2 - n1
284 data = data[n1:n2] # read from file
285 else:
286 data = [None] * (n2 - n1)
287 for psit_G, array in zips(wfs.psit_nX, data):
288 if singlep:
289 psit_G.scatter_from(as_double_precision(array))
290 else:
291 psit_G.scatter_from(array)
293 return ibzwfs
296def check_g_vector_ordering(grid: UGDesc,
297 pw: PWDesc,
298 index_G: Array1D) -> None:
299 size = tuple(grid.size)
300 if np.issubdtype(pw.dtype, np.floating):
301 size = (size[0], size[1], size[2] // 2 + 1)
302 index0_G = pw.indices(size)
303 nG = len(index0_G)
304 assert (index0_G == index_G[:nG]).all()
305 assert (index_G[nG:] == -1).all()