Coverage for gpaw/fftw.py: 86%
237 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-12 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-12 00:18 +0000
1"""
2Python wrapper for FFTW3 library
3================================
5.. autoclass:: FFTPlans
7"""
8from __future__ import annotations
10import weakref
11from types import ModuleType
13import numpy as np
14from scipy.fft import fftn, ifftn, irfftn, rfftn
16import gpaw.cgpaw as cgpaw
17from gpaw.utilities import as_complex_dtype, as_real_dtype
18from gpaw.new.c import pw_insert_gpu
19from gpaw.new import trace
20from gpaw.typing import Array1D, Array3D, DTypeLike, IntVector
21from gpaw.gpu import is_hip
23ESTIMATE = 64
24MEASURE = 0
25PATIENT = 32
26EXHAUSTIVE = 8
28_plan_cache: dict[tuple, weakref.ReferenceType] = {}
31def have_fftw() -> bool:
32 """Did we compile with FFTW?"""
33 return hasattr(cgpaw, 'FFTWPlan')
36def check_fft_size(n: int, factors=[2, 3, 5, 7]) -> bool:
37 """Check if n is an efficient fft size.
39 Efficient means that n can be factored into small primes (2, 3, 5, 7).
41 >>> check_fft_size(17)
42 False
43 >>> check_fft_size(18)
44 True
45 """
47 if n == 1:
48 return True
49 for x in factors:
50 if n % x == 0:
51 return check_fft_size(n // x, factors)
52 return False
55def get_efficient_fft_size(N: int, n=1, factors=[2, 3, 5, 7]) -> int:
56 """Return smallest efficient fft size.
58 Must be greater than or equal to N and divisible by n.
60 >>> get_efficient_fft_size(17)
61 18
62 """
63 N = -(-N // n) * n
64 while not check_fft_size(N, factors):
65 N += n
66 return N
69def empty(shape, dtype=float):
70 """numpy.empty() equivalent with 16 byte alignment."""
71 assert np.issubdtype(dtype, np.complexfloating)
73 real_dtype = as_real_dtype(dtype)
74 complex_dtype = as_complex_dtype(dtype)
76 N = np.prod(shape)
77 a = np.empty(2 * N + 16 // real_dtype.itemsize - 1, real_dtype)
78 offset = (a.ctypes.data % 16) // real_dtype.itemsize
79 a = a[offset:2 * N + offset].view(complex_dtype)
80 a.shape = shape
81 return a
84def create_plans(size_c: IntVector,
85 dtype: DTypeLike,
86 flags: int = MEASURE,
87 xp: ModuleType = np) -> FFTPlans:
88 """Create plan-objects for FFT and inverse FFT."""
89 key = (tuple(size_c), dtype, flags, xp)
90 # Look up weakref to plan:
91 if key in _plan_cache:
92 plan = _plan_cache[key]()
93 # Check if plan is still "alive":
94 if plan is not None:
95 return plan
96 # Create new plan:
97 if xp is not np:
98 plan = CuPyFFTPlans(size_c, dtype)
99 elif have_fftw():
100 plan = FFTWPlans(size_c, dtype, flags)
101 else:
102 plan = NumpyFFTPlans(size_c, dtype)
103 _plan_cache[key] = weakref.ref(plan)
104 return plan
107class FFTPlans:
108 def __init__(self,
109 size_c: IntVector,
110 dtype: DTypeLike,
111 empty=empty):
112 self.shape: tuple[int, ...]
114 if np.issubdtype(dtype, np.floating):
115 self.shape = (size_c[0], size_c[1], size_c[2] // 2 + 1)
116 self.tmp_Q = empty(self.shape, as_complex_dtype(dtype))
117 self.tmp_R = self.tmp_Q.view(dtype)[:, :, :size_c[2]]
118 else:
119 self.shape = tuple(size_c)
120 self.tmp_Q = empty(size_c, dtype)
121 self.tmp_R = self.tmp_Q
123 def fft(self) -> None:
124 """Do FFT from ``tmp_R`` to ``tmp_Q``.
126 >>> plans = create_plans([4, 1, 1], float)
127 >>> plans.tmp_R[:, 0, 0] = [1, 0, 1, 0]
128 >>> plans.fft()
129 >>> plans.tmp_Q[:, 0, 0]
130 array([2.+0.j, 0.+0.j, 2.+0.j, 0.+0.j])
131 """
132 raise NotImplementedError
134 def ifft(self) -> None:
135 """Do inverse FFT from ``tmp_Q`` to ``tmp_R``.
137 >>> plans = create_plans([4, 1, 1], complex)
138 >>> plans.tmp_Q[:, 0, 0] = [0, 1j, 0, 0]
139 >>> plans.ifft()
140 >>> plans.tmp_R[:, 0, 0]
141 array([ 0.+1.j, -1.+0.j, 0.-1.j, 1.+0.j])
142 """
143 raise NotImplementedError
145 def ifft_sphere(self, coef_G, pw, out_R):
146 if coef_G is None:
147 out_R.scatter_from(None)
148 return
149 pw.paste(coef_G, self.tmp_Q)
151 if np.issubdtype(pw.dtype, np.floating):
152 t = self.tmp_Q[:, :, 0]
153 n, m = (s // 2 - 1 for s in out_R.desc.size_c[:2])
154 t[0, -m:] = t[0, m:0:-1].conj()
155 t[n:0:-1, -m:] = t[-n:, m:0:-1].conj()
156 t[-n:, -m:] = t[n:0:-1, m:0:-1].conj()
157 t[-n:, 0] = t[n:0:-1, 0].conj()
158 self.ifft()
159 out_R.scatter_from(self.tmp_R)
161 def fft_sphere(self, in_R, pw):
162 self.tmp_R[:] = in_R.data
163 self.fft()
164 coefs = pw.cut(self.tmp_Q) * (1 / self.tmp_R.size)
165 return coefs
168class FFTWPlans(FFTPlans):
169 """FFTW3 3d transforms."""
170 def __init__(self, size_c, dtype, flags=MEASURE):
171 if not have_fftw():
172 raise ImportError('Not compiled with FFTW.')
173 super().__init__(size_c, dtype)
174 self._fftplan = cgpaw.FFTWPlan(self.tmp_R, self.tmp_Q, -1, flags)
175 self._ifftplan = cgpaw.FFTWPlan(self.tmp_Q, self.tmp_R, 1, flags)
177 def fft(self):
178 cgpaw.FFTWExecute(self._fftplan)
180 def ifft(self):
181 cgpaw.FFTWExecute(self._ifftplan)
183 def __del__(self):
184 # Attributes will not exist if execution stops during FFTW planning
185 if hasattr(self, '_fftplan'):
186 cgpaw.FFTWDestroy(self._fftplan)
187 if hasattr(self, '_ifftplan'):
188 cgpaw.FFTWDestroy(self._ifftplan)
191class NumpyFFTPlans(FFTPlans):
192 """Numpy fallback."""
193 def fft(self):
194 if np.issubdtype(self.tmp_R.dtype, np.floating):
195 self.tmp_Q[:] = rfftn(self.tmp_R, overwrite_x=True)
196 else:
197 self.tmp_Q[:] = fftn(self.tmp_R, overwrite_x=True)
199 def ifft(self):
200 if np.issubdtype(self.tmp_R.dtype, np.floating):
201 self.tmp_R[:] = irfftn(self.tmp_Q, self.tmp_R.shape,
202 norm='forward', overwrite_x=True)
203 else:
204 self.tmp_R[:] = ifftn(self.tmp_Q, self.tmp_R.shape,
205 norm='forward', overwrite_x=True)
208def rfftn_patch(tmp_R):
209 from gpaw.gpu import cupyx
210 return cupyx.scipy.fft.fftn(tmp_R)[:, :, :tmp_R.shape[-1] // 2 + 1]
213def irfftn_patch(B, shape):
214 from gpaw.gpu import cupyx
215 import cupy as xp
216 A = xp.empty(shape, dtype=complex)
217 A[:, :, :B.shape[2]] = B
218 inv_ind1 = -xp.arange(B.shape[0])[:, None, None]
219 inv_ind2 = -xp.arange(B.shape[1])[None, :, None]
220 inv_ind3 = -xp.arange(1, B.shape[2])[None, None, :]
221 A[:, :, -(B.shape[2] - 1):] = B[inv_ind1, inv_ind2, inv_ind3].conj()
222 return cupyx.scipy.fft.ifftn(A).real
225class CuPyFFTPlans(FFTPlans):
226 def __init__(self,
227 size_c: IntVector,
228 dtype: DTypeLike):
229 from gpaw.core import PWDesc
230 from gpaw.gpu import cupy as cp
231 self.dtype = dtype
232 super().__init__(size_c, dtype, empty=cp.empty)
233 self.Q_G_cache: dict[PWDesc, Array1D] = {}
235 @trace(gpu=True)
236 def fft(self):
237 from gpaw.gpu import cupyx
238 if self.tmp_R.dtype == float:
239 if is_hip:
240 self.tmp_Q[:] = rfftn_patch(self.tmp_R)
241 else:
242 self.tmp_Q[:] = cupyx.scipy.fft.rfftn(self.tmp_R)
243 else:
244 self.tmp_Q[:] = cupyx.scipy.fft.fftn(self.tmp_R)
246 @trace(gpu=True)
247 def ifft(self):
248 from gpaw.gpu import cupyx
249 if self.tmp_R.dtype == float:
250 if is_hip:
251 self.tmp_R[:] = irfftn_patch(self.tmp_Q, self.tmp_R.shape) \
252 * self.tmp_R.size
253 else:
254 self.tmp_R[:] = cupyx.scipy.fft.irfftn(
255 self.tmp_Q, self.tmp_R.shape,
256 norm='forward',
257 overwrite_x=True)
258 else:
259 self.tmp_R[:] = cupyx.scipy.fft.ifftn(
260 self.tmp_Q, self.tmp_R.shape,
261 norm='forward',
262 overwrite_x=True)
264 def indices(self, pw):
265 from gpaw.gpu import cupy as cp
266 Q_G = self.Q_G_cache.get(pw)
267 if Q_G is None:
268 Q_G = cp.asarray(pw.indices(self.shape))
269 self.Q_G_cache[pw] = Q_G
270 return Q_G
272 @trace
273 def ifft_sphere(self, coef_G, pw, out_R):
274 from gpaw.gpu import cupyx
276 if coef_G is None:
277 out_R.scatter_from(None)
278 return
280 if out_R.desc.comm.size == 1:
281 array_R = out_R.data
282 else:
283 array_R = self.tmp_R
284 array_Q = self.tmp_Q
286 array_Q[:] = 0.0
287 Q_G = self.indices(pw)
289 assert np.issubdtype(array_Q.dtype, np.complexfloating)
290 assert np.issubdtype(coef_G.dtype, np.complexfloating)
291 pw_insert_gpu(coef_G,
292 Q_G,
293 1.0,
294 array_Q.ravel(),
295 *out_R.desc.size_c)
297 if np.issubdtype(self.dtype, np.complexfloating):
298 array_R[:] = cupyx.scipy.fft.ifftn(
299 array_Q, array_Q.shape,
300 norm='forward', overwrite_x=True)
301 else:
302 if is_hip:
303 array_R[:] = irfftn_patch(array_Q, out_R.desc.global_shape())\
304 * array_R.size
305 else:
306 array_R[:] = cupyx.scipy.fft.irfftn(
307 array_Q, out_R.desc.global_shape(),
308 norm='forward', overwrite_x=True)
310 if out_R.desc.comm.size > 1:
311 out_R.scatter_from(array_R)
313 @trace
314 def fft_sphere(self, in_R, pw):
315 from gpaw.gpu import cupyx
316 if np.issubdtype(self.dtype, np.complexfloating):
317 out_Q = cupyx.scipy.fft.fftn(in_R)
318 else:
319 if is_hip:
320 out_Q = rfftn_patch(in_R)
321 else:
322 # CuPy bug? rfftn fails on non-aligned arrays
323 # To that end, make a copy. However, display a warning.
324 if in_R.data.ptr % 16:
325 in_R = in_R.copy()
326 from warnings import warn
327 warn('Circumventing GPU array alignment problem '
328 'with copy at rfftn.')
329 out_Q = cupyx.scipy.fft.rfftn(in_R)
331 Q_G = self.indices(pw)
332 coef_G = out_Q.ravel()[Q_G] * (1 / in_R.size)
333 return coef_G
336# The rest of this file will be removed in the future ...
338def check_fftw_inputs(in_R, out_R):
339 for arr in in_R, out_R:
340 # Note: Arrays not necessarily contiguous due to 16-byte alignment
341 assert arr.ndim == 3 # We can perhaps relax this requirement
342 assert arr.dtype == float or arr.dtype == complex
344 if in_R.dtype == out_R.dtype == complex:
345 assert in_R.shape == out_R.shape
346 else:
347 # One real and one complex:
348 R, C = (in_R, out_R) if in_R.dtype == float else (out_R, in_R)
349 assert C.dtype == complex
350 assert R.shape[:2] == C.shape[:2]
351 assert C.shape[2] == 1 + R.shape[2] // 2
354class FFTPlan:
355 """FFT 3d transform."""
356 def __init__(self,
357 in_R: Array3D,
358 out_R: Array3D,
359 sign: int,
360 flags: int = MEASURE):
361 check_fftw_inputs(in_R, out_R)
362 self.in_R = in_R
363 self.out_R = out_R
364 self.sign = sign
365 self.flags = flags
367 def execute(self) -> None:
368 raise NotImplementedError
371class FFTWPlan(FFTPlan):
372 """FFTW3 3d transform."""
373 def __init__(self, in_R, out_R, sign, flags=MEASURE):
374 if not have_fftw():
375 raise ImportError('Not compiled with FFTW.')
376 self._ptr = cgpaw.FFTWPlan(in_R, out_R, sign, flags)
377 FFTPlan.__init__(self, in_R, out_R, sign, flags)
379 def execute(self):
380 cgpaw.FFTWExecute(self._ptr)
382 def __del__(self):
383 if getattr(self, '_ptr', None):
384 cgpaw.FFTWDestroy(self._ptr)
385 self._ptr = None
388class NumpyFFTPlan(FFTPlan):
389 """Numpy fallback."""
390 def execute(self):
391 if self.in_R.dtype == float:
392 self.out_R[:] = np.fft.rfftn(self.in_R)
393 elif self.out_R.dtype == float:
394 self.out_R[:] = np.fft.irfftn(self.in_R,
395 self.out_R.shape,
396 [0, 1, 2])
397 self.out_R *= self.out_R.size
398 elif self.sign == 1:
399 self.out_R[:] = np.fft.ifftn(self.in_R,
400 self.out_R.shape,
401 [0, 1, 2])
402 self.out_R *= self.out_R.size
403 else:
404 self.out_R[:] = np.fft.fftn(self.in_R)
407def create_plan(in_R: Array3D,
408 out_R: Array3D,
409 sign: int,
410 flags: int = MEASURE) -> FFTPlan:
411 if have_fftw():
412 return FFTWPlan(in_R, out_R, sign, flags)
413 return NumpyFFTPlan(in_R, out_R, sign, flags)