Coverage for gpaw/fftw.py: 86%

237 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-12 00:18 +0000

1""" 

2Python wrapper for FFTW3 library 

3================================ 

4 

5.. autoclass:: FFTPlans 

6 

7""" 

8from __future__ import annotations 

9 

10import weakref 

11from types import ModuleType 

12 

13import numpy as np 

14from scipy.fft import fftn, ifftn, irfftn, rfftn 

15 

16import gpaw.cgpaw as cgpaw 

17from gpaw.utilities import as_complex_dtype, as_real_dtype 

18from gpaw.new.c import pw_insert_gpu 

19from gpaw.new import trace 

20from gpaw.typing import Array1D, Array3D, DTypeLike, IntVector 

21from gpaw.gpu import is_hip 

22 

23ESTIMATE = 64 

24MEASURE = 0 

25PATIENT = 32 

26EXHAUSTIVE = 8 

27 

28_plan_cache: dict[tuple, weakref.ReferenceType] = {} 

29 

30 

31def have_fftw() -> bool: 

32 """Did we compile with FFTW?""" 

33 return hasattr(cgpaw, 'FFTWPlan') 

34 

35 

36def check_fft_size(n: int, factors=[2, 3, 5, 7]) -> bool: 

37 """Check if n is an efficient fft size. 

38 

39 Efficient means that n can be factored into small primes (2, 3, 5, 7). 

40 

41 >>> check_fft_size(17) 

42 False 

43 >>> check_fft_size(18) 

44 True 

45 """ 

46 

47 if n == 1: 

48 return True 

49 for x in factors: 

50 if n % x == 0: 

