scirs2_fft/gpu_fft/wgpu_backend.rs
1//! wgpu GPU FFT backend.
2//!
3//! This module is compiled only when the `wgpu_fft` feature is enabled.
4//! It exposes `fft_wgpu`, which attempts to:
5//!
6//! 1. Acquire a wgpu adapter and device (GPU).
7//! 2. Upload the input buffer to the GPU.
8//! 3. Execute the Cooley-Tukey radix-2 DIT FFT via a WGSL compute shader
9//! (`fft_shader.wgsl`) for `log2(n)` passes.
10//! 4. Read the result back to the CPU.
11//!
12//! If no GPU adapter is found at runtime (CI, headless server, etc.) the
13//! function returns `Err(FftBackendError::NoAdapter)`. The dispatch layer
14//! in [`super::dispatch`] catches that error and falls back to the CPU path,
15//! so callers never need to handle the GPU-unavailable case explicitly.
16//!
17//! # Feature gate
18//!
19//! This entire module is behind `#[cfg(feature = "wgpu_fft")]`.
20
21#[cfg(feature = "wgpu_fft")]
22mod inner {
23 use crate::error::FFTError;
24 use scirs2_core::numeric::Complex64;
25 use wgpu::{Backends, Instance, InstanceDescriptor, PowerPreference, RequestAdapterOptions};
26
27 use super::super::kernels::bit_reverse_permute_gpu;
28
29 // ─────────────────────────────────────────────────────────────────────────
30 // Error type
31 // ─────────────────────────────────────────────────────────────────────────
32
33 /// Errors specific to the wgpu FFT back-end.
34 #[derive(Debug, thiserror::Error)]
35 pub enum FftBackendError {
36 /// No compatible GPU adapter was found on this system.
37 #[error("no wgpu adapter available (GPU unavailable or unsupported)")]
38 NoAdapter,
39
40 /// The adapter was found but the device could not be created.
41 #[error("wgpu device creation failed: {0}")]
42 DeviceCreation(String),
43
44 /// A shader compilation error occurred.
45 #[error("WGSL shader compilation failed: {0}")]
46 ShaderCompilation(String),
47
48 /// A buffer operation (upload/readback) failed.
49 #[error("GPU buffer operation failed: {0}")]
50 Buffer(String),
51
52 /// The input length is not a power of two (required by the shader).
53 #[error("wgpu FFT requires a power-of-two input length; got {0}")]
54 NonPowerOfTwo(usize),
55 }
56
57 impl From<FftBackendError> for FFTError {
58 fn from(e: FftBackendError) -> Self {
59 FFTError::BackendError(e.to_string())
60 }
61 }
62
63 // ─────────────────────────────────────────────────────────────────────────
64 // Runtime availability check
65 // ─────────────────────────────────────────────────────────────────────────
66
67 /// Returns `true` when a wgpu adapter appears to be available on this
68 /// system. This is a best-effort, synchronous check — it should not be
69 /// relied upon in production code without a subsequent `fft_wgpu` call.
70 ///
71 /// # Implementation note
72 ///
73 /// Performs a real wgpu adapter enumeration using `pollster::block_on` to
74 /// drive the async adapter request synchronously. Returns `false` on any
75 /// headless / CI environment where no GPU adapter is found, so the
76 /// dispatch layer can fall back to the CPU path transparently.
77 pub fn gpu_available() -> bool {
78 let instance_desc = InstanceDescriptor {
79 backends: Backends::all(),
80 flags: wgpu::InstanceFlags::default(),
81 memory_budget_thresholds: Default::default(),
82 backend_options: Default::default(),
83 display: None,
84 };
85 let instance = Instance::new(instance_desc);
86 pollster::block_on(async {
87 instance
88 .request_adapter(&RequestAdapterOptions {
89 power_preference: PowerPreference::default(),
90 compatible_surface: None,
91 force_fallback_adapter: false,
92 })
93 .await
94 .is_ok()
95 })
96 }
97
98 // ─────────────────────────────────────────────────────────────────────────
99 // Internal helpers
100 // ─────────────────────────────────────────────────────────────────────────
101
102 /// Encode FFT uniform-buffer params as raw little-endian bytes.
103 ///
104 /// Layout: `{ n: u32, stage: u32, inverse: u32, _pad: u32 }` (16 bytes).
105 fn encode_params(n: u32, stage: u32, inverse: u32) -> [u8; 16] {
106 let mut out = [0u8; 16];
107 out[0..4].copy_from_slice(&n.to_le_bytes());
108 out[4..8].copy_from_slice(&stage.to_le_bytes());
109 out[8..12].copy_from_slice(&inverse.to_le_bytes());
110 // _pad = 0 (already zero)
111 out
112 }
113
114 /// Serialise a slice of `Complex64` as `array<vec2<f32>>` bytes.
115 ///
116 /// Each complex sample becomes two contiguous `f32` values (real then
117 /// imaginary), each encoded as 4 little-endian bytes, for a total of 8
118 /// bytes per sample.
119 fn complex64_to_bytes(data: &[Complex64]) -> Vec<u8> {
120 let mut out = Vec::with_capacity(data.len() * 8);
121 for c in data {
122 out.extend_from_slice(&(c.re as f32).to_le_bytes());
123 out.extend_from_slice(&(c.im as f32).to_le_bytes());
124 }
125 out
126 }
127
128 /// Deserialise `array<vec2<f32>>` bytes back to `Vec<Complex64>`.
129 fn bytes_to_complex64(bytes: &[u8]) -> Vec<Complex64> {
130 bytes
131 .chunks_exact(8)
132 .map(|chunk| {
133 let re = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as f64;
134 let im = f32::from_le_bytes([chunk[4], chunk[5], chunk[6], chunk[7]]) as f64;
135 Complex64::new(re, im)
136 })
137 .collect()
138 }
139
140 // ─────────────────────────────────────────────────────────────────────────
141 // fft_wgpu
142 // ─────────────────────────────────────────────────────────────────────────
143
144 /// Compute an FFT (or IFFT) using the wgpu compute shader pipeline.
145 ///
146 /// `input` must have a **power-of-two length**. Use
147 /// `super::dispatch::fft_auto_dispatch` for automatic padding.
148 ///
149 /// Returns `Err(FftBackendError::NoAdapter.into())` when no GPU is
150 /// available; the dispatch layer uses this to select the CPU path.
151 ///
152 /// # GPU execution pipeline
153 ///
154 /// 1. `wgpu::Instance::new` → `request_adapter` → `request_device`.
155 /// 2. Bit-reverse permute the input on the CPU.
156 /// 3. Upload the complex data to a storage buffer as `array<vec2<f32>>`.
157 /// 4. Create a uniform buffer for `FFTParams { n, stage, inverse, _pad }`.
158 /// 5. Load `fft_shader.wgsl` via `include_str!`, compile the compute pipeline.
159 /// 6. For each `stage` in `0..log2(n)`: update the uniform buffer with
160 /// the current stage index via `queue.write_buffer`, encode one compute
161 /// pass dispatching `ceil(n/2 / 64)` workgroups, submit and poll until
162 /// the GPU is idle before the next stage.
163 /// 7. Copy the result buffer to a CPU-mappable staging buffer, map and
164 /// read back the `vec2<f32>` pairs as `Complex64`.
165 /// 8. If `inverse`, scale each sample by `1.0 / n`.
166 pub fn fft_wgpu(input: &[Complex64], inverse: bool) -> Result<Vec<Complex64>, FFTError> {
167 use wgpu::{
168 util::{BufferInitDescriptor, DeviceExt as _},
169 BindGroupDescriptor, BindGroupEntry, BindGroupLayoutDescriptor, BindGroupLayoutEntry,
170 BindingType, BufferBindingType, BufferDescriptor, BufferUsages,
171 CommandEncoderDescriptor, ComputePassDescriptor, DeviceDescriptor, Features, Limits,
172 MapMode, ShaderModuleDescriptor, ShaderSource, ShaderStages,
173 };
174
175 let n = input.len();
176 if !n.is_power_of_two() {
177 return Err(FftBackendError::NonPowerOfTwo(n).into());
178 }
179 // n == 0 or n == 1 are degenerate: return as-is (trivial FFT).
180 if n <= 1 {
181 return Ok(input.to_vec());
182 }
183
184 let log2_n = n.trailing_zeros();
185 let inverse_flag: u32 = if inverse { 1 } else { 0 };
186 let byte_len = (n * 8) as u64; // 8 bytes per complex sample (2 × f32)
187
188 // ── Adapter / device acquisition ──────────────────────────────────────
189 let instance = Instance::new(InstanceDescriptor {
190 backends: Backends::all(),
191 flags: wgpu::InstanceFlags::default(),
192 memory_budget_thresholds: Default::default(),
193 backend_options: Default::default(),
194 display: None,
195 });
196
197 let adapter = pollster::block_on(instance.request_adapter(&RequestAdapterOptions {
198 power_preference: PowerPreference::HighPerformance,
199 compatible_surface: None,
200 force_fallback_adapter: false,
201 }))
202 .map_err(|_| FFTError::from(FftBackendError::NoAdapter))?;
203
204 let (device, queue) = pollster::block_on(adapter.request_device(&DeviceDescriptor {
205 label: Some("scirs2-fft"),
206 required_features: Features::empty(),
207 required_limits: Limits::default(),
208 ..Default::default()
209 }))
210 .map_err(|e| FFTError::from(FftBackendError::DeviceCreation(e.to_string())))?;
211
212 // ── Bit-reverse permutation on the CPU ────────────────────────────────
213 let mut buf = input.to_vec();
214 bit_reverse_permute_gpu(&mut buf);
215
216 // ── Data buffer (storage read_write + COPY_SRC for readback) ──────────
217 let data_bytes = complex64_to_bytes(&buf);
218
219 let buf_data = device.create_buffer_init(&BufferInitDescriptor {
220 label: Some("scirs2-fft-data"),
221 contents: &data_bytes,
222 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
223 });
224
225 // ── Uniform buffer for FFTParams (starts at stage 0) ──────────────────
226 let initial_params = encode_params(n as u32, 0, inverse_flag);
227
228 let buf_params = device.create_buffer_init(&BufferInitDescriptor {
229 label: Some("scirs2-fft-params"),
230 contents: &initial_params,
231 usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
232 });
233
234 // ── Staging buffer (CPU readable) ─────────────────────────────────────
235 let buf_staging = device.create_buffer(&BufferDescriptor {
236 label: Some("scirs2-fft-staging"),
237 size: byte_len,
238 usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
239 mapped_at_creation: false,
240 });
241
242 // ── Bind group layout matching the shader bindings ────────────────────
243 // @group(0) @binding(0) var<storage, read_write> data: array<vec2<f32>>;
244 // @group(0) @binding(1) var<uniform> params: FFTParams;
245 let bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
246 label: Some("scirs2-fft-bgl"),
247 entries: &[
248 BindGroupLayoutEntry {
249 binding: 0,
250 visibility: ShaderStages::COMPUTE,
251 ty: BindingType::Buffer {
252 ty: BufferBindingType::Storage { read_only: false },
253 has_dynamic_offset: false,
254 min_binding_size: None,
255 },
256 count: None,
257 },
258 BindGroupLayoutEntry {
259 binding: 1,
260 visibility: ShaderStages::COMPUTE,
261 ty: BindingType::Buffer {
262 ty: BufferBindingType::Uniform,
263 has_dynamic_offset: false,
264 min_binding_size: None,
265 },
266 count: None,
267 },
268 ],
269 });
270
271 // ── Pipeline layout ───────────────────────────────────────────────────
272 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
273 label: Some("scirs2-fft-layout"),
274 bind_group_layouts: &[Some(&bgl)],
275 ..Default::default()
276 });
277
278 // ── Shader module ─────────────────────────────────────────────────────
279 let shader_src = include_str!("fft_shader.wgsl");
280 let shader_module = device.create_shader_module(ShaderModuleDescriptor {
281 label: Some("scirs2-fft-shader"),
282 source: ShaderSource::Wgsl(shader_src.into()),
283 });
284
285 // ── Compute pipeline ──────────────────────────────────────────────────
286 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
287 label: Some("scirs2-fft-pipeline"),
288 layout: Some(&pipeline_layout),
289 module: &shader_module,
290 entry_point: Some("main"),
291 compilation_options: Default::default(),
292 cache: None,
293 });
294
295 // ── Bind group (static — data buffer and params buffer are fixed) ─────
296 let bind_group = device.create_bind_group(&BindGroupDescriptor {
297 label: Some("scirs2-fft-bg"),
298 layout: &bgl,
299 entries: &[
300 BindGroupEntry {
301 binding: 0,
302 resource: buf_data.as_entire_binding(),
303 },
304 BindGroupEntry {
305 binding: 1,
306 resource: buf_params.as_entire_binding(),
307 },
308 ],
309 });
310
311 // ── Per-stage dispatch loop ───────────────────────────────────────────
312 // Dispatch ceil(n/2 / 64) workgroups; the shader uses @workgroup_size(64)
313 // and each thread handles exactly one butterfly pair.
314 let workgroups = (n / 2).div_ceil(64) as u32;
315
316 for stage in 0..log2_n {
317 // Update the uniform buffer with the current stage index.
318 let params_bytes = encode_params(n as u32, stage, inverse_flag);
319 queue.write_buffer(&buf_params, 0, ¶ms_bytes);
320
321 let mut encoder =
322 device.create_command_encoder(&CommandEncoderDescriptor { label: None });
323 {
324 let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor {
325 label: None,
326 timestamp_writes: None,
327 });
328 pass.set_pipeline(&pipeline);
329 pass.set_bind_group(0, &bind_group, &[]);
330 pass.dispatch_workgroups(workgroups, 1, 1);
331 }
332 queue.submit([encoder.finish()]);
333
334 // Wait for the GPU to finish before updating the stage for the next pass.
335 device
336 .poll(wgpu::PollType::wait_indefinitely())
337 .map_err(|e| {
338 FFTError::from(FftBackendError::Buffer(format!("GPU poll error: {e:?}")))
339 })?;
340 }
341
342 // ── Copy result from data buffer to staging buffer ────────────────────
343 let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor { label: None });
344 encoder.copy_buffer_to_buffer(&buf_data, 0, &buf_staging, 0, byte_len);
345 queue.submit([encoder.finish()]);
346
347 // ── Map staging buffer and read back ──────────────────────────────────
348 device
349 .poll(wgpu::PollType::wait_indefinitely())
350 .map_err(|e| {
351 FFTError::from(FftBackendError::Buffer(format!(
352 "GPU poll before map: {e:?}"
353 )))
354 })?;
355
356 let slice = buf_staging.slice(0..byte_len);
357 let (tx, rx) = std::sync::mpsc::channel();
358 slice.map_async(MapMode::Read, move |r| {
359 let _ = tx.send(r);
360 });
361
362 device
363 .poll(wgpu::PollType::wait_indefinitely())
364 .map_err(|e| {
365 FFTError::from(FftBackendError::Buffer(format!(
366 "GPU poll during map: {e:?}"
367 )))
368 })?;
369
370 rx.recv()
371 .map_err(|_| {
372 FFTError::from(FftBackendError::Buffer(
373 "channel closed during map_async".into(),
374 ))
375 })?
376 .map_err(|e| {
377 FFTError::from(FftBackendError::Buffer(format!("map_async failed: {e:?}")))
378 })?;
379
380 let mapped = slice.get_mapped_range();
381 let mut result = bytes_to_complex64(&mapped);
382 drop(mapped);
383 buf_staging.unmap();
384
385 // ── Inverse FFT scaling ───────────────────────────────────────────────
386 if inverse {
387 let scale = 1.0 / n as f64;
388 for c in &mut result {
389 c.re *= scale;
390 c.im *= scale;
391 }
392 }
393
394 Ok(result)
395 }
396}
397
398// Re-export the public items when the feature is active.
399#[cfg(feature = "wgpu_fft")]
400pub use inner::{fft_wgpu, gpu_available, FftBackendError};
401
402#[cfg(all(test, feature = "wgpu_fft"))]
403mod tests {
404 use super::{fft_wgpu, gpu_available};
405 use scirs2_core::numeric::Complex64;
406
407 /// Verify that `gpu_available()` completes without panicking and returns a
408 /// valid boolean. The actual value (`true` or `false`) is environment-
409 /// dependent: CI / headless machines will return `false`, real GPU hosts
410 /// may return `true`. We only assert that the call completes.
411 #[test]
412 fn test_gpu_available_returns_bool() {
413 let result: bool = gpu_available();
414 // Log the result for diagnostic purposes; never assert the specific value.
415 println!("gpu_available() = {result}");
416 }
417
418 /// An 8-point FFT then IFFT must recover the original input within f32
419 /// floating-point tolerance (~0.01). On headless / CI machines without a
420 /// GPU the test is silently skipped.
421 #[test]
422 fn test_fft_wgpu_roundtrip_or_skip() {
423 let input: Vec<Complex64> = (0..8).map(|i| Complex64::new(i as f64, 0.0)).collect();
424
425 match fft_wgpu(&input, false) {
426 Err(e)
427 if e.to_string().contains("adapter")
428 || e.to_string().contains("NoAdapter")
429 || e.to_string().contains("no wgpu") =>
430 {
431 println!("test_fft_wgpu_roundtrip_or_skip: skipping — no GPU adapter");
432 }
433 Err(e) => panic!("unexpected fft_wgpu error: {e}"),
434 Ok(spectrum) => {
435 assert_eq!(spectrum.len(), input.len());
436 // IFFT to recover
437 match fft_wgpu(&spectrum, true) {
438 Err(e) => panic!("unexpected ifft_wgpu error: {e}"),
439 Ok(recovered) => {
440 for (orig, rec) in input.iter().zip(recovered.iter()) {
441 assert!(
442 (orig.re - rec.re).abs() < 0.01,
443 "re mismatch: {} vs {}",
444 orig.re,
445 rec.re
446 );
447 assert!(
448 (orig.im - rec.im).abs() < 0.01,
449 "im mismatch: {} vs {}",
450 orig.im,
451 rec.im
452 );
453 }
454 }
455 }
456 }
457 }
458 }
459
460 /// Non-power-of-two input must always be rejected immediately, regardless
461 /// of whether a GPU adapter is available.
462 #[test]
463 fn test_fft_wgpu_non_power_of_two_rejected() {
464 let input: Vec<Complex64> = vec![Complex64::new(1.0, 0.0); 7];
465 let result = fft_wgpu(&input, false);
466 assert!(
467 result.is_err(),
468 "non-power-of-two input must return an error"
469 );
470 }
471}