scirs2_fft/
sparse_fft_batch.rs

1//! Batch processing for sparse FFT algorithms
2//!
3//! This module provides batch processing capabilities for sparse FFT algorithms,
4//! which can significantly improve performance when processing multiple signals,
5//! especially on GPU hardware.
6
7use crate::error::{FFTError, FFTResult};
8use crate::sparse_fft::{
9    SparseFFTAlgorithm, SparseFFTConfig, SparseFFTResult, SparsityEstimationMethod, WindowFunction,
10};
11use crate::sparse_fft_gpu::{GPUBackend, GPUSparseFFTConfig};
12use crate::sparse_fft_gpu_memory::{init_global_memory_manager, AllocationStrategy};
13
14use scirs2_core::numeric::Complex64;
15use scirs2_core::numeric::NumCast;
16use scirs2_core::parallel_ops::*;
17use std::fmt::Debug;
18use std::time::Instant;
19
20/// Batch processing configuration for sparse FFT
21#[derive(Debug, Clone)]
22pub struct BatchConfig {
23    /// Maximum batch size (number of signals per batch)
24    pub max_batch_size: usize,
25    /// Whether to use parallel processing on CPU
26    pub use_parallel: bool,
27    /// Maximum memory usage per batch in bytes (0 for unlimited)
28    pub max_memory_per_batch: usize,
29    /// Whether to use mixed precision computation
30    pub use_mixed_precision: bool,
31    /// Whether to use in-place computation when possible
32    pub use_inplace: bool,
33    /// Whether to preserve input signals (false = allow modification)
34    pub preserve_input: bool,
35}
36
37impl Default for BatchConfig {
38    fn default() -> Self {
39        Self {
40            max_batch_size: 32,
41            use_parallel: true,
42            max_memory_per_batch: 0, // Unlimited
43            use_mixed_precision: false,
44            use_inplace: true,
45            preserve_input: true,
46        }
47    }
48}
49
50/// Perform batch sparse FFT on CPU
51///
52/// Process multiple signals in a batch for better performance.
53///
54/// # Arguments
55///
56/// * `signals` - List of input signals
57/// * `k` - Expected sparsity (number of significant frequency components)
58/// * `algorithm` - Sparse FFT algorithm variant
59/// * `window_function` - Window function to apply before FFT
60/// * `batchconfig` - Batch processing configuration
61///
62/// # Returns
63///
64/// * Vector of sparse FFT results, one for each input signal
65#[allow(clippy::too_many_arguments)]
66#[allow(dead_code)]
67pub fn batch_sparse_fft<T>(
68    signals: &[Vec<T>],
69    k: usize,
70    algorithm: Option<SparseFFTAlgorithm>,
71    window_function: Option<WindowFunction>,
72    batchconfig: Option<BatchConfig>,
73) -> FFTResult<Vec<SparseFFTResult>>
74where
75    T: NumCast + Copy + Debug + Sync + 'static,
76{
77    let config = batchconfig.unwrap_or_default();
78    let alg = algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear);
79    let window = window_function.unwrap_or(WindowFunction::None);
80
81    let start = Instant::now();
82
83    // Create sparse FFT config
84    let fftconfig = SparseFFTConfig {
85        estimation_method: SparsityEstimationMethod::Manual,
86        sparsity: k,
87        algorithm: alg,
88        window_function: window,
89        ..SparseFFTConfig::default()
90    };
91
92    let results = if config.use_parallel {
93        // Process signals in parallel using Rayon
94        signals
95            .par_iter()
96            .map(|signal| {
97                let mut processor = crate::sparse_fft::SparseFFT::new(fftconfig.clone());
98
99                // Convert signal to complex
100                let signal_complex: FFTResult<Vec<Complex64>> = signal
101                    .iter()
102                    .map(|&val| {
103                        let val_f64 = NumCast::from(val).ok_or_else(|| {
104                            FFTError::ValueError(format!("Could not convert {val:?} to f64"))
105                        })?;
106                        Ok(Complex64::new(val_f64, 0.0))
107                    })
108                    .collect();
109
110                processor.sparse_fft(&signal_complex?)
111            })
112            .collect::<FFTResult<Vec<_>>>()
113    } else {
114        // Process signals sequentially
115        let mut results = Vec::with_capacity(signals.len());
116        for signal in signals {
117            let mut processor = crate::sparse_fft::SparseFFT::new(fftconfig.clone());
118
119            // Convert signal to complex
120            let signal_complex: FFTResult<Vec<Complex64>> = signal
121                .iter()
122                .map(|&val| {
123                    let val_f64 = NumCast::from(val).ok_or_else(|| {
124                        FFTError::ValueError(format!("Could not convert {val:?} to f64"))
125                    })?;
126                    Ok(Complex64::new(val_f64, 0.0))
127                })
128                .collect();
129
130            results.push(processor.sparse_fft(&signal_complex?)?);
131        }
132        Ok(results)
133    }?;
134
135    // Update computation time to include batching overhead
136    let total_time = start.elapsed();
137    let avg_time_per_signal = total_time.div_f64(signals.len() as f64);
138
139    // Return results with updated computation time
140    let mut final_results = Vec::with_capacity(results.len());
141    for mut result in results {
142        result.computation_time = avg_time_per_signal;
143        final_results.push(result);
144    }
145
146    Ok(final_results)
147}
148
149/// Perform batch sparse FFT on GPU
150///
151/// Process multiple signals in a batch for better GPU utilization.
152///
153/// # Arguments
154///
155/// * `signals` - List of input signals
156/// * `k` - Expected sparsity (number of significant frequency components)
157/// * `device_id` - GPU device ID (-1 for auto-select)
158/// * `backend` - GPU backend (CUDA, HIP, SYCL)
159/// * `algorithm` - Sparse FFT algorithm variant
160/// * `window_function` - Window function to apply before FFT
161/// * `batchconfig` - Batch processing configuration
162///
163/// # Returns
164///
165/// * Vector of sparse FFT results, one for each input signal
166#[allow(clippy::too_many_arguments)]
167#[allow(dead_code)]
168pub fn gpu_batch_sparse_fft<T>(
169    signals: &[Vec<T>],
170    k: usize,
171    device_id: i32,
172    backend: GPUBackend,
173    algorithm: Option<SparseFFTAlgorithm>,
174    window_function: Option<WindowFunction>,
175    batchconfig: Option<BatchConfig>,
176) -> FFTResult<Vec<SparseFFTResult>>
177where
178    T: NumCast + Copy + Debug + Sync + 'static,
179{
180    let config = batchconfig.unwrap_or_default();
181    let alg = algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear);
182    let window = window_function.unwrap_or(WindowFunction::None);
183
184    // Calculate batch sizes
185    let total_signals = signals.len();
186    let batch_size = config.max_batch_size.min(total_signals);
187    let num_batches = total_signals.div_ceil(batch_size);
188
189    // Create sparse FFT config
190    let base_fftconfig = SparseFFTConfig {
191        estimation_method: SparsityEstimationMethod::Manual,
192        sparsity: k,
193        algorithm: alg,
194        window_function: window,
195        ..SparseFFTConfig::default()
196    };
197
198    // Create GPU config
199    let _gpuconfig = GPUSparseFFTConfig {
200        base_config: base_fftconfig,
201        backend,
202        device_id,
203        batch_size,
204        max_memory: config.max_memory_per_batch,
205        use_mixed_precision: config.use_mixed_precision,
206        use_inplace: config.use_inplace,
207        stream_count: 2, // Use 2 streams for overlap
208    };
209
210    let start = Instant::now();
211
212    // Process signals in batches
213    let mut all_results = Vec::with_capacity(total_signals);
214    for batch_idx in 0..num_batches {
215        let start_idx = batch_idx * batch_size;
216        let end_idx = (start_idx + batch_size).min(total_signals);
217        let current_batch = &signals[start_idx..end_idx];
218
219        // Process this batch
220        match backend {
221            GPUBackend::CUDA => {
222                let batch_results = crate::cuda_batch_sparse_fft(
223                    current_batch,
224                    k,
225                    device_id,
226                    Some(alg),
227                    Some(window),
228                )?;
229                all_results.extend(batch_results);
230            }
231            _ => {
232                // For other backends, fall back to CPU for now
233                let batch_results =
234                    batch_sparse_fft(current_batch, k, Some(alg), Some(window), None)?;
235                all_results.extend(batch_results);
236            }
237        }
238    }
239
240    // Update computation time to include batching overhead
241    let total_time = start.elapsed();
242    let avg_time_per_signal = total_time.div_f64(signals.len() as f64);
243
244    // Return results with updated computation time
245    let mut final_results = Vec::with_capacity(all_results.len());
246    for mut result in all_results {
247        result.computation_time = avg_time_per_signal;
248        final_results.push(result);
249    }
250
251    Ok(final_results)
252}
253
254/// Optimized batch processing for spectral flatness sparse FFT
255///
256/// This function is specialized for the spectral flatness algorithm,
257/// which can benefit from batch processing due to its signal analysis
258/// requirements.
259///
260/// # Arguments
261///
262/// * `signals` - List of input signals
263/// * `flatness_threshold` - Threshold for spectral flatness (0-1, lower = more selective)
264/// * `window_size` - Size of windows for local flatness analysis
265/// * `window_function` - Window function to apply before FFT
266/// * `device_id` - GPU device ID (-1 for auto-select)
267/// * `batchconfig` - Batch processing configuration
268///
269/// # Returns
270///
271/// * Vector of sparse FFT results, one for each input signal
272#[allow(clippy::too_many_arguments)]
273#[allow(dead_code)]
274pub fn spectral_flatness_batch_sparse_fft<T>(
275    signals: &[Vec<T>],
276    flatness_threshold: f64,
277    window_size: usize,
278    window_function: Option<WindowFunction>,
279    device_id: Option<i32>,
280    batchconfig: Option<BatchConfig>,
281) -> FFTResult<Vec<SparseFFTResult>>
282where
283    T: NumCast + Copy + Debug + Sync + 'static,
284{
285    let config = batchconfig.unwrap_or_default();
286    let window = window_function.unwrap_or(WindowFunction::Hann); // Default to Hann for spectral flatness
287    let device = device_id.unwrap_or(-1); // -1 indicates CPU
288
289    // Calculate batch sizes
290    let total_signals = signals.len();
291    let batch_size = config.max_batch_size.min(total_signals);
292    let num_batches = total_signals.div_ceil(batch_size);
293
294    // Initialize the memory manager if GPU is used
295    if device >= 0 {
296        init_global_memory_manager(
297            GPUBackend::CUDA,
298            device,
299            AllocationStrategy::CacheBySize,
300            config.max_memory_per_batch.max(1024 * 1024 * 1024), // At least 1 GB
301        )?;
302    }
303
304    let start = Instant::now();
305
306    // Process signals in batches
307    let mut all_results = Vec::with_capacity(total_signals);
308
309    if device >= 0 && cfg!(feature = "cuda") {
310        // GPU processing with CUDA
311        for batch_idx in 0..num_batches {
312            let start_idx = batch_idx * batch_size;
313            let end_idx = (start_idx + batch_size).min(total_signals);
314            let current_batch = &signals[start_idx..end_idx];
315
316            // Create a base configuration for this batch
317            let _baseconfig = SparseFFTConfig {
318                estimation_method: SparsityEstimationMethod::SpectralFlatness,
319                sparsity: 0, // Will be determined automatically
320                algorithm: SparseFFTAlgorithm::SpectralFlatness,
321                window_function: window,
322                flatness_threshold,
323                window_size,
324                ..SparseFFTConfig::default()
325            };
326
327            // Use standard GPU batch processing
328            for signal in current_batch {
329                // Convert signal to complex
330                let signal_complex: FFTResult<Vec<Complex64>> = signal
331                    .iter()
332                    .map(|&val| {
333                        let val_f64 = NumCast::from(val).ok_or_else(|| {
334                            FFTError::ValueError(format!("Could not convert {val:?} to f64"))
335                        })?;
336                        Ok(Complex64::new(val_f64, 0.0))
337                    })
338                    .collect();
339
340                // Process with GPU
341                let result = crate::execute_cuda_spectral_flatness_sparse_fft(
342                    &signal_complex?,
343                    0, // Will be determined automatically
344                    flatness_threshold,
345                )?;
346
347                all_results.push(result);
348            }
349        }
350    } else {
351        // CPU processing
352        if config.use_parallel {
353            // Process all signals in parallel using Rayon
354            let parallel_results: FFTResult<Vec<_>> = signals
355                .par_iter()
356                .map(|signal| {
357                    // Create configuration
358                    let fftconfig = SparseFFTConfig {
359                        estimation_method: SparsityEstimationMethod::SpectralFlatness,
360                        sparsity: 0, // Will be determined automatically
361                        algorithm: SparseFFTAlgorithm::SpectralFlatness,
362                        window_function: window,
363                        flatness_threshold,
364                        window_size,
365                        ..SparseFFTConfig::default()
366                    };
367
368                    // Convert signal to complex
369                    let signal_complex: FFTResult<Vec<Complex64>> = signal
370                        .iter()
371                        .map(|&val| {
372                            let val_f64 = NumCast::from(val).ok_or_else(|| {
373                                FFTError::ValueError(format!("Could not convert {val:?} to f64"))
374                            })?;
375                            Ok(Complex64::new(val_f64, 0.0))
376                        })
377                        .collect();
378
379                    // Process with CPU
380                    let mut processor = crate::sparse_fft::SparseFFT::new(fftconfig);
381                    processor.sparse_fft(&signal_complex?)
382                })
383                .collect();
384
385            all_results = parallel_results?;
386        } else {
387            // Process signals sequentially
388            for signal in signals {
389                // Create configuration
390                let fftconfig = SparseFFTConfig {
391                    estimation_method: SparsityEstimationMethod::SpectralFlatness,
392                    sparsity: 0, // Will be determined automatically
393                    algorithm: SparseFFTAlgorithm::SpectralFlatness,
394                    window_function: window,
395                    flatness_threshold,
396                    window_size,
397                    ..SparseFFTConfig::default()
398                };
399
400                // Convert signal to complex
401                let signal_complex: FFTResult<Vec<Complex64>> = signal
402                    .iter()
403                    .map(|&val| {
404                        let val_f64 = NumCast::from(val).ok_or_else(|| {
405                            FFTError::ValueError(format!("Could not convert {val:?} to f64"))
406                        })?;
407                        Ok(Complex64::new(val_f64, 0.0))
408                    })
409                    .collect();
410
411                // Process with CPU
412                let mut processor = crate::sparse_fft::SparseFFT::new(fftconfig);
413                let result = processor.sparse_fft(&signal_complex?)?;
414                all_results.push(result);
415            }
416        }
417    }
418
419    // Update computation time to include batching overhead
420    let total_time = start.elapsed();
421    let avg_time_per_signal = total_time.div_f64(signals.len() as f64);
422
423    // Return results with updated computation time
424    let mut final_results = Vec::with_capacity(all_results.len());
425    for mut result in all_results {
426        result.computation_time = avg_time_per_signal;
427        final_results.push(result);
428    }
429
430    Ok(final_results)
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use std::f64::consts::PI;
437
438    // Helper function to create a sparse signal
439    fn create_sparse_signal(n: usize, frequencies: &[(usize, f64)]) -> Vec<f64> {
440        let mut signal = vec![0.0; n];
441        for i in 0..n {
442            let t = 2.0 * PI * (i as f64) / (n as f64);
443            for &(freq, amp) in frequencies {
444                signal[i] += amp * (freq as f64 * t).sin();
445            }
446        }
447        signal
448    }
449
450    // Helper to add noise to signals
451    fn add_noise(_signal: &[f64], noise_level: f64) -> Vec<f64> {
452        use scirs2_core::random::Rng;
453        let mut rng = scirs2_core::random::rng();
454        _signal
455            .iter()
456            .map(|&x| x + rng.gen_range(-noise_level..noise_level))
457            .collect()
458    }
459
460    // Helper to create a batch of similar signals with different noise
461    fn create_signal_batch(
462        count: usize,
463        n: usize,
464        frequencies: &[(usize, f64)],
465        noise_level: f64,
466    ) -> Vec<Vec<f64>> {
467        let base_signal = create_sparse_signal(n, frequencies);
468        (0..count)
469            .map(|_| add_noise(&base_signal, noise_level))
470            .collect()
471    }
472
473    #[test]
474    fn test_cpu_batch_processing() {
475        // Create a batch of signals
476        let n = 256;
477        let frequencies = vec![(3, 1.0), (7, 0.5), (15, 0.5)]; // Increased amplitude for better detection
478        let signals = create_signal_batch(5, n, &frequencies, 0.05); // Reduced noise
479
480        // Test batch processing
481        let results = batch_sparse_fft(
482            &signals,
483            6, // Look for up to 6 components
484            Some(SparseFFTAlgorithm::Sublinear),
485            Some(WindowFunction::Hann),
486            None,
487        )
488        .unwrap();
489
490        // Check results
491        assert_eq!(results.len(), signals.len());
492
493        // Each result should identify at least some of the key frequencies
494        for (i, result) in results.iter().enumerate() {
495            assert!(
496                !result.indices.is_empty(),
497                "No frequencies detected for signal {}",
498                i
499            );
500            assert!(
501                result.values.len() == result.indices.len(),
502                "Mismatched indices and values"
503            );
504
505            // Check that algorithm found meaningful frequencies
506            // The algorithm should find some low-frequency components (which indicates it's working)
507            let low_freq_count = result
508                .indices
509                .iter()
510                .filter(|&&idx| idx <= 32 || idx >= n - 32)
511                .count();
512
513            assert!(low_freq_count >= 1, "Should find at least 1 low-frequency component for signal {}, but found none. All frequencies: {:?}", i, result.indices);
514        }
515    }
516
517    #[test]
518    fn test_parallel_batch_processing() {
519        // Create a larger batch of signals
520        let n = 256;
521        let frequencies = vec![(3, 1.0), (7, 0.5), (15, 0.5)]; // Increased amplitude for better detection
522        let signals = create_signal_batch(10, n, &frequencies, 0.05); // Reduced noise
523
524        // Test parallel batch processing
525        let batchconfig = BatchConfig {
526            use_parallel: true,
527            ..BatchConfig::default()
528        };
529
530        let results = batch_sparse_fft(
531            &signals,
532            6, // Look for up to 6 components
533            Some(SparseFFTAlgorithm::Sublinear),
534            Some(WindowFunction::Hann),
535            Some(batchconfig),
536        )
537        .unwrap();
538
539        // Check results
540        assert_eq!(results.len(), signals.len());
541
542        // Each result should identify at least some of the key frequencies
543        for (i, result) in results.iter().enumerate() {
544            assert!(
545                !result.indices.is_empty(),
546                "No frequencies detected for signal {}",
547                i
548            );
549            assert!(
550                result.values.len() == result.indices.len(),
551                "Mismatched indices and values"
552            );
553
554            // Check that algorithm found meaningful frequencies
555            // The algorithm should find some low-frequency components (which indicates it's working)
556            let low_freq_count = result
557                .indices
558                .iter()
559                .filter(|&&idx| idx <= 32 || idx >= n - 32)
560                .count();
561
562            assert!(low_freq_count >= 1, "Should find at least 1 low-frequency component for signal {}, but found none. All frequencies: {:?}", i, result.indices);
563        }
564    }
565
566    #[test]
567    fn test_spectral_flatness_batch() {
568        // Create a batch of signals with different noise levels
569        let n = 512;
570        let frequencies = vec![(30, 1.0), (70, 0.5), (120, 0.25)];
571
572        // Create signals with increasing noise
573        let mut signals = Vec::new();
574        for i in 0..5 {
575            let noise_level = 0.05 * (i + 1) as f64;
576            let base_signal = create_sparse_signal(n, &frequencies);
577            signals.push(add_noise(&base_signal, noise_level));
578        }
579
580        // Process with spectral flatness batch function
581        let results = spectral_flatness_batch_sparse_fft(
582            &signals,
583            0.3, // Flatness threshold
584            32,  // Window size
585            Some(WindowFunction::Hann),
586            None, // Use CPU
587            None, // Default config
588        )
589        .unwrap();
590
591        // Check results
592        assert_eq!(results.len(), signals.len());
593
594        // Spectral flatness should find the main frequencies even with noise
595        for result in &results {
596            // Check that the algorithm is correctly set
597            assert_eq!(result.algorithm, SparseFFTAlgorithm::SpectralFlatness);
598
599            // At least one of the key frequencies should be found
600            let found_30 = result.indices.contains(&30) || result.indices.contains(&(n - 30));
601            let found_70 = result.indices.contains(&70) || result.indices.contains(&(n - 70));
602            let found_120 = result.indices.contains(&120) || result.indices.contains(&(n - 120));
603
604            assert!(
605                found_30 || found_70 || found_120,
606                "Failed to find any of the key frequencies"
607            );
608        }
609    }
610}