1use crate::error::{FFTError, FFTResult};
8use rustfft::FftPlanner;
9use scirs2_core::numeric::Complex64;
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex, OnceLock};
12
13pub trait FftBackend: Send + Sync {
15 fn name(&self) -> &str;
17
18 fn description(&self) -> &str;
20
21 fn is_available(&self) -> bool;
23
24 fn fft(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()>;
26
27 fn ifft(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()>;
29
30 fn fft_sized(
32 &self,
33 input: &[Complex64],
34 output: &mut [Complex64],
35 size: usize,
36 ) -> FFTResult<()>;
37
38 fn ifft_sized(
40 &self,
41 input: &[Complex64],
42 output: &mut [Complex64],
43 size: usize,
44 ) -> FFTResult<()>;
45
46 fn supports_feature(&self, feature: &str) -> bool;
48}
49
50pub struct RustFftBackend {
52 planner: Arc<Mutex<FftPlanner<f64>>>,
53}
54
55impl RustFftBackend {
56 pub fn new() -> Self {
57 Self {
58 planner: Arc::new(Mutex::new(FftPlanner::new())),
59 }
60 }
61}
62
63impl Default for RustFftBackend {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl FftBackend for RustFftBackend {
70 fn name(&self) -> &str {
71 "rustfft"
72 }
73
74 fn description(&self) -> &str {
75 "Pure Rust FFT implementation using RustFFT library"
76 }
77
78 fn is_available(&self) -> bool {
79 true
80 }
81
82 fn fft(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()> {
83 self.fft_sized(input, output, input.len())
84 }
85
86 fn ifft(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()> {
87 self.ifft_sized(input, output, input.len())
88 }
89
90 fn fft_sized(
91 &self,
92 input: &[Complex64],
93 output: &mut [Complex64],
94 size: usize,
95 ) -> FFTResult<()> {
96 if input.len() != size || output.len() != size {
97 return Err(FFTError::ValueError(
98 "Input and output sizes must match the specified size".to_string(),
99 ));
100 }
101
102 let mut planner = self.planner.lock().expect("Operation failed");
104 let fft = planner.plan_fft_forward(size);
105
106 let mut buffer: Vec<rustfft::num_complex::Complex<f64>> = input
108 .iter()
109 .map(|&c| rustfft::num_complex::Complex::new(c.re, c.im))
110 .collect();
111
112 fft.process(&mut buffer);
114
115 for (i, &c) in buffer.iter().enumerate() {
117 output[i] = Complex64::new(c.re, c.im);
118 }
119
120 Ok(())
121 }
122
123 fn ifft_sized(
124 &self,
125 input: &[Complex64],
126 output: &mut [Complex64],
127 size: usize,
128 ) -> FFTResult<()> {
129 if input.len() != size || output.len() != size {
130 return Err(FFTError::ValueError(
131 "Input and output sizes must match the specified size".to_string(),
132 ));
133 }
134
135 let mut planner = self.planner.lock().expect("Operation failed");
137 let fft = planner.plan_fft_inverse(size);
138
139 let mut buffer: Vec<rustfft::num_complex::Complex<f64>> = input
141 .iter()
142 .map(|&c| rustfft::num_complex::Complex::new(c.re, c.im))
143 .collect();
144
145 fft.process(&mut buffer);
147
148 let scale = 1.0 / size as f64;
150 for (i, &c) in buffer.iter().enumerate() {
151 output[i] = Complex64::new(c.re * scale, c.im * scale);
152 }
153
154 Ok(())
155 }
156
157 fn supports_feature(&self, feature: &str) -> bool {
158 matches!(feature, "1d_fft" | "2d_fft" | "nd_fft" | "cached_plans")
159 }
160}
161
162pub struct BackendManager {
164 backends: Arc<Mutex<HashMap<String, Arc<dyn FftBackend>>>>,
165 current_backend: Arc<Mutex<String>>,
166}
167
168impl BackendManager {
169 pub fn new() -> Self {
171 let mut backends = HashMap::new();
172
173 let rustfft_backend = Arc::new(RustFftBackend::new()) as Arc<dyn FftBackend>;
175 backends.insert("rustfft".to_string(), rustfft_backend);
176
177 Self {
178 backends: Arc::new(Mutex::new(backends)),
179 current_backend: Arc::new(Mutex::new("rustfft".to_string())),
180 }
181 }
182
183 pub fn register_backend(&self, name: String, backend: Arc<dyn FftBackend>) -> FFTResult<()> {
185 let mut backends = self.backends.lock().expect("Operation failed");
186 if backends.contains_key(&name) {
187 return Err(FFTError::ValueError(format!(
188 "Backend '{name}' already exists"
189 )));
190 }
191 backends.insert(name, backend);
192 Ok(())
193 }
194
195 pub fn list_backends(&self) -> Vec<String> {
197 let backends = self.backends.lock().expect("Operation failed");
198 backends.keys().cloned().collect()
199 }
200
201 pub fn set_backend(&self, name: &str) -> FFTResult<()> {
203 let backends = self.backends.lock().expect("Operation failed");
204 if !backends.contains_key(name) {
205 return Err(FFTError::ValueError(format!("Backend '{name}' not found")));
206 }
207
208 if let Some(backend) = backends.get(name) {
210 if !backend.is_available() {
211 return Err(FFTError::ValueError(format!(
212 "Backend '{name}' is not available"
213 )));
214 }
215 }
216
217 *self.current_backend.lock().expect("Operation failed") = name.to_string();
218 Ok(())
219 }
220
221 pub fn get_backend_name(&self) -> String {
223 self.current_backend
224 .lock()
225 .expect("Operation failed")
226 .clone()
227 }
228
229 pub fn get_backend(&self) -> Arc<dyn FftBackend> {
231 let current_name = self.current_backend.lock().expect("Operation failed");
232 let backends = self.backends.lock().expect("Operation failed");
233 backends
234 .get(&*current_name)
235 .cloned()
236 .expect("Current backend should always exist")
237 }
238
239 pub fn get_backend_info(&self, name: &str) -> Option<BackendInfo> {
241 let backends = self.backends.lock().expect("Operation failed");
242 backends.get(name).map(|backend| BackendInfo {
243 name: backend.name().to_string(),
244 description: backend.description().to_string(),
245 available: backend.is_available(),
246 })
247 }
248}
249
250impl Default for BackendManager {
251 fn default() -> Self {
252 Self::new()
253 }
254}
255
256#[derive(Debug, Clone)]
258pub struct BackendInfo {
259 pub name: String,
260 pub description: String,
261 pub available: bool,
262}
263
264impl std::fmt::Display for BackendInfo {
265 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 write!(
267 f,
268 "{} - {} ({})",
269 self.name,
270 self.description,
271 if self.available {
272 "available"
273 } else {
274 "not available"
275 }
276 )
277 }
278}
279
280static GLOBAL_BACKEND_MANAGER: OnceLock<BackendManager> = OnceLock::new();
282
283#[allow(dead_code)]
285pub fn get_backend_manager() -> &'static BackendManager {
286 GLOBAL_BACKEND_MANAGER.get_or_init(BackendManager::new)
287}
288
289#[allow(dead_code)]
291pub fn init_backend_manager(manager: BackendManager) -> Result<(), &'static str> {
292 GLOBAL_BACKEND_MANAGER
293 .set(manager)
294 .map_err(|_| "Global backend _manager already initialized")
295}
296
297#[allow(dead_code)]
299pub fn list_backends() -> Vec<String> {
300 get_backend_manager().list_backends()
301}
302
303#[allow(dead_code)]
305pub fn set_backend(name: &str) -> FFTResult<()> {
306 get_backend_manager().set_backend(name)
307}
308
309#[allow(dead_code)]
311pub fn get_backend_name() -> String {
312 get_backend_manager().get_backend_name()
313}
314
315#[allow(dead_code)]
317pub fn get_backend_info(name: &str) -> Option<BackendInfo> {
318 get_backend_manager().get_backend_info(name)
319}
320
321pub struct BackendContext {
323 previous_backend: String,
324 manager: &'static BackendManager,
325}
326
327impl BackendContext {
328 pub fn new(_backendname: &str) -> FFTResult<Self> {
330 let manager = get_backend_manager();
331 let previous_backend = manager.get_backend_name();
332
333 manager.set_backend(_backendname)?;
335
336 Ok(Self {
337 previous_backend,
338 manager,
339 })
340 }
341}
342
343impl Drop for BackendContext {
344 fn drop(&mut self) {
345 let _ = self.manager.set_backend(&self.previous_backend);
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_rustfft_backend() {
356 let backend = RustFftBackend::new();
357 assert_eq!(backend.name(), "rustfft");
358 assert!(backend.is_available());
359 assert!(backend.supports_feature("1d_fft"));
360 }
361
362 #[test]
363 fn test_backend_manager() {
364 let manager = BackendManager::new();
365
366 assert_eq!(manager.get_backend_name(), "rustfft");
368
369 let backends = manager.list_backends();
371 assert!(backends.contains(&"rustfft".to_string()));
372
373 let info = manager
375 .get_backend_info("rustfft")
376 .expect("Operation failed");
377 assert!(info.available);
378 }
379
380 #[test]
381 fn test_backend_context() {
382 let manager = get_backend_manager();
383 let original = manager.get_backend_name();
384
385 {
386 let _ctx = BackendContext::new("rustfft").expect("Operation failed");
387 assert_eq!(manager.get_backend_name(), "rustfft");
388 }
389
390 assert_eq!(manager.get_backend_name(), original);
392 }
393}