scirs2_fft/
backend.rs

1//! FFT Backend System
2//!
3//! This module provides a pluggable backend system for FFT implementations,
4//! similar to SciPy's backend model. This allows users to choose between
5//! different FFT implementations at runtime.
6
7use crate::error::{FFTError, FFTResult};
8use rustfft::FftPlanner;
9use scirs2_core::numeric::Complex64;
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex, OnceLock};
12
13/// FFT Backend trait for implementing different FFT algorithms
14pub trait FftBackend: Send + Sync {
15    /// Name of the backend
16    fn name(&self) -> &str;
17
18    /// Description of the backend
19    fn description(&self) -> &str;
20
21    /// Check if this backend is available
22    fn is_available(&self) -> bool;
23
24    /// Perform forward FFT
25    fn fft(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()>;
26
27    /// Perform inverse FFT
28    fn ifft(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()>;
29
30    /// Perform FFT with specific size (may be cached)
31    fn fft_sized(
32        &self,
33        input: &[Complex64],
34        output: &mut [Complex64],
35        size: usize,
36    ) -> FFTResult<()>;
37
38    /// Perform inverse FFT with specific size (may be cached)
39    fn ifft_sized(
40        &self,
41        input: &[Complex64],
42        output: &mut [Complex64],
43        size: usize,
44    ) -> FFTResult<()>;
45
46    /// Check if this backend supports a specific feature
47    fn supports_feature(&self, feature: &str) -> bool;
48}
49
50/// RustFFT backend implementation
51pub struct RustFftBackend {
52    planner: Arc<Mutex<FftPlanner<f64>>>,
53}
54
55impl RustFftBackend {
56    pub fn new() -> Self {
57        Self {
58            planner: Arc::new(Mutex::new(FftPlanner::new())),
59        }
60    }
61}
62
63impl Default for RustFftBackend {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl FftBackend for RustFftBackend {
70    fn name(&self) -> &str {
71        "rustfft"
72    }
73
74    fn description(&self) -> &str {
75        "Pure Rust FFT implementation using RustFFT library"
76    }
77
78    fn is_available(&self) -> bool {
79        true
80    }
81
82    fn fft(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()> {
83        self.fft_sized(input, output, input.len())
84    }
85
86    fn ifft(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()> {
87        self.ifft_sized(input, output, input.len())
88    }
89
90    fn fft_sized(
91        &self,
92        input: &[Complex64],
93        output: &mut [Complex64],
94        size: usize,
95    ) -> FFTResult<()> {
96        if input.len() != size || output.len() != size {
97            return Err(FFTError::ValueError(
98                "Input and output sizes must match the specified size".to_string(),
99            ));
100        }
101
102        // Get cached plan from the planner
103        let mut planner = self.planner.lock().expect("Operation failed");
104        let fft = planner.plan_fft_forward(size);
105
106        // Convert to rustfft's Complex type
107        let mut buffer: Vec<rustfft::num_complex::Complex<f64>> = input
108            .iter()
109            .map(|&c| rustfft::num_complex::Complex::new(c.re, c.im))
110            .collect();
111
112        // Perform FFT
113        fft.process(&mut buffer);
114
115        // Copy to output
116        for (i, &c) in buffer.iter().enumerate() {
117            output[i] = Complex64::new(c.re, c.im);
118        }
119
120        Ok(())
121    }
122
123    fn ifft_sized(
124        &self,
125        input: &[Complex64],
126        output: &mut [Complex64],
127        size: usize,
128    ) -> FFTResult<()> {
129        if input.len() != size || output.len() != size {
130            return Err(FFTError::ValueError(
131                "Input and output sizes must match the specified size".to_string(),
132            ));
133        }
134
135        // Get cached plan from the planner
136        let mut planner = self.planner.lock().expect("Operation failed");
137        let fft = planner.plan_fft_inverse(size);
138
139        // Convert to rustfft's Complex type
140        let mut buffer: Vec<rustfft::num_complex::Complex<f64>> = input
141            .iter()
142            .map(|&c| rustfft::num_complex::Complex::new(c.re, c.im))
143            .collect();
144
145        // Perform IFFT
146        fft.process(&mut buffer);
147
148        // Copy to output with normalization
149        let scale = 1.0 / size as f64;
150        for (i, &c) in buffer.iter().enumerate() {
151            output[i] = Complex64::new(c.re * scale, c.im * scale);
152        }
153
154        Ok(())
155    }
156
157    fn supports_feature(&self, feature: &str) -> bool {
158        matches!(feature, "1d_fft" | "2d_fft" | "nd_fft" | "cached_plans")
159    }
160}
161
162/// Backend manager for FFT operations
163pub struct BackendManager {
164    backends: Arc<Mutex<HashMap<String, Arc<dyn FftBackend>>>>,
165    current_backend: Arc<Mutex<String>>,
166}
167
168impl BackendManager {
169    /// Create a new backend manager with default backends
170    pub fn new() -> Self {
171        let mut backends = HashMap::new();
172
173        // Add default RustFFT backend
174        let rustfft_backend = Arc::new(RustFftBackend::new()) as Arc<dyn FftBackend>;
175        backends.insert("rustfft".to_string(), rustfft_backend);
176
177        Self {
178            backends: Arc::new(Mutex::new(backends)),
179            current_backend: Arc::new(Mutex::new("rustfft".to_string())),
180        }
181    }
182
183    /// Register a new backend
184    pub fn register_backend(&self, name: String, backend: Arc<dyn FftBackend>) -> FFTResult<()> {
185        let mut backends = self.backends.lock().expect("Operation failed");
186        if backends.contains_key(&name) {
187            return Err(FFTError::ValueError(format!(
188                "Backend '{name}' already exists"
189            )));
190        }
191        backends.insert(name, backend);
192        Ok(())
193    }
194
195    /// Get available backends
196    pub fn list_backends(&self) -> Vec<String> {
197        let backends = self.backends.lock().expect("Operation failed");
198        backends.keys().cloned().collect()
199    }
200
201    /// Set the current backend
202    pub fn set_backend(&self, name: &str) -> FFTResult<()> {
203        let backends = self.backends.lock().expect("Operation failed");
204        if !backends.contains_key(name) {
205            return Err(FFTError::ValueError(format!("Backend '{name}' not found")));
206        }
207
208        // Check if backend is available
209        if let Some(backend) = backends.get(name) {
210            if !backend.is_available() {
211                return Err(FFTError::ValueError(format!(
212                    "Backend '{name}' is not available"
213                )));
214            }
215        }
216
217        *self.current_backend.lock().expect("Operation failed") = name.to_string();
218        Ok(())
219    }
220
221    /// Get current backend name
222    pub fn get_backend_name(&self) -> String {
223        self.current_backend
224            .lock()
225            .expect("Operation failed")
226            .clone()
227    }
228
229    /// Get current backend
230    pub fn get_backend(&self) -> Arc<dyn FftBackend> {
231        let current_name = self.current_backend.lock().expect("Operation failed");
232        let backends = self.backends.lock().expect("Operation failed");
233        backends
234            .get(&*current_name)
235            .cloned()
236            .expect("Current backend should always exist")
237    }
238
239    /// Get backend info
240    pub fn get_backend_info(&self, name: &str) -> Option<BackendInfo> {
241        let backends = self.backends.lock().expect("Operation failed");
242        backends.get(name).map(|backend| BackendInfo {
243            name: backend.name().to_string(),
244            description: backend.description().to_string(),
245            available: backend.is_available(),
246        })
247    }
248}
249
250impl Default for BackendManager {
251    fn default() -> Self {
252        Self::new()
253    }
254}
255
256/// Information about a backend
257#[derive(Debug, Clone)]
258pub struct BackendInfo {
259    pub name: String,
260    pub description: String,
261    pub available: bool,
262}
263
264impl std::fmt::Display for BackendInfo {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        write!(
267            f,
268            "{} - {} ({})",
269            self.name,
270            self.description,
271            if self.available {
272                "available"
273            } else {
274                "not available"
275            }
276        )
277    }
278}
279
280/// Global backend manager instance
281static GLOBAL_BACKEND_MANAGER: OnceLock<BackendManager> = OnceLock::new();
282
283/// Get the global backend manager
284#[allow(dead_code)]
285pub fn get_backend_manager() -> &'static BackendManager {
286    GLOBAL_BACKEND_MANAGER.get_or_init(BackendManager::new)
287}
288
289/// Initialize global backend manager with custom configuration
290#[allow(dead_code)]
291pub fn init_backend_manager(manager: BackendManager) -> Result<(), &'static str> {
292    GLOBAL_BACKEND_MANAGER
293        .set(manager)
294        .map_err(|_| "Global backend _manager already initialized")
295}
296
297/// List available backends
298#[allow(dead_code)]
299pub fn list_backends() -> Vec<String> {
300    get_backend_manager().list_backends()
301}
302
303/// Set the current backend
304#[allow(dead_code)]
305pub fn set_backend(name: &str) -> FFTResult<()> {
306    get_backend_manager().set_backend(name)
307}
308
309/// Get current backend name
310#[allow(dead_code)]
311pub fn get_backend_name() -> String {
312    get_backend_manager().get_backend_name()
313}
314
315/// Get backend information
316#[allow(dead_code)]
317pub fn get_backend_info(name: &str) -> Option<BackendInfo> {
318    get_backend_manager().get_backend_info(name)
319}
320
321/// Context manager for temporarily using a different backend
322pub struct BackendContext {
323    previous_backend: String,
324    manager: &'static BackendManager,
325}
326
327impl BackendContext {
328    /// Create a new backend context
329    pub fn new(_backendname: &str) -> FFTResult<Self> {
330        let manager = get_backend_manager();
331        let previous_backend = manager.get_backend_name();
332
333        // Set the new backend
334        manager.set_backend(_backendname)?;
335
336        Ok(Self {
337            previous_backend,
338            manager,
339        })
340    }
341}
342
343impl Drop for BackendContext {
344    fn drop(&mut self) {
345        // Restore previous backend
346        let _ = self.manager.set_backend(&self.previous_backend);
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_rustfft_backend() {
356        let backend = RustFftBackend::new();
357        assert_eq!(backend.name(), "rustfft");
358        assert!(backend.is_available());
359        assert!(backend.supports_feature("1d_fft"));
360    }
361
362    #[test]
363    fn test_backend_manager() {
364        let manager = BackendManager::new();
365
366        // Check default backend
367        assert_eq!(manager.get_backend_name(), "rustfft");
368
369        // List backends
370        let backends = manager.list_backends();
371        assert!(backends.contains(&"rustfft".to_string()));
372
373        // Get backend info
374        let info = manager
375            .get_backend_info("rustfft")
376            .expect("Operation failed");
377        assert!(info.available);
378    }
379
380    #[test]
381    fn test_backend_context() {
382        let manager = get_backend_manager();
383        let original = manager.get_backend_name();
384
385        {
386            let _ctx = BackendContext::new("rustfft").expect("Operation failed");
387            assert_eq!(manager.get_backend_name(), "rustfft");
388        }
389
390        // Backend should be restored
391        assert_eq!(manager.get_backend_name(), original);
392    }
393}