scirs2_fft/sparse_fft/
estimation.rs

1//! Sparsity estimation methods for Sparse FFT
2//!
3//! This module provides various methods to estimate the sparsity of a signal,
4//! which determines how many significant frequency components are present.
5
6use crate::error::FFTResult;
7use crate::fft::fft;
8// Complex64 is used through the FFT functions
9use scirs2_core::numeric::NumCast;
10use std::f64::consts::PI;
11use std::fmt::Debug;
12
13use super::config::{SparseFFTConfig, SparsityEstimationMethod};
14
15/// Estimate sparsity of a signal using various methods
16#[allow(dead_code)]
17pub fn estimate_sparsity<T>(signal: &[T], config: &SparseFFTConfig) -> FFTResult<usize>
18where
19    T: NumCast + Copy + Debug + 'static,
20{
21    match config.estimation_method {
22        SparsityEstimationMethod::Manual => Ok(config.sparsity),
23
24        SparsityEstimationMethod::Threshold => {
25            estimate_sparsity_threshold(signal, config.threshold)
26        }
27
28        SparsityEstimationMethod::Adaptive => {
29            estimate_sparsity_adaptive(signal, config.adaptivity_factor, config.sparsity)
30        }
31
32        SparsityEstimationMethod::FrequencyPruning => {
33            estimate_sparsity_frequency_pruning(signal, config.pruning_sensitivity)
34        }
35
36        SparsityEstimationMethod::SpectralFlatness => estimate_sparsity_spectral_flatness(
37            signal,
38            config.flatness_threshold,
39            config.window_size,
40        ),
41    }
42}
43
44/// Estimate sparsity using magnitude thresholding
45#[allow(dead_code)]
46pub fn estimate_sparsity_threshold<T>(signal: &[T], threshold: f64) -> FFTResult<usize>
47where
48    T: NumCast + Copy + Debug + 'static,
49{
50    // Compute regular FFT
51    let spectrum = fft(signal, None)?;
52
53    // Find magnitudes
54    let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
55
56    // Find maximum magnitude
57    let max_magnitude = magnitudes.iter().cloned().fold(0.0, f64::max);
58
59    // Count coefficients above threshold
60    let threshold_value = max_magnitude * threshold;
61    let count = magnitudes.iter().filter(|&&m| m > threshold_value).count();
62
63    Ok(count)
64}
65
66/// Estimate sparsity using adaptive energy-based method
67#[allow(dead_code)]
68pub fn estimate_sparsity_adaptive<T>(
69    signal: &[T],
70    adaptivity_factor: f64,
71    fallback_sparsity: usize,
72) -> FFTResult<usize>
73where
74    T: NumCast + Copy + Debug + 'static,
75{
76    // Compute regular FFT
77    let spectrum = fft(signal, None)?;
78
79    // Find magnitudes and sort them
80    let mut magnitudes: Vec<(usize, f64)> = spectrum
81        .iter()
82        .enumerate()
83        .map(|(i, c)| (i, c.norm()))
84        .collect();
85
86    magnitudes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
87
88    // Find "elbow" in the magnitude curve using adaptivity _factor
89    let signal_energy: f64 = magnitudes.iter().map(|(_, m)| m * m).sum();
90    let mut cumulative_energy = 0.0;
91    let energy_threshold = signal_energy * (1.0 - adaptivity_factor);
92
93    for (i, (_, mag)) in magnitudes.iter().enumerate() {
94        cumulative_energy += mag * mag;
95        if cumulative_energy >= energy_threshold {
96            return Ok(i + 1);
97        }
98    }
99
100    // Fallback: return a default small value if we couldn't determine _sparsity
101    Ok(fallback_sparsity)
102}
103
104/// Estimate sparsity using frequency pruning method
105#[allow(dead_code)]
106pub fn estimate_sparsity_frequency_pruning<T>(
107    signal: &[T],
108    pruning_sensitivity: f64,
109) -> FFTResult<usize>
110where
111    T: NumCast + Copy + Debug + 'static,
112{
113    // Compute regular FFT
114    let spectrum = fft(signal, None)?;
115
116    // Use frequency pruning approach
117    let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
118    let n = magnitudes.len();
119
120    // Compute local variance in frequency domain
121    let mut local_variances = Vec::with_capacity(n);
122    let window_size = (n / 16).max(3).min(n);
123
124    for i in 0..n {
125        let start = i.saturating_sub(window_size / 2);
126        let end = (i + window_size / 2 + 1).min(n);
127
128        let window_mags = &magnitudes[start..end];
129        let mean = window_mags.iter().sum::<f64>() / window_mags.len() as f64;
130        let variance =
131            window_mags.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / window_mags.len() as f64;
132
133        local_variances.push(variance);
134    }
135
136    // Count significant components based on local variance
137    let mean_variance = local_variances.iter().sum::<f64>() / local_variances.len() as f64;
138    let variance_threshold = mean_variance * pruning_sensitivity;
139
140    let significant_count = local_variances
141        .iter()
142        .zip(magnitudes.iter())
143        .filter(|(&var, &mag)| var > variance_threshold && mag > 0.0)
144        .count();
145
146    Ok(significant_count.max(1))
147}
148
149/// Estimate sparsity using spectral flatness measure
150#[allow(dead_code)]
151pub fn estimate_sparsity_spectral_flatness<T>(
152    signal: &[T],
153    flatness_threshold: f64,
154    window_size: usize,
155) -> FFTResult<usize>
156where
157    T: NumCast + Copy + Debug + 'static,
158{
159    // Compute regular FFT
160    let spectrum = fft(signal, None)?;
161
162    // Compute power spectrum
163    let power_spectrum: Vec<f64> = spectrum.iter().map(|c| c.norm_sqr()).collect();
164    let n = power_spectrum.len();
165
166    // Compute spectral flatness for overlapping windows
167    let mut significant_components = 0;
168    let step_size = window_size / 2;
169
170    for start in (0..n).step_by(step_size) {
171        let end = (start + window_size).min(n);
172        let window_power = &power_spectrum[start..end];
173
174        // Skip if window is too small or contains only zeros
175        if window_power.len() < 2 || window_power.iter().all(|&x| x == 0.0) {
176            continue;
177        }
178
179        // Compute geometric mean
180        let geometric_mean = {
181            let log_sum = window_power
182                .iter()
183                .filter(|&&x| x > 0.0)
184                .map(|&x| x.ln())
185                .sum::<f64>();
186            let count = window_power.iter().filter(|&&x| x > 0.0).count() as f64;
187            if count > 0.0 {
188                (log_sum / count).exp()
189            } else {
190                0.0
191            }
192        };
193
194        // Compute arithmetic mean
195        let arithmetic_mean = window_power.iter().sum::<f64>() / window_power.len() as f64;
196
197        // Compute spectral flatness
198        let spectral_flatness = if arithmetic_mean > 0.0 {
199            geometric_mean / arithmetic_mean
200        } else {
201            0.0
202        };
203
204        // Count as significant if flatness is below _threshold (indicating peaks)
205        if spectral_flatness < flatness_threshold {
206            significant_components += window_power
207                .iter()
208                .filter(|&&x| x > arithmetic_mean * 0.1)
209                .count();
210        }
211    }
212
213    Ok(significant_components.max(1))
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    fn create_sparse_signal(n: usize, frequencies: &[(usize, f64)]) -> Vec<f64> {
221        let mut signal = vec![0.0; n];
222
223        for i in 0..n {
224            let t = 2.0 * PI * (i as f64) / (n as f64);
225            for &(freq, amp) in frequencies {
226                signal[i] += amp * (freq as f64 * t).sin();
227            }
228        }
229
230        signal
231    }
232
233    #[test]
234    fn test_estimate_sparsity_threshold() {
235        let n = 64;
236        let frequencies = vec![(3, 1.0), (7, 0.5)];
237        let signal = create_sparse_signal(n, &frequencies);
238
239        let result = estimate_sparsity_threshold(&signal, 0.1).unwrap();
240        // Should find approximately 4 components (positive and negative frequencies)
241        assert!(result >= 2 && result <= 8);
242    }
243
244    #[test]
245    fn test_estimate_sparsity_adaptive() {
246        let n = 64;
247        let frequencies = vec![(3, 1.0), (7, 0.5), (15, 0.25)];
248        let signal = create_sparse_signal(n, &frequencies);
249
250        let result = estimate_sparsity_adaptive(&signal, 0.25, 10).unwrap();
251        // Should find some components (adaptive method can vary)
252        assert!(result >= 2 && result <= 15);
253    }
254
255    #[test]
256    fn test_estimate_sparsity_frequency_pruning() {
257        let n = 64;
258        let frequencies = vec![(3, 1.0), (7, 0.5)];
259        let signal = create_sparse_signal(n, &frequencies);
260
261        let result = estimate_sparsity_frequency_pruning(&signal, 2.0).unwrap();
262        assert!(result >= 1);
263    }
264
265    #[test]
266    fn test_estimate_sparsity_spectral_flatness() {
267        let n = 64;
268        let frequencies = vec![(3, 1.0), (7, 0.5)];
269        let signal = create_sparse_signal(n, &frequencies);
270
271        let result = estimate_sparsity_spectral_flatness(&signal, 0.3, 8).unwrap();
272        assert!(result >= 1);
273    }
274}