Coverage for gpaw/new/pw/builder.py: 82%

183 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1from __future__ import annotations 

2 

3import os 

4import warnings 

5from functools import cached_property 

6 

7import numpy as np 

8from ase.units import Ha 

9 

10from gpaw.core import PWDesc, UGDesc 

11from gpaw.core.domain import Domain 

12from gpaw.core.matrix import Matrix 

13from gpaw.core.plane_waves import PWArray 

14from gpaw.new import zips 

15from gpaw.new.builder import create_uniform_grid 

16from gpaw.new.gpw import as_double_precision 

17from gpaw.new.pw.bloechl_poisson import BloechlPAWPoissonSolver 

18from gpaw.new.pw.hamiltonian import PWHamiltonian, SpinorPWHamiltonian 

19from gpaw.new.pw.hybrids import PWHybridHamiltonian 

20from gpaw.new.pw.paw_poisson import SlowPAWPoissonSolver 

21from gpaw.new.pw.poisson import make_poisson_solver 

22from gpaw.new.pw.pot_calc import PlaneWavePotentialCalculator 

23from gpaw.new.pwfd.builder import PWFDDFTComponentsBuilder 

24from gpaw.new.xc import create_functional 

25from gpaw.typing import Array1D 

26 

27 

28class PWDFTComponentsBuilder(PWFDDFTComponentsBuilder): 

29 def __init__(self, 

30 atoms, 

31 params, 

32 *, 

33 comm=None, 

34 log=None): 

35 mode = params.mode 

36 self.ecut = mode.ecut / Ha 

37 # mode.dedecut ??? 

38 super().__init__(atoms, params, comm=comm, log=log) 

39 

40 self._nct_ag = None 

41 self._tauct_ag = None 

42 

43 nthreads = int(os.environ.get('OMP_NUM_THREADS', '') or '1') 

44 if nthreads > 1: 

45 warnings.warn( 

46 'Using OMP_NUM_THREADS>1 in PW-mode is not useful!') 

47 # We should just distribute the atom evenly, but that is not compatible 

48 # with LCAO initialization! 

49 # return AtomDistribution.from_number_of_atoms(len(self.relpos_ac), 

50 # self.communicators['d']) 

51 

52 def create_uniform_grids(self): 

53 grid = create_uniform_grid( 

54 'pw', 

55 self.params.gpts, 

56 self.atoms.cell, 

57 self.atoms.pbc, 

58 self.ibz.symmetries, 

59 h=self.params.h, 

60 interpolation=self.params.interpolation or 'fft', 

61 ecut=self.ecut, 

62 comm=self.communicators['d']) 

63 fine_grid = grid.new(size=grid.size_c * 2) 

64 # decomposition=[2 * d for d in grid.decomposition] 

65 return grid, fine_grid 

66 

67 def create_wf_description(self) -> Domain: 

68 return PWDesc(ecut=self.ecut, 

69 cell=self.grid.cell, 

70 comm=self.grid.comm, 

71 dtype=self.dtype) 

72 

73 def create_xc_functional(self): 

74 return create_functional(self._xc, 

75 self.fine_grid, self.xp) 

76 

77 @cached_property 

78 def interpolation_desc(self): 

79 """Plane-wave set used for interpolating from corse to fine grid.""" 

80 # By default, the size of the grid used for the FFT's (self.grid) 

81 # will acommodate G-vectors up to 2 * self.ecut, but the grid-size 

82 # could have been set using h=... or gpts=... 

83 ecut = min(2 * self.ecut, self.grid.ekin_max()) 

84 return PWDesc(ecut=ecut, 

85 cell=self.grid.cell, 

86 comm=self.grid.comm) 

87 

88 @cached_property 

89 def electrostatic_potential_desc(self): 

90 if self.fast_poisson_solver: 

91 return self.interpolation_desc 

92 return self.interpolation_desc.new(ecut=8 * self.ecut) 

93 

94 @cached_property 

95 def fast_poisson_solver(self) -> bool: 

96 fast = self.params.poissonsolver.params.get('fast', False) 

97 if fast: 

98 # Only works for gaussian compensation charges at the moment: 

99 fast = False 

100 for s in self.setups: 

101 if not hasattr(s, 'data'): 

102 break 

103 if s.data.shape_function['type'] != 'gauss': 

104 break 

105 else: # no break 

106 fast = True 

107 return fast 

108 

