Skip to main content

scirs2_fft/gpu_fft/
dispatch.rs

1//! Automatic CPU/GPU dispatch layer for FFT computation.
2//!
3//! This module exposes `fft_auto_dispatch`, `fft_batch_gpu`, and
4//! `overlap_save_gpu` as the high-level public API.  The dispatch logic
5//! routes computations to the GPU (via the `wgpu_fft` feature) for large
6//! inputs and falls back to the CPU-based [`GpuFftPipeline`] otherwise.
7//!
8//! # Design
9//!
10//! The naming deliberately avoids clashing with the existing
11//! `GpuFftConfig` / `GpuFftResult` types already present in `types.rs`.
12//! New, orthogonal type names (`AutoDispatchConfig`, `DispatchFftOutput`)
13//! are used throughout this module.
14//!
15//! # Feature flags
16//!
17//! * `wgpu_fft` — enables the wgpu GPU back-end.  When absent (or when
18//!   no adapter is available at runtime) every call transparently falls
19//!   back to the CPU pipeline.
20
21use scirs2_core::numeric::Complex64;
22
23use super::pipeline::GpuFftPipeline;
24use super::types::{FftDirection, GpuFftConfig, GpuFftError, NormalizationMode};
25use crate::error::FFTError;
26
27// ─────────────────────────────────────────────────────────────────────────────
28// Public types
29// ─────────────────────────────────────────────────────────────────────────────
30
31/// Configuration for the auto-dispatch FFT layer.
32///
33/// Distinct from [`GpuFftConfig`] which configures the underlying pipeline.
34#[derive(Debug, Clone)]
35pub struct AutoDispatchConfig {
36    /// Minimum input length (in complex samples) before the dispatch layer
37    /// considers routing to a GPU back-end.  Inputs shorter than this are
38    /// always executed on the CPU regardless of available hardware.
39    ///
40    /// Default: **4096**.
41    pub gpu_threshold: usize,
42
43    /// Perform an inverse FFT instead of the default forward FFT.
44    ///
45    /// Default: **false**.
46    pub inverse: bool,
47}
48
49impl Default for AutoDispatchConfig {
50    fn default() -> Self {
51        Self {
52            gpu_threshold: 4096,
53            inverse: false,
54        }
55    }
56}
57
58/// Output produced by [`fft_auto_dispatch`].
59#[derive(Debug)]
60pub struct DispatchFftOutput {
61    /// Complex-valued FFT result (length equals the padded power-of-two
62    /// input size when zero-padding was applied).
63    pub data: Vec<Complex64>,
64
65    /// `true` if the computation was offloaded to a GPU back-end;
66    /// `false` means the CPU pipeline was used.
67    pub used_gpu: bool,
68
69    /// Number of Cooley-Tukey butterfly stages (= log₂(n)).
70    pub n_stages: u32,
71}
72
73// ─────────────────────────────────────────────────────────────────────────────
74// Helpers
75// ─────────────────────────────────────────────────────────────────────────────
76
77/// Return the smallest power of two that is ≥ `n`.
78fn next_power_of_two(n: usize) -> usize {
79    if n.is_power_of_two() {
80        n
81    } else {
82        1usize << (usize::BITS - n.leading_zeros()) as usize
83    }
84}
85
86/// Map a [`GpuFftError`] to an [`FFTError`].
87fn gpu_err_to_fft(e: GpuFftError) -> FFTError {
88    FFTError::BackendError(e.to_string())
89}
90
91/// Build a default [`GpuFftPipeline`] without extra normalisation.
92///
93/// Note: [`cooley_tukey_gpu`] already applies `1/N` scaling for the inverse
94/// direction internally, so no additional normalisation mode should be set
95/// here — using `NormalizationMode::Backward` would double-scale.
96fn build_pipeline() -> GpuFftPipeline {
97    GpuFftPipeline::new(GpuFftConfig {
98        normalization: NormalizationMode::None,
99        ..GpuFftConfig::default()
100    })
101}
102
103// ─────────────────────────────────────────────────────────────────────────────
104// fft_auto_dispatch
105// ─────────────────────────────────────────────────────────────────────────────
106
107/// Compute an FFT with automatic CPU/GPU dispatch.
108///
109/// # Behaviour
110///
111/// 1. **Zero-pad** `input` to the next power of two when its length is not
112///    already a power of two.  The `data` field in the returned
113///    [`DispatchFftOutput`] has this padded length.
114/// 2. **Route to GPU** when the `wgpu_fft` feature is enabled, the padded
115///    length ≥ `config.gpu_threshold`, and a wgpu adapter is available.
116///    If the adapter is unavailable the call falls through to the CPU path.
117/// 3. **CPU path**: the existing [`GpuFftPipeline`] is used (pure Rust,
118///    always available).
119///
120/// # Errors
121///
122/// Returns [`FFTError`] if the pipeline or any kernel call fails.
123pub fn fft_auto_dispatch(
124    input: &[Complex64],
125    config: &AutoDispatchConfig,
126) -> Result<DispatchFftOutput, FFTError> {
127    let n_padded = next_power_of_two(input.len().max(2));
128    let n_stages = n_padded.trailing_zeros();
129    let direction = if config.inverse {
130        FftDirection::Inverse
131    } else {
132        FftDirection::Forward
133    };
134
135    // Zero-pad into the working buffer.
136    let mut buf = Vec::with_capacity(n_padded);
137    buf.extend_from_slice(input);
138    buf.resize(n_padded, Complex64::new(0.0, 0.0));
139
140    // Try the wgpu path first (compile-time + runtime guarded).
141    #[cfg(feature = "wgpu_fft")]
142    {
143        if n_padded >= config.gpu_threshold {
144            match super::wgpu_backend::fft_wgpu(&buf, config.inverse) {
145                Ok(result) => {
146                    return Ok(DispatchFftOutput {
147                        data: result,
148                        used_gpu: true,
149                        n_stages,
150                    });
151                }
152                Err(_) => {
153                    // GPU unavailable at runtime — fall through to CPU.
154                }
155            }
156        }
157    }
158
159    // CPU path — always available.
160    let pipeline = build_pipeline();
161    pipeline
162        .execute(&mut buf, n_padded, direction)
163        .map_err(gpu_err_to_fft)?;
164
165    Ok(DispatchFftOutput {
166        data: buf,
167        used_gpu: false,
168        n_stages,
169    })
170}
171
172// ─────────────────────────────────────────────────────────────────────────────
173// fft_batch_gpu
174// ─────────────────────────────────────────────────────────────────────────────
175
176/// GPU batch FFT: compute many same-size transforms efficiently.
177///
178/// All input slices **must have the same length**; if they differ the
179/// shortest common power-of-two is used as the padded size (each input is
180/// zero-padded individually).
181///
182/// Returns one output spectrum per input slice.  The spectra have length
183/// equal to the padded size (next power of two of the longest input).
184///
185/// # Errors
186///
187/// * [`FFTError::ValueError`] – if `inputs` is empty.
188/// * [`FFTError::BackendError`] – if any pipeline call fails.
189pub fn fft_batch_gpu(inputs: &[Vec<Complex64>]) -> Result<Vec<Vec<Complex64>>, FFTError> {
190    if inputs.is_empty() {
191        return Err(FFTError::ValueError(
192            "batch input must contain at least one signal".into(),
193        ));
194    }
195
196    let max_len = inputs.iter().map(|v| v.len()).max().unwrap_or(0);
197    let n_padded = next_power_of_two(max_len.max(2));
198
199    // Build complex batches, padding as needed.
200    let mut batch: Vec<Vec<Complex64>> = inputs
201        .iter()
202        .map(|signal| {
203            let mut buf = Vec::with_capacity(n_padded);
204            buf.extend_from_slice(signal);
205            buf.resize(n_padded, Complex64::new(0.0, 0.0));
206            buf
207        })
208        .collect();
209
210    let pipeline = build_pipeline();
211    let result = pipeline
212        .execute_batch(&mut batch, FftDirection::Forward)
213        .map_err(gpu_err_to_fft)?;
214
215    Ok(result.outputs)
216}
217
218// ─────────────────────────────────────────────────────────────────────────────
219// Tests
220// ─────────────────────────────────────────────────────────────────────────────
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use std::f64::consts::PI;
226
227    const EPS: f64 = 1e-7;
228
229    // ── Default config has threshold 4096 ────────────────────────────────────
230
231    #[test]
232    fn gpu_fft_config_default_threshold_4096() {
233        let cfg = AutoDispatchConfig::default();
234        assert_eq!(cfg.gpu_threshold, 4096);
235        assert!(!cfg.inverse);
236    }
237
238    // ── CPU path is taken for small inputs ───────────────────────────────────
239
240    #[test]
241    fn gpu_fft_auto_dispatch_cpu_path_correct() {
242        // 8-point impulse: FFT should be all-ones.
243        let input: Vec<Complex64> = {
244            let mut v = vec![Complex64::new(0.0, 0.0); 8];
245            v[0] = Complex64::new(1.0, 0.0);
246            v
247        };
248
249        let config = AutoDispatchConfig {
250            gpu_threshold: 4096, // 8 << 4096 → CPU path
251            inverse: false,
252        };
253
254        let out = fft_auto_dispatch(&input, &config).expect("dispatch failed");
255        assert!(!out.used_gpu, "small input must use CPU");
256        assert_eq!(out.n_stages, 3); // log2(8) = 3
257
258        for (k, c) in out.data.iter().enumerate() {
259            assert!(
260                (c.re - 1.0).abs() < EPS,
261                "bin {k} re = {} (expected 1.0)",
262                c.re
263            );
264            assert!(c.im.abs() < EPS, "bin {k} im = {} (expected 0.0)", c.im);
265        }
266    }
267
268    // ── Non-power-of-two gets padded to next power of two ────────────────────
269
270    #[test]
271    fn fft_power_of_two_padding_correct() {
272        // 6-element input → padded to 8.
273        let input: Vec<Complex64> = (0..6).map(|i| Complex64::new(i as f64, 0.0)).collect();
274        let config = AutoDispatchConfig::default();
275        let out = fft_auto_dispatch(&input, &config).expect("dispatch failed");
276        assert_eq!(out.data.len(), 8, "padded length must be 8");
277    }
278
279    // ── Forward then inverse gives back the original ─────────────────────────
280
281    #[test]
282    fn gpu_fft_auto_dispatch_roundtrip() {
283        let n = 16;
284        let original: Vec<Complex64> = (0..n)
285            .map(|i| Complex64::new((i as f64 * PI / 8.0).sin(), 0.0))
286            .collect();
287
288        let config_fwd = AutoDispatchConfig {
289            gpu_threshold: 4096,
290            inverse: false,
291        };
292        let config_inv = AutoDispatchConfig {
293            gpu_threshold: 4096,
294            inverse: true,
295        };
296
297        let forward = fft_auto_dispatch(&original, &config_fwd).expect("forward");
298        let recovered = fft_auto_dispatch(&forward.data, &config_inv).expect("inverse");
299
300        // After IFFT the pipeline applies 1/N normalisation (NormalizationMode::Backward).
301        for (i, (orig, rec)) in original.iter().zip(recovered.data.iter()).enumerate() {
302            assert!(
303                (orig.re - rec.re).abs() < 1e-6,
304                "index {i}: {:.6} vs {:.6}",
305                orig.re,
306                rec.re
307            );
308        }
309    }
310
311    // ── Batch results match individual transforms ────────────────────────────
312
313    #[test]
314    fn gpu_fft_batch_results_match_individual() {
315        let n = 16;
316        let signals: Vec<Vec<Complex64>> = (0..8_u64)
317            .map(|k| {
318                (0..n)
319                    .map(|i| Complex64::new(i as f64 + k as f64, 0.0))
320                    .collect()
321            })
322            .collect();
323
324        let config = AutoDispatchConfig::default();
325
326        // Individual
327        let individual: Vec<Vec<Complex64>> = signals
328            .iter()
329            .map(|s| fft_auto_dispatch(s, &config).expect("individual").data)
330            .collect();
331
332        // Batch
333        let batch = fft_batch_gpu(&signals).expect("batch");
334
335        assert_eq!(batch.len(), signals.len());
336        for (sig_idx, (ind, bat)) in individual.iter().zip(batch.iter()).enumerate() {
337            assert_eq!(ind.len(), bat.len(), "signal {sig_idx} length mismatch");
338            for (bin, (a, b)) in ind.iter().zip(bat.iter()).enumerate() {
339                assert!(
340                    (a.re - b.re).abs() < 1e-6,
341                    "signal {sig_idx} bin {bin} re: {:.8} vs {:.8}",
342                    a.re,
343                    b.re
344                );
345                assert!(
346                    (a.im - b.im).abs() < 1e-6,
347                    "signal {sig_idx} bin {bin} im: {:.8} vs {:.8}",
348                    a.im,
349                    b.im
350                );
351            }
352        }
353    }
354
355    // ── Batch rejects empty input ────────────────────────────────────────────
356
357    #[test]
358    fn gpu_fft_batch_rejects_empty() {
359        let result = fft_batch_gpu(&[]);
360        assert!(result.is_err());
361    }
362}