scirs2_fft/gpu_fft/dispatch.rs
1//! Automatic CPU/GPU dispatch layer for FFT computation.
2//!
3//! This module exposes `fft_auto_dispatch`, `fft_batch_gpu`, and
4//! `overlap_save_gpu` as the high-level public API. The dispatch logic
5//! routes computations to the GPU (via the `wgpu_fft` feature) for large
6//! inputs and falls back to the CPU-based [`GpuFftPipeline`] otherwise.
7//!
8//! # Design
9//!
10//! The naming deliberately avoids clashing with the existing
11//! `GpuFftConfig` / `GpuFftResult` types already present in `types.rs`.
12//! New, orthogonal type names (`AutoDispatchConfig`, `DispatchFftOutput`)
13//! are used throughout this module.
14//!
15//! # Feature flags
16//!
17//! * `wgpu_fft` — enables the wgpu GPU back-end. When absent (or when
18//! no adapter is available at runtime) every call transparently falls
19//! back to the CPU pipeline.
20
21use scirs2_core::numeric::Complex64;
22
23use super::pipeline::GpuFftPipeline;
24use super::types::{FftDirection, GpuFftConfig, GpuFftError, NormalizationMode};
25use crate::error::FFTError;
26
27// ─────────────────────────────────────────────────────────────────────────────
28// Public types
29// ─────────────────────────────────────────────────────────────────────────────
30
31/// Configuration for the auto-dispatch FFT layer.
32///
33/// Distinct from [`GpuFftConfig`] which configures the underlying pipeline.
34#[derive(Debug, Clone)]
35pub struct AutoDispatchConfig {
36 /// Minimum input length (in complex samples) before the dispatch layer
37 /// considers routing to a GPU back-end. Inputs shorter than this are
38 /// always executed on the CPU regardless of available hardware.
39 ///
40 /// Default: **4096**.
41 pub gpu_threshold: usize,
42
43 /// Perform an inverse FFT instead of the default forward FFT.
44 ///
45 /// Default: **false**.
46 pub inverse: bool,
47}
48
49impl Default for AutoDispatchConfig {
50 fn default() -> Self {
51 Self {
52 gpu_threshold: 4096,
53 inverse: false,
54 }
55 }
56}
57
58/// Output produced by [`fft_auto_dispatch`].
59#[derive(Debug)]
60pub struct DispatchFftOutput {
61 /// Complex-valued FFT result (length equals the padded power-of-two
62 /// input size when zero-padding was applied).
63 pub data: Vec<Complex64>,
64
65 /// `true` if the computation was offloaded to a GPU back-end;
66 /// `false` means the CPU pipeline was used.
67 pub used_gpu: bool,
68
69 /// Number of Cooley-Tukey butterfly stages (= log₂(n)).
70 pub n_stages: u32,
71}
72
73// ─────────────────────────────────────────────────────────────────────────────
74// Helpers
75// ─────────────────────────────────────────────────────────────────────────────
76
77/// Return the smallest power of two that is ≥ `n`.
78fn next_power_of_two(n: usize) -> usize {
79 if n.is_power_of_two() {
80 n
81 } else {
82 1usize << (usize::BITS - n.leading_zeros()) as usize
83 }
84}
85
86/// Map a [`GpuFftError`] to an [`FFTError`].
87fn gpu_err_to_fft(e: GpuFftError) -> FFTError {
88 FFTError::BackendError(e.to_string())
89}
90
91/// Build a default [`GpuFftPipeline`] without extra normalisation.
92///
93/// Note: [`cooley_tukey_gpu`] already applies `1/N` scaling for the inverse
94/// direction internally, so no additional normalisation mode should be set
95/// here — using `NormalizationMode::Backward` would double-scale.
96fn build_pipeline() -> GpuFftPipeline {
97 GpuFftPipeline::new(GpuFftConfig {
98 normalization: NormalizationMode::None,
99 ..GpuFftConfig::default()
100 })
101}
102
103// ─────────────────────────────────────────────────────────────────────────────
104// fft_auto_dispatch
105// ─────────────────────────────────────────────────────────────────────────────
106
107/// Compute an FFT with automatic CPU/GPU dispatch.
108///
109/// # Behaviour
110///
111/// 1. **Zero-pad** `input` to the next power of two when its length is not
112/// already a power of two. The `data` field in the returned
113/// [`DispatchFftOutput`] has this padded length.
114/// 2. **Route to GPU** when the `wgpu_fft` feature is enabled, the padded
115/// length ≥ `config.gpu_threshold`, and a wgpu adapter is available.
116/// If the adapter is unavailable the call falls through to the CPU path.
117/// 3. **CPU path**: the existing [`GpuFftPipeline`] is used (pure Rust,
118/// always available).
119///
120/// # Errors
121///
122/// Returns [`FFTError`] if the pipeline or any kernel call fails.
123pub fn fft_auto_dispatch(
124 input: &[Complex64],
125 config: &AutoDispatchConfig,
126) -> Result<DispatchFftOutput, FFTError> {
127 let n_padded = next_power_of_two(input.len().max(2));
128 let n_stages = n_padded.trailing_zeros();
129 let direction = if config.inverse {
130 FftDirection::Inverse
131 } else {
132 FftDirection::Forward
133 };
134
135 // Zero-pad into the working buffer.
136 let mut buf = Vec::with_capacity(n_padded);
137 buf.extend_from_slice(input);
138 buf.resize(n_padded, Complex64::new(0.0, 0.0));
139
140 // Try the wgpu path first (compile-time + runtime guarded).
141 #[cfg(feature = "wgpu_fft")]
142 {
143 if n_padded >= config.gpu_threshold {
144 match super::wgpu_backend::fft_wgpu(&buf, config.inverse) {
145 Ok(result) => {
146 return Ok(DispatchFftOutput {
147 data: result,
148 used_gpu: true,
149 n_stages,
150 });
151 }
152 Err(_) => {
153 // GPU unavailable at runtime — fall through to CPU.
154 }
155 }
156 }
157 }
158
159 // CPU path — always available.
160 let pipeline = build_pipeline();
161 pipeline
162 .execute(&mut buf, n_padded, direction)
163 .map_err(gpu_err_to_fft)?;
164
165 Ok(DispatchFftOutput {
166 data: buf,
167 used_gpu: false,
168 n_stages,
169 })
170}
171
172// ─────────────────────────────────────────────────────────────────────────────
173// fft_batch_gpu
174// ─────────────────────────────────────────────────────────────────────────────
175
176/// GPU batch FFT: compute many same-size transforms efficiently.
177///
178/// All input slices **must have the same length**; if they differ the
179/// shortest common power-of-two is used as the padded size (each input is
180/// zero-padded individually).
181///
182/// Returns one output spectrum per input slice. The spectra have length
183/// equal to the padded size (next power of two of the longest input).
184///
185/// # Errors
186///
187/// * [`FFTError::ValueError`] – if `inputs` is empty.
188/// * [`FFTError::BackendError`] – if any pipeline call fails.
189pub fn fft_batch_gpu(inputs: &[Vec<Complex64>]) -> Result<Vec<Vec<Complex64>>, FFTError> {
190 if inputs.is_empty() {
191 return Err(FFTError::ValueError(
192 "batch input must contain at least one signal".into(),
193 ));
194 }
195
196 let max_len = inputs.iter().map(|v| v.len()).max().unwrap_or(0);
197 let n_padded = next_power_of_two(max_len.max(2));
198
199 // Build complex batches, padding as needed.
200 let mut batch: Vec<Vec<Complex64>> = inputs
201 .iter()
202 .map(|signal| {
203 let mut buf = Vec::with_capacity(n_padded);
204 buf.extend_from_slice(signal);
205 buf.resize(n_padded, Complex64::new(0.0, 0.0));
206 buf
207 })
208 .collect();
209
210 let pipeline = build_pipeline();
211 let result = pipeline
212 .execute_batch(&mut batch, FftDirection::Forward)
213 .map_err(gpu_err_to_fft)?;
214
215 Ok(result.outputs)
216}
217
218// ─────────────────────────────────────────────────────────────────────────────
219// Tests
220// ─────────────────────────────────────────────────────────────────────────────
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use std::f64::consts::PI;
226
227 const EPS: f64 = 1e-7;
228
229 // ── Default config has threshold 4096 ────────────────────────────────────
230
231 #[test]
232 fn gpu_fft_config_default_threshold_4096() {
233 let cfg = AutoDispatchConfig::default();
234 assert_eq!(cfg.gpu_threshold, 4096);
235 assert!(!cfg.inverse);
236 }
237
238 // ── CPU path is taken for small inputs ───────────────────────────────────
239
240 #[test]
241 fn gpu_fft_auto_dispatch_cpu_path_correct() {
242 // 8-point impulse: FFT should be all-ones.
243 let input: Vec<Complex64> = {
244 let mut v = vec![Complex64::new(0.0, 0.0); 8];
245 v[0] = Complex64::new(1.0, 0.0);
246 v
247 };
248
249 let config = AutoDispatchConfig {
250 gpu_threshold: 4096, // 8 << 4096 → CPU path
251 inverse: false,
252 };
253
254 let out = fft_auto_dispatch(&input, &config).expect("dispatch failed");
255 assert!(!out.used_gpu, "small input must use CPU");
256 assert_eq!(out.n_stages, 3); // log2(8) = 3
257
258 for (k, c) in out.data.iter().enumerate() {
259 assert!(
260 (c.re - 1.0).abs() < EPS,
261 "bin {k} re = {} (expected 1.0)",
262 c.re
263 );
264 assert!(c.im.abs() < EPS, "bin {k} im = {} (expected 0.0)", c.im);
265 }
266 }
267
268 // ── Non-power-of-two gets padded to next power of two ────────────────────
269
270 #[test]
271 fn fft_power_of_two_padding_correct() {
272 // 6-element input → padded to 8.
273 let input: Vec<Complex64> = (0..6).map(|i| Complex64::new(i as f64, 0.0)).collect();
274 let config = AutoDispatchConfig::default();
275 let out = fft_auto_dispatch(&input, &config).expect("dispatch failed");
276 assert_eq!(out.data.len(), 8, "padded length must be 8");
277 }
278
279 // ── Forward then inverse gives back the original ─────────────────────────
280
281 #[test]
282 fn gpu_fft_auto_dispatch_roundtrip() {
283 let n = 16;
284 let original: Vec<Complex64> = (0..n)
285 .map(|i| Complex64::new((i as f64 * PI / 8.0).sin(), 0.0))
286 .collect();
287
288 let config_fwd = AutoDispatchConfig {
289 gpu_threshold: 4096,
290 inverse: false,
291 };
292 let config_inv = AutoDispatchConfig {
293 gpu_threshold: 4096,
294 inverse: true,
295 };
296
297 let forward = fft_auto_dispatch(&original, &config_fwd).expect("forward");
298 let recovered = fft_auto_dispatch(&forward.data, &config_inv).expect("inverse");
299
300 // After IFFT the pipeline applies 1/N normalisation (NormalizationMode::Backward).
301 for (i, (orig, rec)) in original.iter().zip(recovered.data.iter()).enumerate() {
302 assert!(
303 (orig.re - rec.re).abs() < 1e-6,
304 "index {i}: {:.6} vs {:.6}",
305 orig.re,
306 rec.re
307 );
308 }
309 }
310
311 // ── Batch results match individual transforms ────────────────────────────
312
313 #[test]
314 fn gpu_fft_batch_results_match_individual() {
315 let n = 16;
316 let signals: Vec<Vec<Complex64>> = (0..8_u64)
317 .map(|k| {
318 (0..n)
319 .map(|i| Complex64::new(i as f64 + k as f64, 0.0))
320 .collect()
321 })
322 .collect();
323
324 let config = AutoDispatchConfig::default();
325
326 // Individual
327 let individual: Vec<Vec<Complex64>> = signals
328 .iter()
329 .map(|s| fft_auto_dispatch(s, &config).expect("individual").data)
330 .collect();
331
332 // Batch
333 let batch = fft_batch_gpu(&signals).expect("batch");
334
335 assert_eq!(batch.len(), signals.len());
336 for (sig_idx, (ind, bat)) in individual.iter().zip(batch.iter()).enumerate() {
337 assert_eq!(ind.len(), bat.len(), "signal {sig_idx} length mismatch");
338 for (bin, (a, b)) in ind.iter().zip(bat.iter()).enumerate() {
339 assert!(
340 (a.re - b.re).abs() < 1e-6,
341 "signal {sig_idx} bin {bin} re: {:.8} vs {:.8}",
342 a.re,
343 b.re
344 );
345 assert!(
346 (a.im - b.im).abs() < 1e-6,
347 "signal {sig_idx} bin {bin} im: {:.8} vs {:.8}",
348 a.im,
349 b.im
350 );
351 }
352 }
353 }
354
355 // ── Batch rejects empty input ────────────────────────────────────────────
356
357 #[test]
358 fn gpu_fft_batch_rejects_empty() {
359 let result = fft_batch_gpu(&[]);
360 assert!(result.is_err());
361 }
362}