51 return check_fft_size(n // x, factors) 

52 return False 

53 

54 

55def get_efficient_fft_size(N: int, n=1, factors=[2, 3, 5, 7]) -> int: 

56 """Return smallest efficient fft size. 

57 

58 Must be greater than or equal to N and divisible by n. 

59 

60 >>> get_efficient_fft_size(17) 

61 18 

62 """ 

63 N = -(-N // n) * n 

64 while not check_fft_size(N, factors): 

65 N += n 

66 return N 

67 

68 

69def empty(shape, dtype=float): 

70 """numpy.empty() equivalent with 16 byte alignment.""" 

71 assert np.issubdtype(dtype, np.complexfloating) 

72 

73 real_dtype = as_real_dtype(dtype) 

74 complex_dtype = as_complex_dtype(dtype) 

75 

76 N = np.prod(shape) 

77 a = np.empty(2 * N + 16 // real_dtype.itemsize - 1, real_dtype) 

78 offset = (a.ctypes.data % 16) // real_dtype.itemsize 

79 a = a[offset:2 * N + offset].view(complex_dtype) 

80 a.shape = shape 

81 return a 

82 

83 

84def create_plans(size_c: IntVector, 

85 dtype: DTypeLike, 

86 flags: int = MEASURE, 

87 xp: ModuleType = np) -> FFTPlans: 

88 """Create plan-objects for FFT and inverse FFT.""" 

89 key = (tuple(size_c), dtype, flags, xp) 

90 # Look up weakref to plan: 

91 if key in _plan_cache: 

92 plan = _plan_cache[key]() 

93 # Check if plan is still "alive": 

94 if plan is not None: 

95 return plan 

96 # Create new plan: 

97 if xp is not np: 

98 plan = CuPyFFTPlans(size_c, dtype) 

99 elif have_fftw(): 

100 plan = FFTWPlans(size_c, dtype, flags) 

101 else: 

102 plan = NumpyFFTPlans(size_c, dtype) 

103 _plan_cache[key] = weakref.ref(plan) 

104 return plan 

105 

106 

107class FFTPlans: 

108 def __init__(self, 

109 size_c: IntVector, 

110 dtype: DTypeLike, 

111 empty=empty): 

112 self.shape: tuple[int, ...] 

113 

114 if np.issubdtype(dtype, np.floating): 

115 self.shape = (size_c[0], size_c[1], size_c[2] // 2 + 1) 

116 self.tmp_Q = empty(self.shape, as_complex_dtype(dtype)) 

117 self.tmp_R = self.tmp_Q.view(dtype)[:, :, :size_c[2]] 

118 else: 

119 self.shape = tuple(size_c) 

120 self.tmp_Q = empty(size_c, dtype) 

121 self.tmp_R = self.tmp_Q 

122 

123 def fft(self) -> None: 

124 """Do FFT from ``tmp_R`` to ``tmp_Q``. 

125 

126 >>> plans = create_plans([4, 1, 1], float) 

127 >>> plans.tmp_R[:, 0, 0] = [1, 0, 1, 0] 

128 >>> plans.fft() 

129 >>> plans.tmp_Q[:, 0, 0] 

130 array([2.+0.j, 0.+0.j, 2.+0.j, 0.+0.j]) 

131 """ 

132 raise NotImplementedError 

133 

134 def ifft(self) -> None: 

135 """Do inverse FFT from ``tmp_Q`` to ``tmp_R``. 

136 

137 >>> plans = create_plans([4, 1, 1], complex) 

138 >>> plans.tmp_Q[:, 0, 0] = [0, 1j, 0, 0] 

139 >>> plans.ifft() 

140 >>> plans.tmp_R[:, 0, 0] 

141 array([ 0.+1.j, -1.+0.j, 0.-1.j, 1.+0.j]) 

142 """ 

143 raise NotImplementedError 

144 

145 def ifft_sphere(self, coef_G, pw, out_R): 

146 if coef_G is None: 

147 out_R.scatter_from(None) 

148 return 

149 pw.paste(coef_G, self.tmp_Q) 

150 

151 if np.issubdtype(pw.dtype, np.floating): 

152 t = self.tmp_Q[:, :, 0] 

153 n, m = (s // 2 - 1 for s in out_R.desc.size_c[:2]) 

154 t[0, -m:] = t[0, m:0:-1].conj() 

155 t[n:0:-1, -m:] = t[-n:, m:0:-1].conj() 

156 t[-n:, -m:] = t[n:0:-1, m:0:-1].conj() 

157 t[-n:, 0] = t[n:0:-1, 0].conj() 

158 self.ifft() 

159 out_R.scatter_from(self.tmp_R) 

160 

161 def fft_sphere(self, in_R, pw): 

162 self.tmp_R[:] = in_R.data 

163 self.fft() 

164 coefs = pw.cut(self.tmp_Q) * (1 / self.tmp_R.size) 

165 return coefs 

166 

167 

168class FFTWPlans(FFTPlans): 

169 """FFTW3 3d transforms.""" 

170 def __init__(self, size_c, dtype, flags=MEASURE): 

171 if not have_fftw(): 

172 raise ImportError('Not compiled with FFTW.') 

173 super().__init__(size_c, dtype) 

174 self._fftplan = cgpaw.FFTWPlan(self.tmp_R, self.tmp_Q, -1, flags) 

175 self._ifftplan = cgpaw.FFTWPlan(self.tmp_Q, self.tmp_R, 1, flags) 

176 

177 def fft(self): 

178 cgpaw.FFTWExecute(self._fftplan) 

179 

180 def ifft(self): 

181 cgpaw.FFTWExecute(self._ifftplan) 

182 

183 def __del__(self): 

184 # Attributes will not exist if execution stops during FFTW planning 

185 if hasattr(self, '_fftplan'): 

186 cgpaw.FFTWDestroy(self._fftplan) 

187 if hasattr(self, '_ifftplan'): 

188 cgpaw.FFTWDestroy(self._ifftplan) 

189 

190 

191class NumpyFFTPlans(FFTPlans): 

192 """Numpy fallback.""" 

193 def fft(self): 

194 if np.issubdtype(self.tmp_R.dtype, np.floating): 

195 self.tmp_Q[:] = rfftn(self.tmp_R, overwrite_x=True) 

196 else: 

197 self.tmp_Q[:] = fftn(self.tmp_R, overwrite_x=True) 

198 

199 def ifft(self): 

200 if np.issubdtype(self.tmp_R.dtype, np.floating): 

201 self.tmp_R[:] = irfftn(self.tmp_Q, self.tmp_R.shape, 

202 norm='forward', overwrite_x=True) 

203 else: 

204 self.tmp_R[:] = ifftn(self.tmp_Q, self.tmp_R.shape, 

205 norm='forward', overwrite_x=True) 

206 

207 

208def rfftn_patch(tmp_R): 

209 from gpaw.gpu import cupyx 

210 return cupyx.scipy.fft.fftn(tmp_R)[:, :, :tmp_R.shape[-1] // 2 + 1] 

211 

212 

213def irfftn_patch(B, shape): 

214 from gpaw.gpu import cupyx 

215 import cupy as xp 

216 A = xp.empty(shape, dtype=complex) 

217 A[:, :, :B.shape[2]] = B 

218 inv_ind1 = -xp.arange(B.shape[0])[:, None, None] 

219 inv_ind2 = -xp.arange(B.shape[1])[None, :, None] 

220 inv_ind3 = -xp.arange(1, B.shape[2])[None, None, :] 

221 A[:, :, -(B.shape[2] - 1):] = B[inv_ind1, inv_ind2, inv_ind3].conj() 

222 return cupyx.scipy.fft.ifftn(A).real 

223 

224 

225class CuPyFFTPlans(FFTPlans): 

226 def __init__(self, 

227 size_c: IntVector, 

228 dtype: DTypeLike): 

229 from gpaw.core import PWDesc 

230 from gpaw.gpu import cupy as cp 

231 self.dtype = dtype 

232 super().__init__(size_c, dtype, empty=cp.empty) 

233 self.Q_G_cache: dict[PWDesc, Array1D] = {} 

234 

235 @trace(gpu=True) 

236 def fft(self): 

237 from gpaw.gpu import cupyx 

238 if self.tmp_R.dtype == float: 

239 if is_hip: 

240 self.tmp_Q[:] = rfftn_patch(self.tmp_R) 

241 else: 

242 self.tmp_Q[:] = cupyx.scipy.fft.rfftn(self.tmp_R) 

243 else: 

244 self.tmp_Q[:] = cupyx.scipy.fft.fftn(self.tmp_R) 

245 

246 @trace(gpu=True) 

247 def ifft(self): 

248 from gpaw.gpu import cupyx 

249 if self.tmp_R.dtype == float: 

250 if is_hip: 

251 self.tmp_R[:] = irfftn_patch(self.tmp_Q, self.tmp_R.shape) \ 

252 * self.tmp_R.size 

253 else: 

254 self.tmp_R[:] = cupyx.scipy.fft.irfftn( 

255 self.tmp_Q, self.tmp_R.shape, 

256 norm='forward', 

257 overwrite_x=True) 

258 else: 

259 self.tmp_R[:] = cupyx.scipy.fft.ifftn( 

260 self.tmp_Q, self.tmp_R.shape, 

261 norm='forward', 

262 overwrite_x=True) 

263 

264 def indices(self, pw): 

265 from gpaw.gpu import cupy as cp 

266 Q_G = self.Q_G_cache.get(pw) 

267 if Q_G is None: 

268 Q_G = cp.asarray(pw.indices(self.shape)) 

269 self.Q_G_cache[pw] = Q_G 

270 return Q_G 

271 

272 @trace 

273 def ifft_sphere(self, coef_G, pw, out_R): 

274 from gpaw.gpu import cupyx 

275 

276 if coef_G is None: 

277 out_R.scatter_from(None) 

278 return 

279 

280 if out_R.desc.comm.size == 1: 

281 array_R = out_R.data 

282 else: 

283 array_R = self.tmp_R 

284 array_Q = self.tmp_Q 

285 

286 array_Q[:] = 0.0 

287 Q_G = self.indices(pw) 

288 

289 assert np.issubdtype(array_Q.dtype, np.complexfloating) 

290 assert np.issubdtype(coef_G.dtype, np.complexfloating) 

291 pw_insert_gpu(coef_G, 

292 Q_G, 

293 1.0, 

294 array_Q.ravel(), 

295 *out_R.desc.size_c) 

296 

297 if np.issubdtype(self.dtype, np.complexfloating): 

298 array_R[:] = cupyx.scipy.fft.ifftn( 

299 array_Q, array_Q.shape, 

300 norm='forward', overwrite_x=True) 

301 else: 

302 if is_hip: 

303 array_R[:] = irfftn_patch(array_Q, out_R.desc.global_shape())\ 

304 * array_R.size 

305 else: 

306 array_R[:] = cupyx.scipy.fft.irfftn( 

307 array_Q, out_R.desc.global_shape(), 

308 norm='forward', overwrite_x=True) 

309 

310 if out_R.desc.comm.size > 1: 

311 out_R.scatter_from(array_R) 

312 

313 @trace 

314 def fft_sphere(self, in_R, pw): 

315 from gpaw.gpu import cupyx 

316 if np.issubdtype(self.dtype, np.complexfloating): 

317 out_Q = cupyx.scipy.fft.fftn(in_R) 

318 else: 

319 if is_hip: 

320 out_Q = rfftn_patch(in_R) 

321 else: 

322 # CuPy bug? rfftn fails on non-aligned arrays 

323 # To that end, make a copy. However, display a warning. 

324 if in_R.data.ptr % 16: 

325 in_R = in_R.copy() 

326 from warnings import warn 

327 warn('Circumventing GPU array alignment problem ' 

328 'with copy at rfftn.') 

329 out_Q = cupyx.scipy.fft.rfftn(in_R) 

330 

331 Q_G = self.indices(pw) 

332 coef_G = out_Q.ravel()[Q_G] * (1 / in_R.size) 

333 return coef_G 

334 

335 

336# The rest of this file will be removed in the future ... 

337 

338def check_fftw_inputs(in_R, out_R): 

339 for arr in in_R, out_R: 

340 # Note: Arrays not necessarily contiguous due to 16-byte alignment 

341 assert arr.ndim == 3 # We can perhaps relax this requirement 

342 assert arr.dtype == float or arr.dtype == complex 

343 

344 if in_R.dtype == out_R.dtype == complex: 

345 assert in_R.shape == out_R.shape 

346 else: 

347 # One real and one complex: 

348 R, C = (in_R, out_R) if in_R.dtype == float else (out_R, in_R) 

349 assert C.dtype == complex 

350 assert R.shape[:2] == C.shape[:2] 

351 assert C.shape[2] == 1 + R.shape[2] // 2 

352 

353 

354class FFTPlan: 

355 """FFT 3d transform.""" 

356 def __init__(self, 

357 in_R: Array3D, 

358 out_R: Array3D, 

359 sign: int, 

360 flags: int = MEASURE): 

361 check_fftw_inputs(in_R, out_R) 

362 self.in_R = in_R 

363 self.out_R = out_R 

364 self.sign = sign 

365 self.flags = flags 

366 

367 def execute(self) -> None: 

368 raise NotImplementedError 

369 

370 

371class FFTWPlan(FFTPlan): 

372 """FFTW3 3d transform.""" 

373 def __init__(self, in_R, out_R, sign, flags=MEASURE): 

374 if not have_fftw(): 

375 raise ImportError('Not compiled with FFTW.') 

376 self._ptr = cgpaw.FFTWPlan(in_R, out_R, sign, flags) 

377 FFTPlan.__init__(self, in_R, out_R, sign, flags) 

378 

379 def execute(self): 

380 cgpaw.FFTWExecute(self._ptr) 

381 

382 def __del__(self): 

383 if getattr(self, '_ptr', None): 

384 cgpaw.FFTWDestroy(self._ptr) 

385 self._ptr = None 

386 

387 

388class NumpyFFTPlan(FFTPlan): 

389 """Numpy fallback.""" 

390 def execute(self): 

391 if self.in_R.dtype == float: 

392 self.out_R[:] = np.fft.rfftn(self.in_R) 

393 elif self.out_R.dtype == float: 

394 self.out_R[:] = np.fft.irfftn(self.in_R, 

395 self.out_R.shape, 

396 [0, 1, 2]) 

397 self.out_R *= self.out_R.size 

398 elif self.sign == 1: 

399 self.out_R[:] = np.fft.ifftn(self.in_R, 

400 self.out_R.shape, 

401 [0, 1, 2]) 

402 self.out_R *= self.out_R.size 

403 else: 

404 self.out_R[:] = np.fft.fftn(self.in_R) 

405 

406 

407def create_plan(in_R: Array3D, 

408 out_R: Array3D, 

409 sign: int, 

410 flags: int = MEASURE) -> FFTPlan: 

411 if have_fftw(): 

412 return FFTWPlan(in_R, out_R, sign, flags) 

413 return NumpyFFTPlan(in_R, out_R, sign, flags)