109 def get_pseudo_core_densities(self): 

110 if self._nct_ag is None: 

111 self._nct_ag = self.setups.create_pseudo_core_densities( 

112 self.interpolation_desc, self.relpos_ac, self.atomdist, 

113 xp=self.xp) 

114 return self._nct_ag 

115 

116 def get_pseudo_core_ked(self): 

117 if self._tauct_ag is None: 

118 self._tauct_ag = self.setups.create_pseudo_core_ked( 

119 self.interpolation_desc, self.relpos_ac, self.atomdist) 

120 return self._tauct_ag 

121 

122 def create_poisson_solver(self, env): 

123 psparams = self.params.poissonsolver.params.copy() or {'strength': 1.0} 

124 psparams.pop('fast', False) 

125 

126 if self.fast_poisson_solver: 

127 grid = self.grid 

128 else: 

129 grid = self.fine_grid 

130 

131 pw = self.electrostatic_potential_desc 

132 ps = make_poisson_solver(pw, 

133 grid, 

134 self.charge, 

135 env, 

136 **psparams) 

137 

138 if self.fast_poisson_solver: 

139 cutoff_a = [s.data.shape_function['rc'] for s in self.setups] 

140 return BloechlPAWPoissonSolver( 

141 pw, cutoff_a, ps, self.relpos_ac, self.atomdist, self.xp) 

142 

143 return SlowPAWPoissonSolver( 

144 self.interpolation_desc, 

145 self.setups, 

146 ps, self.relpos_ac, self.atomdist, self.xp) 

147 

148 def create_potential_calculator(self): 

149 env = self.create_environment(self.fine_grid) 

150 return PlaneWavePotentialCalculator( 

151 self.grid, self.fine_grid, 

152 self.interpolation_desc, 

153 self.setups, 

154 self.xc, 

155 self.create_poisson_solver(env), 

156 relpos_ac=self.relpos_ac, 

157 atomdist=self.atomdist, 

158 soc=self.soc, 

159 xp=self.xp, 

160 environment=env, 

161 extensions=self.get_extensions()) 

162 

163 def create_hamiltonian_operator(self, blocksize=10): 

164 if self.ncomponents < 4: 

165 if self.xc.exx_fraction == 0.0: 

166 return PWHamiltonian(self.grid, self.wf_desc, self.xp) 

167 assert self.communicators['d'].size == 1 

168 assert self.communicators['k'].size == 1 

169 assert self.nbands % self.communicators['b'].size == 0 

170 return PWHybridHamiltonian( 

171 self.grid, self.wf_desc, self.xc, self.setups, 

172 self.relpos_ac, self.atomdist, 

173 comp_charge_in_real_space=self.params.experimental.get( 

174 'ccirs')) 

175 return SpinorPWHamiltonian(self.qspiral_v) 

176 

177 def convert_wave_functions_from_uniform_grid(self, 

178 C_nM: Matrix, 

179 basis_set, 

180 kpt_c, 

181 q): 

182 if self.params.experimental.get('fast_pw_init', True): 

183 if self.ncomponents < 4: 

184 from gpaw.core.pwacf import PWAtomCenteredFunctions 

185 pw = self.wf_desc.new(kpt=kpt_c) 

186 phit_aJG = PWAtomCenteredFunctions( 

187 [setup.basis_functions_J for setup in self.setups], 

188 self.relpos_ac, 

189 pw, 

190 atomdist=self.atomdist, 

191 xp=self.xp) 

192 psit_nG = pw.empty(self.nbands, 

193 comm=self.communicators['b'], 

194 xp=self.xp) 

195 mynbands, M = C_nM.dist.shape 

196 phit_aJG.multiply(C_nM.to_xp(self.xp).to_dtype(pw.dtype), 

197 out_nG=psit_nG[:mynbands]) 

198 return psit_nG 

199 

200 lcao_dtype = complex if \ 

201 np.issubdtype(self.dtype, np.complexfloating) else float 

202 

203 grid = self.grid.new(kpt=kpt_c, dtype=lcao_dtype) 

204 pw = self.wf_desc.new(kpt=kpt_c, dtype=lcao_dtype) 

205 if self.dtype != lcao_dtype: 

206 pw_correct = self.wf_desc.new(kpt=kpt_c, dtype=self.dtype) 

207 

208 if np.issubdtype(self.dtype, np.complexfloating): 

209 emikr_R = grid.eikr(-kpt_c) 

