1use crate::error::{FFTError, FFTResult};
7use crate::sparse_fft::{
8 SparseFFTAlgorithm, SparseFFTConfig, SparseFFTResult, SparsityEstimationMethod, WindowFunction,
9};
10use scirs2_core::numeric::Complex64;
11use scirs2_core::numeric::NumCast;
12use std::fmt::Debug;
13use std::time::Instant;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum GPUBackend {
18 CUDA,
20 HIP,
22 SYCL,
24 CPUFallback,
26}
27
28#[derive(Debug, Clone)]
30pub struct GPUSparseFFTConfig {
31 pub base_config: SparseFFTConfig,
33 pub backend: GPUBackend,
35 pub device_id: i32,
37 pub batch_size: usize,
39 pub max_memory: usize,
41 pub use_mixed_precision: bool,
43 pub use_inplace: bool,
45 pub stream_count: usize,
47}
48
49impl Default for GPUSparseFFTConfig {
50 fn default() -> Self {
51 Self {
52 base_config: SparseFFTConfig::default(),
53 backend: GPUBackend::CPUFallback,
54 device_id: -1,
55 batch_size: 1,
56 max_memory: 0,
57 use_mixed_precision: false,
58 use_inplace: true,
59 stream_count: 1,
60 }
61 }
62}
63
64pub struct GPUSparseFFT {
66 _config: GPUSparseFFTConfig,
68 gpu_initialized: bool,
70 device_info: Option<String>,
72}
73
74impl GPUSparseFFT {
75 pub fn new(config: GPUSparseFFTConfig) -> Self {
77 Self {
78 _config: config,
79 gpu_initialized: false,
80 device_info: None,
81 }
82 }
83
84 pub fn with_default_config() -> Self {
86 Self::new(GPUSparseFFTConfig::default())
87 }
88
89 fn initialize_gpu(&mut self) -> FFTResult<()> {
91 match self._config.backend {
95 GPUBackend::CUDA => {
96 self.device_info = Some("CUDA GPU device (simulated)".to_string());
98 }
99 GPUBackend::HIP => {
100 self.device_info = Some("ROCm GPU device (simulated)".to_string());
102 }
103 GPUBackend::SYCL => {
104 self.device_info = Some("SYCL device (simulated)".to_string());
106 }
107 GPUBackend::CPUFallback => {
108 self.device_info = Some("CPU fallback device".to_string());
109 }
110 }
111
112 self.gpu_initialized = true;
113 Ok(())
114 }
115
116 pub fn get_device_info(&mut self) -> FFTResult<String> {
118 if !self.gpu_initialized {
119 self.initialize_gpu()?;
120 }
121
122 Ok(self
123 .device_info
124 .clone()
125 .unwrap_or_else(|| "Unknown device".to_string()))
126 }
127
128 pub fn sparse_fft<T>(&mut self, signal: &[T]) -> FFTResult<SparseFFTResult>
130 where
131 T: NumCast + Copy + Debug + 'static,
132 {
133 if !self.gpu_initialized {
134 self.initialize_gpu()?;
135 }
136
137 let start = Instant::now();
140
141 let signal_complex: Vec<Complex64> = signal
143 .iter()
144 .map(|&val| {
145 let val_f64 = NumCast::from(val).ok_or_else(|| {
146 FFTError::ValueError(format!("Could not convert {val:?} to f64"))
147 })?;
148 Ok(Complex64::new(val_f64, 0.0))
149 })
150 .collect::<FFTResult<Vec<_>>>()?;
151
152 let mut cpu_processor = crate::sparse_fft::SparseFFT::new(self._config.base_config.clone());
155 let result = cpu_processor.sparse_fft(&signal_complex)?;
156
157 let computation_time = start.elapsed();
159
160 Ok(SparseFFTResult {
162 values: result.values,
163 indices: result.indices,
164 estimated_sparsity: result.estimated_sparsity,
165 computation_time,
166 algorithm: self._config.base_config.algorithm,
167 })
168 }
169
170 pub fn batch_sparse_fft<T>(&mut self, signals: &[Vec<T>]) -> FFTResult<Vec<SparseFFTResult>>
172 where
173 T: NumCast + Copy + Debug + 'static,
174 {
175 if !self.gpu_initialized {
176 self.initialize_gpu()?;
177 }
178
179 signals
181 .iter()
182 .map(|signal| self.sparse_fft(signal))
183 .collect()
184 }
185}
186
187#[allow(dead_code)]
204pub fn gpu_sparse_fft<T>(
205 signal: &[T],
206 k: usize,
207 backend: GPUBackend,
208 algorithm: Option<SparseFFTAlgorithm>,
209 window_function: Option<WindowFunction>,
210) -> FFTResult<SparseFFTResult>
211where
212 T: NumCast + Copy + Debug + 'static,
213{
214 let base_config = SparseFFTConfig {
216 estimation_method: SparsityEstimationMethod::Manual,
217 sparsity: k,
218 algorithm: algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear),
219 window_function: window_function.unwrap_or(WindowFunction::None),
220 ..SparseFFTConfig::default()
221 };
222
223 let gpu_config = GPUSparseFFTConfig {
225 base_config,
226 backend,
227 ..GPUSparseFFTConfig::default()
228 };
229
230 let mut processor = GPUSparseFFT::new(gpu_config);
232 processor.sparse_fft(signal)
233}
234
235#[allow(dead_code)]
249pub fn gpu_batch_sparse_fft<T>(
250 signals: &[Vec<T>],
251 k: usize,
252 backend: GPUBackend,
253 algorithm: Option<SparseFFTAlgorithm>,
254 window_function: Option<WindowFunction>,
255) -> FFTResult<Vec<SparseFFTResult>>
256where
257 T: NumCast + Copy + Debug + 'static,
258{
259 let base_config = SparseFFTConfig {
261 estimation_method: SparsityEstimationMethod::Manual,
262 sparsity: k,
263 algorithm: algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear),
264 window_function: window_function.unwrap_or(WindowFunction::None),
265 ..SparseFFTConfig::default()
266 };
267
268 let gpu_config = GPUSparseFFTConfig {
270 base_config,
271 backend,
272 batch_size: signals.len(),
273 ..GPUSparseFFTConfig::default()
274 };
275
276 let mut processor = GPUSparseFFT::new(gpu_config);
278 processor.batch_sparse_fft(signals)
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use std::f64::consts::PI;
285
286 fn create_sparse_signal(n: usize, frequencies: &[(usize, f64)]) -> Vec<f64> {
288 let mut signal = vec![0.0; n];
289
290 for i in 0..n {
291 let t = 2.0 * PI * (i as f64) / (n as f64);
292 for &(freq, amp) in frequencies {
293 signal[i] += amp * (freq as f64 * t).sin();
294 }
295 }
296
297 signal
298 }
299
300 #[test]
301 fn test_gpu_sparse_fft_cpu_fallback() {
302 let n = 256;
304 let frequencies = vec![(3, 1.0), (7, 0.5), (15, 0.25)];
305 let signal = create_sparse_signal(n, &frequencies);
306
307 let result = gpu_sparse_fft(
309 &signal,
310 6,
311 GPUBackend::CPUFallback,
312 Some(SparseFFTAlgorithm::Sublinear),
313 Some(WindowFunction::Hann),
314 )
315 .unwrap();
316
317 assert!(!result.values.is_empty());
319 assert_eq!(result.algorithm, SparseFFTAlgorithm::Sublinear);
320 }
321
322 #[test]
323 fn test_gpu_batch_processing() {
324 let n = 128;
326 let signals = vec![
327 create_sparse_signal(n, &[(3, 1.0), (7, 0.5)]),
328 create_sparse_signal(n, &[(5, 1.0), (10, 0.7)]),
329 create_sparse_signal(n, &[(2, 0.8), (12, 0.6)]),
330 ];
331
332 let results = gpu_batch_sparse_fft(
334 &signals,
335 4,
336 GPUBackend::CPUFallback,
337 Some(SparseFFTAlgorithm::Sublinear),
338 None,
339 )
340 .unwrap();
341
342 assert_eq!(results.len(), signals.len());
344
345 for result in results {
347 assert!(!result.values.is_empty());
348 }
349 }
350}