1use crate::error::{FFTError, FFTResult};
6use crate::fft::{fft, ifft};
7use scirs2_core::numeric::Complex64;
8use scirs2_core::numeric::NumCast;
9use scirs2_core::random::{Rng, SeedableRng};
10use std::fmt::Debug;
11use std::time::Instant;
12
13use super::config::{SparseFFTAlgorithm, SparseFFTConfig};
14use super::estimation::estimate_sparsity;
15use super::windowing::apply_window;
16
17#[derive(Debug, Clone)]
19pub struct SparseFFTResult {
20 pub values: Vec<Complex64>,
22 pub indices: Vec<usize>,
24 pub estimated_sparsity: usize,
26 pub computation_time: std::time::Duration,
28 pub algorithm: SparseFFTAlgorithm,
30}
31
32pub struct SparseFFT {
34 config: SparseFFTConfig,
36 rng: scirs2_core::random::rngs::StdRng,
38}
39
40impl SparseFFT {
41 pub fn new(config: SparseFFTConfig) -> Self {
43 let seed = config.seed.unwrap_or_else(scirs2_core::random::random);
44 let rng = scirs2_core::random::rngs::StdRng::seed_from_u64(seed);
45
46 Self { config, rng }
47 }
48
49 pub fn with_default_config() -> Self {
51 Self::new(SparseFFTConfig::default())
52 }
53
54 pub fn estimate_sparsity<T>(&mut self, signal: &[T]) -> FFTResult<usize>
56 where
57 T: NumCast + Copy + Debug + 'static,
58 {
59 estimate_sparsity(signal, &self.config)
60 }
61
62 fn calculate_spectral_flatness(&self, magnitudes: &[f64]) -> f64 {
67 if magnitudes.is_empty() {
68 return 1.0; }
70
71 let epsilon = 1e-10;
73
74 let log_sum: f64 = magnitudes.iter().map(|&x| (x + epsilon).ln()).sum::<f64>();
76 let geometric_mean = (log_sum / magnitudes.len() as f64).exp();
77
78 let arithmetic_mean: f64 = magnitudes.iter().sum::<f64>() / magnitudes.len() as f64;
80
81 if arithmetic_mean < epsilon {
82 return 1.0; }
84
85 let flatness = geometric_mean / arithmetic_mean;
87
88 flatness.clamp(0.0, 1.0)
90 }
91
92 pub fn sparse_fft<T>(&mut self, signal: &[T]) -> FFTResult<SparseFFTResult>
94 where
95 T: NumCast + Copy + Debug + 'static,
96 {
97 let start = Instant::now();
99
100 let limit = signal.len().min(self.config.max_signal_size);
102 let limited_signal = &signal[..limit];
103
104 let windowed_signal = apply_window(
106 limited_signal,
107 self.config.window_function,
108 self.config.kaiser_beta,
109 )?;
110
111 let estimated_sparsity = self.estimate_sparsity(&windowed_signal)?;
113
114 let (values, indices) = match self.config.algorithm {
116 SparseFFTAlgorithm::Sublinear => {
117 self.sublinear_sfft(&windowed_signal, estimated_sparsity)?
118 }
119 SparseFFTAlgorithm::CompressedSensing => {
120 self.compressed_sensing_sfft(&windowed_signal, estimated_sparsity)?
121 }
122 SparseFFTAlgorithm::Iterative => {
123 self.iterative_sfft(&windowed_signal, estimated_sparsity)?
124 }
125 SparseFFTAlgorithm::Deterministic => {
126 self.deterministic_sfft(&windowed_signal, estimated_sparsity)?
127 }
128 SparseFFTAlgorithm::FrequencyPruning => {
129 self.frequency_pruning_sfft(&windowed_signal, estimated_sparsity)?
130 }
131 SparseFFTAlgorithm::SpectralFlatness => {
132 self.spectral_flatness_sfft(&windowed_signal, estimated_sparsity)?
133 }
134 };
135
136 let computation_time = start.elapsed();
138
139 Ok(SparseFFTResult {
140 values,
141 indices,
142 estimated_sparsity,
143 computation_time,
144 algorithm: self.config.algorithm,
145 })
146 }
147
148 pub fn sparse_fft_full<T>(&mut self, signal: &[T]) -> FFTResult<Vec<Complex64>>
150 where
151 T: NumCast + Copy + Debug + 'static,
152 {
153 let n = signal.len().min(self.config.max_signal_size);
154
155 let windowed_signal = apply_window(
157 &signal[..n],
158 self.config.window_function,
159 self.config.kaiser_beta,
160 )?;
161 let result = self.sparse_fft(&windowed_signal)?;
162
163 let mut spectrum = vec![Complex64::new(0.0, 0.0); n];
165 for (value, &index) in result.values.iter().zip(result.indices.iter()) {
166 spectrum[index] = *value;
167 }
168
169 Ok(spectrum)
170 }
171
172 pub fn reconstruct_signal(
174 &self,
175 sparse_result: &SparseFFTResult,
176 n: usize,
177 ) -> FFTResult<Vec<Complex64>> {
178 let mut spectrum = vec![Complex64::new(0.0, 0.0); n];
180 for (value, &index) in sparse_result
181 .values
182 .iter()
183 .zip(sparse_result.indices.iter())
184 {
185 spectrum[index] = *value;
186 }
187
188 ifft(&spectrum, None)
190 }
191
192 fn sublinear_sfft<T>(
194 &mut self,
195 signal: &[T],
196 k: usize,
197 ) -> FFTResult<(Vec<Complex64>, Vec<usize>)>
198 where
199 T: NumCast + Copy + Debug + 'static,
200 {
201 let signal_complex: Vec<Complex64> = signal
203 .iter()
204 .map(|&val| {
205 let val_f64 = NumCast::from(val).ok_or_else(|| {
206 FFTError::ValueError(format!("Could not convert {val:?} to f64"))
207 })?;
208 Ok(Complex64::new(val_f64, 0.0))
209 })
210 .collect::<FFTResult<Vec<_>>>()?;
211
212 let _n = signal_complex.len();
213
214 let spectrum = fft(&signal_complex, None)?;
217
218 let mut freq_with_magnitudes: Vec<(f64, usize, Complex64)> = spectrum
220 .iter()
221 .enumerate()
222 .map(|(i, &coef)| (coef.norm(), i, coef))
223 .collect();
224
225 freq_with_magnitudes
227 .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
228
229 let mut selected_indices = Vec::new();
231 let mut selected_values = Vec::new();
232
233 for &(_, idx, val) in freq_with_magnitudes.iter().take(k) {
234 selected_indices.push(idx);
235 selected_values.push(val);
236 }
237
238 Ok((selected_values, selected_indices))
239 }
240
241 fn compressed_sensing_sfft<T>(
243 &mut self,
244 signal: &[T],
245 k: usize,
246 ) -> FFTResult<(Vec<Complex64>, Vec<usize>)>
247 where
248 T: NumCast + Copy + Debug + 'static,
249 {
250 let signal_complex: Vec<Complex64> = signal
252 .iter()
253 .map(|&val| {
254 let val_f64 = NumCast::from(val).ok_or_else(|| {
255 FFTError::ValueError(format!("Could not convert {val:?} to f64"))
256 })?;
257 Ok(Complex64::new(val_f64, 0.0))
258 })
259 .collect::<FFTResult<Vec<_>>>()?;
260
261 let n = signal_complex.len();
262
263 let m = (4 * k * (self.config.iterations as f64).log2() as usize).min(n);
265
266 let mut measurements = Vec::with_capacity(m);
268 let mut sample_indices = Vec::with_capacity(m);
269
270 for _ in 0..m {
271 let idx = self.rng.gen_range(0..n);
272 sample_indices.push(idx);
273 measurements.push(signal_complex[idx]);
274 }
275
276 let spectrum = fft(&signal_complex, None)?;
278
279 let mut freq_with_magnitudes: Vec<(f64, usize, Complex64)> = spectrum
281 .iter()
282 .enumerate()
283 .map(|(i, &coef)| (coef.norm(), i, coef))
284 .collect();
285
286 freq_with_magnitudes
288 .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
289
290 let mut selected_indices = Vec::new();
292 let mut selected_values = Vec::new();
293
294 for &(_, idx, val) in freq_with_magnitudes.iter().take(k) {
295 selected_indices.push(idx);
296 selected_values.push(val);
297 }
298
299 Ok((selected_values, selected_indices))
300 }
301
302 fn iterative_sfft<T>(
304 &mut self,
305 signal: &[T],
306 k: usize,
307 ) -> FFTResult<(Vec<Complex64>, Vec<usize>)>
308 where
309 T: NumCast + Copy + Debug + 'static,
310 {
311 let mut signal_complex: Vec<Complex64> = signal
313 .iter()
314 .map(|&val| {
315 let val_f64 = NumCast::from(val).ok_or_else(|| {
316 FFTError::ValueError(format!("Could not convert {val:?} to f64"))
317 })?;
318 Ok(Complex64::new(val_f64, 0.0))
319 })
320 .collect::<FFTResult<Vec<_>>>()?;
321
322 let mut selected_indices = Vec::new();
323 let mut selected_values = Vec::new();
324
325 for _ in 0..k.min(self.config.iterations) {
327 let spectrum = fft(&signal_complex, None)?;
329
330 let (best_idx, best_value) = spectrum
332 .iter()
333 .enumerate()
334 .max_by(|(_, a), (_, b)| {
335 a.norm()
336 .partial_cmp(&b.norm())
337 .unwrap_or(std::cmp::Ordering::Equal)
338 })
339 .map(|(i, &val)| (i, val))
340 .ok_or_else(|| FFTError::ValueError("Empty spectrum".to_string()))?;
341
342 if best_value.norm() < 1e-10 {
344 break;
345 }
346
347 selected_indices.push(best_idx);
349 selected_values.push(best_value);
350
351 let n = signal_complex.len();
354 for (i, sample) in signal_complex.iter_mut().enumerate() {
355 let phase =
356 2.0 * std::f64::consts::PI * (best_idx as f64) * (i as f64) / (n as f64);
357 let component = best_value * Complex64::new(phase.cos(), phase.sin()) / (n as f64);
358 *sample -= component;
359 }
360 }
361
362 Ok((selected_values, selected_indices))
363 }
364
365 fn deterministic_sfft<T>(
367 &mut self,
368 signal: &[T],
369 k: usize,
370 ) -> FFTResult<(Vec<Complex64>, Vec<usize>)>
371 where
372 T: NumCast + Copy + Debug + 'static,
373 {
374 self.sublinear_sfft(signal, k)
377 }
378
379 fn frequency_pruning_sfft<T>(
381 &mut self,
382 signal: &[T],
383 k: usize,
384 ) -> FFTResult<(Vec<Complex64>, Vec<usize>)>
385 where
386 T: NumCast + Copy + Debug + 'static,
387 {
388 let signal_complex: Vec<Complex64> = signal
390 .iter()
391 .map(|&val| {
392 let val_f64 = NumCast::from(val).ok_or_else(|| {
393 FFTError::ValueError(format!("Could not convert {val:?} to f64"))
394 })?;
395 Ok(Complex64::new(val_f64, 0.0))
396 })
397 .collect::<FFTResult<Vec<_>>>()?;
398
399 let spectrum = fft(&signal_complex, None)?;
401 let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
402
403 let n = magnitudes.len();
405 let mean: f64 = magnitudes.iter().sum::<f64>() / n as f64;
406 let variance: f64 = magnitudes.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
407 let std_dev = variance.sqrt();
408
409 let threshold = mean + self.config.pruning_sensitivity * std_dev;
411
412 let mut candidates: Vec<(f64, usize, Complex64)> = spectrum
414 .iter()
415 .enumerate()
416 .filter(|(_, c)| c.norm() > threshold)
417 .map(|(i, &c)| (c.norm(), i, c))
418 .collect();
419
420 candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
422
423 let selected_count = k.min(candidates.len());
425 let selected_indices: Vec<usize> = candidates[..selected_count]
426 .iter()
427 .map(|(_, i_, _)| *i_)
428 .collect();
429 let selected_values: Vec<Complex64> = candidates[..selected_count]
430 .iter()
431 .map(|(_, _, c)| *c)
432 .collect();
433
434 Ok((selected_values, selected_indices))
435 }
436
437 fn spectral_flatness_sfft<T>(
439 &mut self,
440 signal: &[T],
441 k: usize,
442 ) -> FFTResult<(Vec<Complex64>, Vec<usize>)>
443 where
444 T: NumCast + Copy + Debug + 'static,
445 {
446 let signal_complex: Vec<Complex64> = signal
448 .iter()
449 .map(|&val| {
450 let val_f64 = NumCast::from(val).ok_or_else(|| {
451 FFTError::ValueError(format!("Could not convert {val:?} to f64"))
452 })?;
453 Ok(Complex64::new(val_f64, 0.0))
454 })
455 .collect::<FFTResult<Vec<_>>>()?;
456
457 let spectrum = fft(&signal_complex, None)?;
459 let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
460
461 let n = magnitudes.len();
463 let window_size = self.config.window_size.min(n);
464 let mut selected_indices = Vec::new();
465 let mut selected_values = Vec::new();
466
467 for start in (0..n).step_by(window_size / 2) {
468 let end = (start + window_size).min(n);
469 if start >= n {
470 break;
471 }
472
473 let window_mags = &magnitudes[start..end];
474 let flatness = self.calculate_spectral_flatness(window_mags);
475
476 if flatness < self.config.flatness_threshold {
478 if let Some((local_idx_, _)) = window_mags
479 .iter()
480 .enumerate()
481 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
482 {
483 let global_idx = start + local_idx_;
484 if !selected_indices.contains(&global_idx) {
485 selected_indices.push(global_idx);
486 selected_values.push(spectrum[global_idx]);
487 }
488 }
489 }
490
491 if selected_indices.len() >= k {
493 break;
494 }
495 }
496
497 if selected_indices.len() < k {
499 let mut remaining_candidates: Vec<(f64, usize, Complex64)> = spectrum
500 .iter()
501 .enumerate()
502 .filter(|(i_, _)| !selected_indices.contains(i_))
503 .map(|(i, &c)| (c.norm(), i, c))
504 .collect();
505
506 remaining_candidates
507 .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
508
509 let needed = k - selected_indices.len();
510 for (_, idx, val) in remaining_candidates.into_iter().take(needed) {
511 selected_indices.push(idx);
512 selected_values.push(val);
513 }
514 }
515
516 Ok((selected_values, selected_indices))
517 }
518}
519
520#[allow(dead_code)]
524pub fn sparse_fft<T>(
525 signal: &[T],
526 k: usize,
527 algorithm: Option<SparseFFTAlgorithm>,
528 seed: Option<u64>,
529) -> FFTResult<SparseFFTResult>
530where
531 T: NumCast + Copy + Debug + 'static,
532{
533 let config = SparseFFTConfig {
534 sparsity: k,
535 algorithm: algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear),
536 seed,
537 ..SparseFFTConfig::default()
538 };
539
540 let mut processor = SparseFFT::new(config);
541 processor.sparse_fft(signal)
542}
543
544#[allow(dead_code)]
546pub fn adaptive_sparse_fft<T>(signal: &[T], threshold: f64) -> FFTResult<SparseFFTResult>
547where
548 T: NumCast + Copy + Debug + 'static,
549{
550 let config = SparseFFTConfig {
551 estimation_method: super::config::SparsityEstimationMethod::Adaptive,
552 threshold,
553 adaptivity_factor: threshold,
554 ..SparseFFTConfig::default()
555 };
556
557 let mut processor = SparseFFT::new(config);
558 processor.sparse_fft(signal)
559}
560
561#[allow(dead_code)]
563pub fn frequency_pruning_sparse_fft<T>(
564 _signal: &[T],
565 sensitivity: f64,
566) -> FFTResult<SparseFFTResult>
567where
568 T: NumCast + Copy + Debug + 'static,
569{
570 let config = SparseFFTConfig {
571 estimation_method: super::config::SparsityEstimationMethod::FrequencyPruning,
572 algorithm: SparseFFTAlgorithm::FrequencyPruning,
573 pruning_sensitivity: sensitivity,
574 ..SparseFFTConfig::default()
575 };
576
577 let mut processor = SparseFFT::new(config);
578 processor.sparse_fft(_signal)
579}
580
581#[allow(dead_code)]
583pub fn spectral_flatness_sparse_fft<T>(
584 signal: &[T],
585 flatness_threshold: f64,
586 window_size: usize,
587) -> FFTResult<SparseFFTResult>
588where
589 T: NumCast + Copy + Debug + 'static,
590{
591 let config = SparseFFTConfig {
592 estimation_method: super::config::SparsityEstimationMethod::SpectralFlatness,
593 algorithm: SparseFFTAlgorithm::SpectralFlatness,
594 flatness_threshold,
595 window_size,
596 ..SparseFFTConfig::default()
597 };
598
599 let mut processor = SparseFFT::new(config);
600 processor.sparse_fft(signal)
601}
602
603#[allow(dead_code)]
605pub fn sparse_fft2<T>(
606 _signal: &[Vec<T>],
607 _k: usize,
608 _algorithm: Option<SparseFFTAlgorithm>,
609) -> FFTResult<SparseFFTResult>
610where
611 T: NumCast + Copy + Debug + 'static,
612{
613 Err(FFTError::ValueError(
615 "2D sparse FFT not yet implemented".to_string(),
616 ))
617}
618
619#[allow(dead_code)]
621pub fn sparse_fftn<T>(
622 _signal: &[T],
623 _shape: &[usize],
624 _k: usize,
625 _algorithm: Option<SparseFFTAlgorithm>,
626) -> FFTResult<SparseFFTResult>
627where
628 T: NumCast + Copy + Debug + 'static,
629{
630 Err(FFTError::ValueError(
632 "N-dimensional sparse FFT not yet implemented".to_string(),
633 ))
634}