Skip to main content

scirs2_fft/gpu_fft/
wgpu_backend.rs

1//! wgpu GPU FFT backend.
2//!
3//! This module is compiled only when the `wgpu_fft` feature is enabled.
4//! It exposes `fft_wgpu`, which attempts to:
5//!
6//! 1. Acquire a wgpu adapter and device (GPU).
7//! 2. Upload the input buffer to the GPU.
8//! 3. Execute the Cooley-Tukey radix-2 DIT FFT via a WGSL compute shader
9//!    (`fft_shader.wgsl`) for `log2(n)` passes.
10//! 4. Read the result back to the CPU.
11//!
12//! If no GPU adapter is found at runtime (CI, headless server, etc.) the
13//! function returns `Err(FftBackendError::NoAdapter)`.  The dispatch layer
14//! in [`super::dispatch`] catches that error and falls back to the CPU path,
15//! so callers never need to handle the GPU-unavailable case explicitly.
16//!
17//! # Feature gate
18//!
19//! This entire module is behind `#[cfg(feature = "wgpu_fft")]`.
20
21#[cfg(feature = "wgpu_fft")]
22mod inner {
23    use crate::error::FFTError;
24    use scirs2_core::numeric::Complex64;
25    use wgpu::{Backends, Instance, InstanceDescriptor, PowerPreference, RequestAdapterOptions};
26
27    use super::super::kernels::bit_reverse_permute_gpu;
28
29    // ─────────────────────────────────────────────────────────────────────────
30    // Error type
31    // ─────────────────────────────────────────────────────────────────────────
32
33    /// Errors specific to the wgpu FFT back-end.
34    #[derive(Debug, thiserror::Error)]
35    pub enum FftBackendError {
36        /// No compatible GPU adapter was found on this system.
37        #[error("no wgpu adapter available (GPU unavailable or unsupported)")]
38        NoAdapter,
39
40        /// The adapter was found but the device could not be created.
41        #[error("wgpu device creation failed: {0}")]
42        DeviceCreation(String),
43
44        /// A shader compilation error occurred.
45        #[error("WGSL shader compilation failed: {0}")]
46        ShaderCompilation(String),
47
48        /// A buffer operation (upload/readback) failed.
49        #[error("GPU buffer operation failed: {0}")]
50        Buffer(String),
51
52        /// The input length is not a power of two (required by the shader).
53        #[error("wgpu FFT requires a power-of-two input length; got {0}")]
54        NonPowerOfTwo(usize),
55    }
56
57    impl From<FftBackendError> for FFTError {
58        fn from(e: FftBackendError) -> Self {
59            FFTError::BackendError(e.to_string())
60        }
61    }
62
63    // ─────────────────────────────────────────────────────────────────────────
64    // Runtime availability check
65    // ─────────────────────────────────────────────────────────────────────────
66
67    /// Returns `true` when a wgpu adapter appears to be available on this
68    /// system.  This is a best-effort, synchronous check — it should not be
69    /// relied upon in production code without a subsequent `fft_wgpu` call.
70    ///
71    /// # Implementation note
72    ///
73    /// Performs a real wgpu adapter enumeration using `pollster::block_on` to
74    /// drive the async adapter request synchronously.  Returns `false` on any
75    /// headless / CI environment where no GPU adapter is found, so the
76    /// dispatch layer can fall back to the CPU path transparently.
77    pub fn gpu_available() -> bool {
78        let instance_desc = InstanceDescriptor {
79            backends: Backends::all(),
80            flags: wgpu::InstanceFlags::default(),
81            memory_budget_thresholds: Default::default(),
82            backend_options: Default::default(),
83            display: None,
84        };
85        let instance = Instance::new(instance_desc);
86        pollster::block_on(async {
87            instance
88                .request_adapter(&RequestAdapterOptions {
89                    power_preference: PowerPreference::default(),
90                    compatible_surface: None,
91                    force_fallback_adapter: false,
92                })
93                .await
94                .is_ok()
95        })
96    }
97
98    // ─────────────────────────────────────────────────────────────────────────
99    // Internal helpers
100    // ─────────────────────────────────────────────────────────────────────────
101
102    /// Encode FFT uniform-buffer params as raw little-endian bytes.
103    ///
104    /// Layout: `{ n: u32, stage: u32, inverse: u32, _pad: u32 }` (16 bytes).
105    fn encode_params(n: u32, stage: u32, inverse: u32) -> [u8; 16] {
106        let mut out = [0u8; 16];
107        out[0..4].copy_from_slice(&n.to_le_bytes());
108        out[4..8].copy_from_slice(&stage.to_le_bytes());
109        out[8..12].copy_from_slice(&inverse.to_le_bytes());
110        // _pad = 0 (already zero)
111        out
112    }
113
114    /// Serialise a slice of `Complex64` as `array<vec2<f32>>` bytes.
115    ///
116    /// Each complex sample becomes two contiguous `f32` values (real then
117    /// imaginary), each encoded as 4 little-endian bytes, for a total of 8
118    /// bytes per sample.
119    fn complex64_to_bytes(data: &[Complex64]) -> Vec<u8> {
120        let mut out = Vec::with_capacity(data.len() * 8);
121        for c in data {
122            out.extend_from_slice(&(c.re as f32).to_le_bytes());
123            out.extend_from_slice(&(c.im as f32).to_le_bytes());
124        }
125        out
126    }
127
128    /// Deserialise `array<vec2<f32>>` bytes back to `Vec<Complex64>`.
129    fn bytes_to_complex64(bytes: &[u8]) -> Vec<Complex64> {
130        bytes
131            .chunks_exact(8)
132            .map(|chunk| {
133                let re = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as f64;
134                let im = f32::from_le_bytes([chunk[4], chunk[5], chunk[6], chunk[7]]) as f64;
135                Complex64::new(re, im)
136            })
137            .collect()
138    }
139
140    // ─────────────────────────────────────────────────────────────────────────
141    // fft_wgpu
142    // ─────────────────────────────────────────────────────────────────────────
143
144    /// Compute an FFT (or IFFT) using the wgpu compute shader pipeline.
145    ///
146    /// `input` must have a **power-of-two length**.  Use
147    /// `super::dispatch::fft_auto_dispatch` for automatic padding.
148    ///
149    /// Returns `Err(FftBackendError::NoAdapter.into())` when no GPU is
150    /// available; the dispatch layer uses this to select the CPU path.
151    ///
152    /// # GPU execution pipeline
153    ///
154    /// 1. `wgpu::Instance::new` → `request_adapter` → `request_device`.
155    /// 2. Bit-reverse permute the input on the CPU.
156    /// 3. Upload the complex data to a storage buffer as `array<vec2<f32>>`.
157    /// 4. Create a uniform buffer for `FFTParams { n, stage, inverse, _pad }`.
158    /// 5. Load `fft_shader.wgsl` via `include_str!`, compile the compute pipeline.
159    /// 6. For each `stage` in `0..log2(n)`: update the uniform buffer with
160    ///    the current stage index via `queue.write_buffer`, encode one compute
161    ///    pass dispatching `ceil(n/2 / 64)` workgroups, submit and poll until
162    ///    the GPU is idle before the next stage.
163    /// 7. Copy the result buffer to a CPU-mappable staging buffer, map and
164    ///    read back the `vec2<f32>` pairs as `Complex64`.
165    /// 8. If `inverse`, scale each sample by `1.0 / n`.
166    pub fn fft_wgpu(input: &[Complex64], inverse: bool) -> Result<Vec<Complex64>, FFTError> {
167        use wgpu::{
168            util::{BufferInitDescriptor, DeviceExt as _},
169            BindGroupDescriptor, BindGroupEntry, BindGroupLayoutDescriptor, BindGroupLayoutEntry,
170            BindingType, BufferBindingType, BufferDescriptor, BufferUsages,
171            CommandEncoderDescriptor, ComputePassDescriptor, DeviceDescriptor, Features, Limits,
172            MapMode, ShaderModuleDescriptor, ShaderSource, ShaderStages,
173        };
174
175        let n = input.len();
176        if !n.is_power_of_two() {
177            return Err(FftBackendError::NonPowerOfTwo(n).into());
178        }
179        // n == 0 or n == 1 are degenerate: return as-is (trivial FFT).
180        if n <= 1 {
181            return Ok(input.to_vec());
182        }
183
184        let log2_n = n.trailing_zeros();
185        let inverse_flag: u32 = if inverse { 1 } else { 0 };
186        let byte_len = (n * 8) as u64; // 8 bytes per complex sample (2 × f32)
187
188        // ── Adapter / device acquisition ──────────────────────────────────────
189        let instance = Instance::new(InstanceDescriptor {
190            backends: Backends::all(),
191            flags: wgpu::InstanceFlags::default(),
192            memory_budget_thresholds: Default::default(),
193            backend_options: Default::default(),
194            display: None,
195        });
196
197        let adapter = pollster::block_on(instance.request_adapter(&RequestAdapterOptions {
198            power_preference: PowerPreference::HighPerformance,
199            compatible_surface: None,
200            force_fallback_adapter: false,
201        }))
202        .map_err(|_| FFTError::from(FftBackendError::NoAdapter))?;
203
204        let (device, queue) = pollster::block_on(adapter.request_device(&DeviceDescriptor {
205            label: Some("scirs2-fft"),
206            required_features: Features::empty(),
207            required_limits: Limits::default(),
208            ..Default::default()
209        }))
210        .map_err(|e| FFTError::from(FftBackendError::DeviceCreation(e.to_string())))?;
211
212        // ── Bit-reverse permutation on the CPU ────────────────────────────────
213        let mut buf = input.to_vec();
214        bit_reverse_permute_gpu(&mut buf);
215
216        // ── Data buffer (storage read_write + COPY_SRC for readback) ──────────
217        let data_bytes = complex64_to_bytes(&buf);
218
219        let buf_data = device.create_buffer_init(&BufferInitDescriptor {
220            label: Some("scirs2-fft-data"),
221            contents: &data_bytes,
222            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
223        });
224
225        // ── Uniform buffer for FFTParams (starts at stage 0) ──────────────────
226        let initial_params = encode_params(n as u32, 0, inverse_flag);
227
228        let buf_params = device.create_buffer_init(&BufferInitDescriptor {
229            label: Some("scirs2-fft-params"),
230            contents: &initial_params,
231            usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
232        });
233
234        // ── Staging buffer (CPU readable) ─────────────────────────────────────
235        let buf_staging = device.create_buffer(&BufferDescriptor {
236            label: Some("scirs2-fft-staging"),
237            size: byte_len,
238            usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
239            mapped_at_creation: false,
240        });
241
242        // ── Bind group layout matching the shader bindings ────────────────────
243        //    @group(0) @binding(0) var<storage, read_write> data: array<vec2<f32>>;
244        //    @group(0) @binding(1) var<uniform> params: FFTParams;
245        let bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
246            label: Some("scirs2-fft-bgl"),
247            entries: &[
248                BindGroupLayoutEntry {
249                    binding: 0,
250                    visibility: ShaderStages::COMPUTE,
251                    ty: BindingType::Buffer {
252                        ty: BufferBindingType::Storage { read_only: false },
253                        has_dynamic_offset: false,
254                        min_binding_size: None,
255                    },
256                    count: None,
257                },
258                BindGroupLayoutEntry {
259                    binding: 1,
260                    visibility: ShaderStages::COMPUTE,
261                    ty: BindingType::Buffer {
262                        ty: BufferBindingType::Uniform,
263                        has_dynamic_offset: false,
264                        min_binding_size: None,
265                    },
266                    count: None,
267                },
268            ],
269        });
270
271        // ── Pipeline layout ───────────────────────────────────────────────────
272        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
273            label: Some("scirs2-fft-layout"),
274            bind_group_layouts: &[Some(&bgl)],
275            ..Default::default()
276        });
277
278        // ── Shader module ─────────────────────────────────────────────────────
279        let shader_src = include_str!("fft_shader.wgsl");
280        let shader_module = device.create_shader_module(ShaderModuleDescriptor {
281            label: Some("scirs2-fft-shader"),
282            source: ShaderSource::Wgsl(shader_src.into()),
283        });
284
285        // ── Compute pipeline ──────────────────────────────────────────────────
286        let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
287            label: Some("scirs2-fft-pipeline"),
288            layout: Some(&pipeline_layout),
289            module: &shader_module,
290            entry_point: Some("main"),
291            compilation_options: Default::default(),
292            cache: None,
293        });
294
295        // ── Bind group (static — data buffer and params buffer are fixed) ─────
296        let bind_group = device.create_bind_group(&BindGroupDescriptor {
297            label: Some("scirs2-fft-bg"),
298            layout: &bgl,
299            entries: &[
300                BindGroupEntry {
301                    binding: 0,
302                    resource: buf_data.as_entire_binding(),
303                },
304                BindGroupEntry {
305                    binding: 1,
306                    resource: buf_params.as_entire_binding(),
307                },
308            ],
309        });
310
311        // ── Per-stage dispatch loop ───────────────────────────────────────────
312        // Dispatch ceil(n/2 / 64) workgroups; the shader uses @workgroup_size(64)
313        // and each thread handles exactly one butterfly pair.
314        let workgroups = (n / 2).div_ceil(64) as u32;
315
316        for stage in 0..log2_n {
317            // Update the uniform buffer with the current stage index.
318            let params_bytes = encode_params(n as u32, stage, inverse_flag);
319            queue.write_buffer(&buf_params, 0, &params_bytes);
320
321            let mut encoder =
322                device.create_command_encoder(&CommandEncoderDescriptor { label: None });
323            {
324                let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor {
325                    label: None,
326                    timestamp_writes: None,
327                });
328                pass.set_pipeline(&pipeline);
329                pass.set_bind_group(0, &bind_group, &[]);
330                pass.dispatch_workgroups(workgroups, 1, 1);
331            }
332            queue.submit([encoder.finish()]);
333
334            // Wait for the GPU to finish before updating the stage for the next pass.
335            device
336                .poll(wgpu::PollType::wait_indefinitely())
337                .map_err(|e| {
338                    FFTError::from(FftBackendError::Buffer(format!("GPU poll error: {e:?}")))
339                })?;
340        }
341
342        // ── Copy result from data buffer to staging buffer ────────────────────
343        let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor { label: None });
344        encoder.copy_buffer_to_buffer(&buf_data, 0, &buf_staging, 0, byte_len);
345        queue.submit([encoder.finish()]);
346
347        // ── Map staging buffer and read back ──────────────────────────────────
348        device
349            .poll(wgpu::PollType::wait_indefinitely())
350            .map_err(|e| {
351                FFTError::from(FftBackendError::Buffer(format!(
352                    "GPU poll before map: {e:?}"
353                )))
354            })?;
355
356        let slice = buf_staging.slice(0..byte_len);
357        let (tx, rx) = std::sync::mpsc::channel();
358        slice.map_async(MapMode::Read, move |r| {
359            let _ = tx.send(r);
360        });
361
362        device
363            .poll(wgpu::PollType::wait_indefinitely())
364            .map_err(|e| {
365                FFTError::from(FftBackendError::Buffer(format!(
366                    "GPU poll during map: {e:?}"
367                )))
368            })?;
369
370        rx.recv()
371            .map_err(|_| {
372                FFTError::from(FftBackendError::Buffer(
373                    "channel closed during map_async".into(),
374                ))
375            })?
376            .map_err(|e| {
377                FFTError::from(FftBackendError::Buffer(format!("map_async failed: {e:?}")))
378            })?;
379
380        let mapped = slice.get_mapped_range();
381        let mut result = bytes_to_complex64(&mapped);
382        drop(mapped);
383        buf_staging.unmap();
384
385        // ── Inverse FFT scaling ───────────────────────────────────────────────
386        if inverse {
387            let scale = 1.0 / n as f64;
388            for c in &mut result {
389                c.re *= scale;
390                c.im *= scale;
391            }
392        }
393
394        Ok(result)
395    }
396}
397
398// Re-export the public items when the feature is active.
399#[cfg(feature = "wgpu_fft")]
400pub use inner::{fft_wgpu, gpu_available, FftBackendError};
401
402#[cfg(all(test, feature = "wgpu_fft"))]
403mod tests {
404    use super::{fft_wgpu, gpu_available};
405    use scirs2_core::numeric::Complex64;
406
407    /// Verify that `gpu_available()` completes without panicking and returns a
408    /// valid boolean.  The actual value (`true` or `false`) is environment-
409    /// dependent: CI / headless machines will return `false`, real GPU hosts
410    /// may return `true`.  We only assert that the call completes.
411    #[test]
412    fn test_gpu_available_returns_bool() {
413        let result: bool = gpu_available();
414        // Log the result for diagnostic purposes; never assert the specific value.
415        println!("gpu_available() = {result}");
416    }
417
418    /// An 8-point FFT then IFFT must recover the original input within f32
419    /// floating-point tolerance (~0.01).  On headless / CI machines without a
420    /// GPU the test is silently skipped.
421    #[test]
422    fn test_fft_wgpu_roundtrip_or_skip() {
423        let input: Vec<Complex64> = (0..8).map(|i| Complex64::new(i as f64, 0.0)).collect();
424
425        match fft_wgpu(&input, false) {
426            Err(e)
427                if e.to_string().contains("adapter")
428                    || e.to_string().contains("NoAdapter")
429                    || e.to_string().contains("no wgpu") =>
430            {
431                println!("test_fft_wgpu_roundtrip_or_skip: skipping — no GPU adapter");
432            }
433            Err(e) => panic!("unexpected fft_wgpu error: {e}"),
434            Ok(spectrum) => {
435                assert_eq!(spectrum.len(), input.len());
436                // IFFT to recover
437                match fft_wgpu(&spectrum, true) {
438                    Err(e) => panic!("unexpected ifft_wgpu error: {e}"),
439                    Ok(recovered) => {
440                        for (orig, rec) in input.iter().zip(recovered.iter()) {
441                            assert!(
442                                (orig.re - rec.re).abs() < 0.01,
443                                "re mismatch: {} vs {}",
444                                orig.re,
445                                rec.re
446                            );
447                            assert!(
448                                (orig.im - rec.im).abs() < 0.01,
449                                "im mismatch: {} vs {}",
450                                orig.im,
451                                rec.im
452                            );
453                        }
454                    }
455                }
456            }
457        }
458    }
459
460    /// Non-power-of-two input must always be rejected immediately, regardless
461    /// of whether a GPU adapter is available.
462    #[test]
463    fn test_fft_wgpu_non_power_of_two_rejected() {
464        let input: Vec<Complex64> = vec![Complex64::new(1.0, 0.0); 7];
465        let result = fft_wgpu(&input, false);
466        assert!(
467            result.is_err(),
468            "non-power-of-two input must return an error"
469        );
470    }
471}