210 

211 mynbands, M = C_nM.dist.shape 

212 if self.ncomponents < 4: 

213 psit_nG = pw.empty(self.nbands, self.communicators['b']) 

214 psit_nR = grid.zeros(mynbands) 

215 basis_set.lcao_to_grid(C_nM.data, psit_nR.data, q) 

216 

217 for psit_R, psit_G in zips(psit_nR, psit_nG, strict=False): 

218 if np.issubdtype(self.dtype, np.complexfloating): 

219 psit_R.data *= emikr_R 

220 psit_R.fft(out=psit_G) 

221 

222 if self.dtype != lcao_dtype: 

223 psit2_nG = pw_correct.empty(self.nbands, 

224 self.communicators['b']) 

225 psit2_nG.data[:] = psit_nG.data 

226 return psit2_nG.to_xp(self.xp) 

227 return psit_nG.to_xp(self.xp) 

228 else: 

229 psit_nsG = pw.empty((self.nbands, 2), self.communicators['b']) 

230 psit_sR = grid.empty(2) 

231 C_nsM = C_nM.data.reshape((mynbands, 2, M // 2)) 

232 for psit_sG, C_sM in zips(psit_nsG, C_nsM, strict=False): 

233 psit_sR.data[:] = 0.0 

234 basis_set.lcao_to_grid(C_sM, psit_sR.data, q) 

235 psit_sR.data *= emikr_R 

236 for psit_G, psit_R in zips(psit_sG, psit_sR): 

237 psit_R.fft(out=psit_G) 

238 return psit_nsG 

239 

240 def read_ibz_wave_functions(self, reader): 

241 ibzwfs = super().read_ibz_wave_functions(reader) 

242 

243 if 'coefficients' not in reader.wave_functions: 

244 return ibzwfs 

245 

246 singlep = reader.get('precision', 'double') == 'single' 

247 c = reader.bohr**1.5 

248 if reader.version < 0: 

249 c = 1 # very old gpw file 

250 elif reader.version < 4: 

251 c /= self.grid.size_c.prod() 

252 

253 index_kG = reader.wave_functions.indices 

254 

255 if self.ncomponents == 4: 

256 shape = (self.nbands, 2) 

257 else: 

258 shape = (self.nbands,) 

259 

260 for wfs in ibzwfs: 

261 pw = self.wf_desc.new(kpt=wfs.kpt_c) 

262 if wfs.spin == 0: 

263 check_g_vector_ordering(self.grid, pw, index_kG[wfs.k]) 

264 

265 index = (wfs.spin, wfs.k) if self.ncomponents != 4 else (wfs.k,) 

266 data = reader.wave_functions.proxy('coefficients', *index) 

267 data.scale = c 

268 data.length_of_last_dimension = pw.shape[-1] 

269 

270 if self.communicators['w'].size == 1 and not singlep: 

271 orig_shape = data.shape 

272 data.shape = shape + pw.shape 

273 wfs.psit_nX = pw.from_data(data) 

274 data.shape = orig_shape 

275 else: 

276 band_comm = self.communicators['b'] 

277 wfs.psit_nX = PWArray(pw, shape, comm=band_comm) 

278 mynbands = (self.nbands + 

279 band_comm.size - 1) // band_comm.size 

280 n1 = min(band_comm.rank * mynbands, self.nbands) 

281 n2 = min((band_comm.rank + 1) * mynbands, self.nbands) 

282 if pw.comm.rank == 0: 

283 assert wfs.psit_nX.mydims[0] == n2 - n1 

284 data = data[n1:n2] # read from file 

285 else: 

286 data = [None] * (n2 - n1) 

287 for psit_G, array in zips(wfs.psit_nX, data): 

288 if singlep: 

289 psit_G.scatter_from(as_double_precision(array)) 

290 else: 

291 psit_G.scatter_from(array) 

292 

293 return ibzwfs 

294 

295 

296def check_g_vector_ordering(grid: UGDesc, 

297 pw: PWDesc, 

298 index_G: Array1D) -> None: 

299 size = tuple(grid.size) 

300 if np.issubdtype(pw.dtype, np.floating): 

301 size = (size[0], size[1], size[2] // 2 + 1) 

302 index0_G = pw.indices(size) 

303 nG = len(index0_G) 

304 assert (index0_G == index_G[:nG]).all() 

305 assert (index_G[nG:] == -1).all()