1use crate::error::{FFTError, FFTResult};
8use crate::sparse_fft::{
9 SparseFFTAlgorithm, SparseFFTConfig, SparseFFTResult, SparsityEstimationMethod, WindowFunction,
10};
11use crate::sparse_fft_gpu::{GPUBackend, GPUSparseFFTConfig};
12use crate::sparse_fft_gpu_memory::{init_global_memory_manager, AllocationStrategy};
13
14use scirs2_core::numeric::Complex64;
15use scirs2_core::numeric::NumCast;
16use scirs2_core::parallel_ops::*;
17use std::fmt::Debug;
18use std::time::Instant;
19
20#[derive(Debug, Clone)]
22pub struct BatchConfig {
23 pub max_batch_size: usize,
25 pub use_parallel: bool,
27 pub max_memory_per_batch: usize,
29 pub use_mixed_precision: bool,
31 pub use_inplace: bool,
33 pub preserve_input: bool,
35}
36
37impl Default for BatchConfig {
38 fn default() -> Self {
39 Self {
40 max_batch_size: 32,
41 use_parallel: true,
42 max_memory_per_batch: 0, use_mixed_precision: false,
44 use_inplace: true,
45 preserve_input: true,
46 }
47 }
48}
49
50#[allow(clippy::too_many_arguments)]
66#[allow(dead_code)]
67pub fn batch_sparse_fft<T>(
68 signals: &[Vec<T>],
69 k: usize,
70 algorithm: Option<SparseFFTAlgorithm>,
71 window_function: Option<WindowFunction>,
72 batchconfig: Option<BatchConfig>,
73) -> FFTResult<Vec<SparseFFTResult>>
74where
75 T: NumCast + Copy + Debug + Sync + 'static,
76{
77 let config = batchconfig.unwrap_or_default();
78 let alg = algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear);
79 let window = window_function.unwrap_or(WindowFunction::None);
80
81 let start = Instant::now();
82
83 let fftconfig = SparseFFTConfig {
85 estimation_method: SparsityEstimationMethod::Manual,
86 sparsity: k,
87 algorithm: alg,
88 window_function: window,
89 ..SparseFFTConfig::default()
90 };
91
92 let results = if config.use_parallel {
93 signals
95 .par_iter()
96 .map(|signal| {
97 let mut processor = crate::sparse_fft::SparseFFT::new(fftconfig.clone());
98
99 let signal_complex: FFTResult<Vec<Complex64>> = signal
101 .iter()
102 .map(|&val| {
103 let val_f64 = NumCast::from(val).ok_or_else(|| {
104 FFTError::ValueError(format!("Could not convert {val:?} to f64"))
105 })?;
106 Ok(Complex64::new(val_f64, 0.0))
107 })
108 .collect();
109
110 processor.sparse_fft(&signal_complex?)
111 })
112 .collect::<FFTResult<Vec<_>>>()
113 } else {
114 let mut results = Vec::with_capacity(signals.len());
116 for signal in signals {
117 let mut processor = crate::sparse_fft::SparseFFT::new(fftconfig.clone());
118
119 let signal_complex: FFTResult<Vec<Complex64>> = signal
121 .iter()
122 .map(|&val| {
123 let val_f64 = NumCast::from(val).ok_or_else(|| {
124 FFTError::ValueError(format!("Could not convert {val:?} to f64"))
125 })?;
126 Ok(Complex64::new(val_f64, 0.0))
127 })
128 .collect();
129
130 results.push(processor.sparse_fft(&signal_complex?)?);
131 }
132 Ok(results)
133 }?;
134
135 let total_time = start.elapsed();
137 let avg_time_per_signal = total_time.div_f64(signals.len() as f64);
138
139 let mut final_results = Vec::with_capacity(results.len());
141 for mut result in results {
142 result.computation_time = avg_time_per_signal;
143 final_results.push(result);
144 }
145
146 Ok(final_results)
147}
148
149#[allow(clippy::too_many_arguments)]
167#[allow(dead_code)]
168pub fn gpu_batch_sparse_fft<T>(
169 signals: &[Vec<T>],
170 k: usize,
171 device_id: i32,
172 backend: GPUBackend,
173 algorithm: Option<SparseFFTAlgorithm>,
174 window_function: Option<WindowFunction>,
175 batchconfig: Option<BatchConfig>,
176) -> FFTResult<Vec<SparseFFTResult>>
177where
178 T: NumCast + Copy + Debug + Sync + 'static,
179{
180 let config = batchconfig.unwrap_or_default();
181 let alg = algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear);
182 let window = window_function.unwrap_or(WindowFunction::None);
183
184 let total_signals = signals.len();
186 let batch_size = config.max_batch_size.min(total_signals);
187 let num_batches = total_signals.div_ceil(batch_size);
188
189 let base_fftconfig = SparseFFTConfig {
191 estimation_method: SparsityEstimationMethod::Manual,
192 sparsity: k,
193 algorithm: alg,
194 window_function: window,
195 ..SparseFFTConfig::default()
196 };
197
198 let _gpuconfig = GPUSparseFFTConfig {
200 base_config: base_fftconfig,
201 backend,
202 device_id,
203 batch_size,
204 max_memory: config.max_memory_per_batch,
205 use_mixed_precision: config.use_mixed_precision,
206 use_inplace: config.use_inplace,
207 stream_count: 2, };
209
210 let start = Instant::now();
211
212 let mut all_results = Vec::with_capacity(total_signals);
214 for batch_idx in 0..num_batches {
215 let start_idx = batch_idx * batch_size;
216 let end_idx = (start_idx + batch_size).min(total_signals);
217 let current_batch = &signals[start_idx..end_idx];
218
219 match backend {
221 GPUBackend::CUDA => {
222 let batch_results = crate::cuda_batch_sparse_fft(
223 current_batch,
224 k,
225 device_id,
226 Some(alg),
227 Some(window),
228 )?;
229 all_results.extend(batch_results);
230 }
231 _ => {
232 let batch_results =
234 batch_sparse_fft(current_batch, k, Some(alg), Some(window), None)?;
235 all_results.extend(batch_results);
236 }
237 }
238 }
239
240 let total_time = start.elapsed();
242 let avg_time_per_signal = total_time.div_f64(signals.len() as f64);
243
244 let mut final_results = Vec::with_capacity(all_results.len());
246 for mut result in all_results {
247 result.computation_time = avg_time_per_signal;
248 final_results.push(result);
249 }
250
251 Ok(final_results)
252}
253
254#[allow(clippy::too_many_arguments)]
273#[allow(dead_code)]
274pub fn spectral_flatness_batch_sparse_fft<T>(
275 signals: &[Vec<T>],
276 flatness_threshold: f64,
277 window_size: usize,
278 window_function: Option<WindowFunction>,
279 device_id: Option<i32>,
280 batchconfig: Option<BatchConfig>,
281) -> FFTResult<Vec<SparseFFTResult>>
282where
283 T: NumCast + Copy + Debug + Sync + 'static,
284{
285 let config = batchconfig.unwrap_or_default();
286 let window = window_function.unwrap_or(WindowFunction::Hann); let device = device_id.unwrap_or(-1); let total_signals = signals.len();
291 let batch_size = config.max_batch_size.min(total_signals);
292 let num_batches = total_signals.div_ceil(batch_size);
293
294 if device >= 0 {
296 init_global_memory_manager(
297 GPUBackend::CUDA,
298 device,
299 AllocationStrategy::CacheBySize,
300 config.max_memory_per_batch.max(1024 * 1024 * 1024), )?;
302 }
303
304 let start = Instant::now();
305
306 let mut all_results = Vec::with_capacity(total_signals);
308
309 if device >= 0 && cfg!(feature = "cuda") {
310 for batch_idx in 0..num_batches {
312 let start_idx = batch_idx * batch_size;
313 let end_idx = (start_idx + batch_size).min(total_signals);
314 let current_batch = &signals[start_idx..end_idx];
315
316 let _baseconfig = SparseFFTConfig {
318 estimation_method: SparsityEstimationMethod::SpectralFlatness,
319 sparsity: 0, algorithm: SparseFFTAlgorithm::SpectralFlatness,
321 window_function: window,
322 flatness_threshold,
323 window_size,
324 ..SparseFFTConfig::default()
325 };
326
327 for signal in current_batch {
329 let signal_complex: FFTResult<Vec<Complex64>> = signal
331 .iter()
332 .map(|&val| {
333 let val_f64 = NumCast::from(val).ok_or_else(|| {
334 FFTError::ValueError(format!("Could not convert {val:?} to f64"))
335 })?;
336 Ok(Complex64::new(val_f64, 0.0))
337 })
338 .collect();
339
340 let result = crate::execute_cuda_spectral_flatness_sparse_fft(
342 &signal_complex?,
343 0, flatness_threshold,
345 )?;
346
347 all_results.push(result);
348 }
349 }
350 } else {
351 if config.use_parallel {
353 let parallel_results: FFTResult<Vec<_>> = signals
355 .par_iter()
356 .map(|signal| {
357 let fftconfig = SparseFFTConfig {
359 estimation_method: SparsityEstimationMethod::SpectralFlatness,
360 sparsity: 0, algorithm: SparseFFTAlgorithm::SpectralFlatness,
362 window_function: window,
363 flatness_threshold,
364 window_size,
365 ..SparseFFTConfig::default()
366 };
367
368 let signal_complex: FFTResult<Vec<Complex64>> = signal
370 .iter()
371 .map(|&val| {
372 let val_f64 = NumCast::from(val).ok_or_else(|| {
373 FFTError::ValueError(format!("Could not convert {val:?} to f64"))
374 })?;
375 Ok(Complex64::new(val_f64, 0.0))
376 })
377 .collect();
378
379 let mut processor = crate::sparse_fft::SparseFFT::new(fftconfig);
381 processor.sparse_fft(&signal_complex?)
382 })
383 .collect();
384
385 all_results = parallel_results?;
386 } else {
387 for signal in signals {
389 let fftconfig = SparseFFTConfig {
391 estimation_method: SparsityEstimationMethod::SpectralFlatness,
392 sparsity: 0, algorithm: SparseFFTAlgorithm::SpectralFlatness,
394 window_function: window,
395 flatness_threshold,
396 window_size,
397 ..SparseFFTConfig::default()
398 };
399
400 let signal_complex: FFTResult<Vec<Complex64>> = signal
402 .iter()
403 .map(|&val| {
404 let val_f64 = NumCast::from(val).ok_or_else(|| {
405 FFTError::ValueError(format!("Could not convert {val:?} to f64"))
406 })?;
407 Ok(Complex64::new(val_f64, 0.0))
408 })
409 .collect();
410
411 let mut processor = crate::sparse_fft::SparseFFT::new(fftconfig);
413 let result = processor.sparse_fft(&signal_complex?)?;
414 all_results.push(result);
415 }
416 }
417 }
418
419 let total_time = start.elapsed();
421 let avg_time_per_signal = total_time.div_f64(signals.len() as f64);
422
423 let mut final_results = Vec::with_capacity(all_results.len());
425 for mut result in all_results {
426 result.computation_time = avg_time_per_signal;
427 final_results.push(result);
428 }
429
430 Ok(final_results)
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use std::f64::consts::PI;
437
438 fn create_sparse_signal(n: usize, frequencies: &[(usize, f64)]) -> Vec<f64> {
440 let mut signal = vec![0.0; n];
441 for i in 0..n {
442 let t = 2.0 * PI * (i as f64) / (n as f64);
443 for &(freq, amp) in frequencies {
444 signal[i] += amp * (freq as f64 * t).sin();
445 }
446 }
447 signal
448 }
449
450 fn add_noise(_signal: &[f64], noise_level: f64) -> Vec<f64> {
452 use scirs2_core::random::Rng;
453 let mut rng = scirs2_core::random::rng();
454 _signal
455 .iter()
456 .map(|&x| x + rng.gen_range(-noise_level..noise_level))
457 .collect()
458 }
459
460 fn create_signal_batch(
462 count: usize,
463 n: usize,
464 frequencies: &[(usize, f64)],
465 noise_level: f64,
466 ) -> Vec<Vec<f64>> {
467 let base_signal = create_sparse_signal(n, frequencies);
468 (0..count)
469 .map(|_| add_noise(&base_signal, noise_level))
470 .collect()
471 }
472
473 #[test]
474 fn test_cpu_batch_processing() {
475 let n = 256;
477 let frequencies = vec![(3, 1.0), (7, 0.5), (15, 0.5)]; let signals = create_signal_batch(5, n, &frequencies, 0.05); let results = batch_sparse_fft(
482 &signals,
483 6, Some(SparseFFTAlgorithm::Sublinear),
485 Some(WindowFunction::Hann),
486 None,
487 )
488 .unwrap();
489
490 assert_eq!(results.len(), signals.len());
492
493 for (i, result) in results.iter().enumerate() {
495 assert!(
496 !result.indices.is_empty(),
497 "No frequencies detected for signal {}",
498 i
499 );
500 assert!(
501 result.values.len() == result.indices.len(),
502 "Mismatched indices and values"
503 );
504
505 let low_freq_count = result
508 .indices
509 .iter()
510 .filter(|&&idx| idx <= 32 || idx >= n - 32)
511 .count();
512
513 assert!(low_freq_count >= 1, "Should find at least 1 low-frequency component for signal {}, but found none. All frequencies: {:?}", i, result.indices);
514 }
515 }
516
517 #[test]
518 fn test_parallel_batch_processing() {
519 let n = 256;
521 let frequencies = vec![(3, 1.0), (7, 0.5), (15, 0.5)]; let signals = create_signal_batch(10, n, &frequencies, 0.05); let batchconfig = BatchConfig {
526 use_parallel: true,
527 ..BatchConfig::default()
528 };
529
530 let results = batch_sparse_fft(
531 &signals,
532 6, Some(SparseFFTAlgorithm::Sublinear),
534 Some(WindowFunction::Hann),
535 Some(batchconfig),
536 )
537 .unwrap();
538
539 assert_eq!(results.len(), signals.len());
541
542 for (i, result) in results.iter().enumerate() {
544 assert!(
545 !result.indices.is_empty(),
546 "No frequencies detected for signal {}",
547 i
548 );
549 assert!(
550 result.values.len() == result.indices.len(),
551 "Mismatched indices and values"
552 );
553
554 let low_freq_count = result
557 .indices
558 .iter()
559 .filter(|&&idx| idx <= 32 || idx >= n - 32)
560 .count();
561
562 assert!(low_freq_count >= 1, "Should find at least 1 low-frequency component for signal {}, but found none. All frequencies: {:?}", i, result.indices);
563 }
564 }
565
566 #[test]
567 fn test_spectral_flatness_batch() {
568 let n = 512;
570 let frequencies = vec![(30, 1.0), (70, 0.5), (120, 0.25)];
571
572 let mut signals = Vec::new();
574 for i in 0..5 {
575 let noise_level = 0.05 * (i + 1) as f64;
576 let base_signal = create_sparse_signal(n, &frequencies);
577 signals.push(add_noise(&base_signal, noise_level));
578 }
579
580 let results = spectral_flatness_batch_sparse_fft(
582 &signals,
583 0.3, 32, Some(WindowFunction::Hann),
586 None, None, )
589 .unwrap();
590
591 assert_eq!(results.len(), signals.len());
593
594 for result in &results {
596 assert_eq!(result.algorithm, SparseFFTAlgorithm::SpectralFlatness);
598
599 let found_30 = result.indices.contains(&30) || result.indices.contains(&(n - 30));
601 let found_70 = result.indices.contains(&70) || result.indices.contains(&(n - 70));
602 let found_120 = result.indices.contains(&120) || result.indices.contains(&(n - 120));
603
604 assert!(
605 found_30 || found_70 || found_120,
606 "Failed to find any of the key frequencies"
607 );
608 }
609 }
610}