scirs2_fft/sparse_fft/
estimation.rs1use crate::error::FFTResult;
7use crate::fft::fft;
8use scirs2_core::numeric::NumCast;
10use std::f64::consts::PI;
11use std::fmt::Debug;
12
13use super::config::{SparseFFTConfig, SparsityEstimationMethod};
14
15#[allow(dead_code)]
17pub fn estimate_sparsity<T>(signal: &[T], config: &SparseFFTConfig) -> FFTResult<usize>
18where
19 T: NumCast + Copy + Debug + 'static,
20{
21 match config.estimation_method {
22 SparsityEstimationMethod::Manual => Ok(config.sparsity),
23
24 SparsityEstimationMethod::Threshold => {
25 estimate_sparsity_threshold(signal, config.threshold)
26 }
27
28 SparsityEstimationMethod::Adaptive => {
29 estimate_sparsity_adaptive(signal, config.adaptivity_factor, config.sparsity)
30 }
31
32 SparsityEstimationMethod::FrequencyPruning => {
33 estimate_sparsity_frequency_pruning(signal, config.pruning_sensitivity)
34 }
35
36 SparsityEstimationMethod::SpectralFlatness => estimate_sparsity_spectral_flatness(
37 signal,
38 config.flatness_threshold,
39 config.window_size,
40 ),
41 }
42}
43
44#[allow(dead_code)]
46pub fn estimate_sparsity_threshold<T>(signal: &[T], threshold: f64) -> FFTResult<usize>
47where
48 T: NumCast + Copy + Debug + 'static,
49{
50 let spectrum = fft(signal, None)?;
52
53 let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
55
56 let max_magnitude = magnitudes.iter().cloned().fold(0.0, f64::max);
58
59 let threshold_value = max_magnitude * threshold;
61 let count = magnitudes.iter().filter(|&&m| m > threshold_value).count();
62
63 Ok(count)
64}
65
66#[allow(dead_code)]
68pub fn estimate_sparsity_adaptive<T>(
69 signal: &[T],
70 adaptivity_factor: f64,
71 fallback_sparsity: usize,
72) -> FFTResult<usize>
73where
74 T: NumCast + Copy + Debug + 'static,
75{
76 let spectrum = fft(signal, None)?;
78
79 let mut magnitudes: Vec<(usize, f64)> = spectrum
81 .iter()
82 .enumerate()
83 .map(|(i, c)| (i, c.norm()))
84 .collect();
85
86 magnitudes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
87
88 let signal_energy: f64 = magnitudes.iter().map(|(_, m)| m * m).sum();
90 let mut cumulative_energy = 0.0;
91 let energy_threshold = signal_energy * (1.0 - adaptivity_factor);
92
93 for (i, (_, mag)) in magnitudes.iter().enumerate() {
94 cumulative_energy += mag * mag;
95 if cumulative_energy >= energy_threshold {
96 return Ok(i + 1);
97 }
98 }
99
100 Ok(fallback_sparsity)
102}
103
104#[allow(dead_code)]
106pub fn estimate_sparsity_frequency_pruning<T>(
107 signal: &[T],
108 pruning_sensitivity: f64,
109) -> FFTResult<usize>
110where
111 T: NumCast + Copy + Debug + 'static,
112{
113 let spectrum = fft(signal, None)?;
115
116 let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
118 let n = magnitudes.len();
119
120 let mut local_variances = Vec::with_capacity(n);
122 let window_size = (n / 16).max(3).min(n);
123
124 for i in 0..n {
125 let start = i.saturating_sub(window_size / 2);
126 let end = (i + window_size / 2 + 1).min(n);
127
128 let window_mags = &magnitudes[start..end];
129 let mean = window_mags.iter().sum::<f64>() / window_mags.len() as f64;
130 let variance =
131 window_mags.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / window_mags.len() as f64;
132
133 local_variances.push(variance);
134 }
135
136 let mean_variance = local_variances.iter().sum::<f64>() / local_variances.len() as f64;
138 let variance_threshold = mean_variance * pruning_sensitivity;
139
140 let significant_count = local_variances
141 .iter()
142 .zip(magnitudes.iter())
143 .filter(|(&var, &mag)| var > variance_threshold && mag > 0.0)
144 .count();
145
146 Ok(significant_count.max(1))
147}
148
149#[allow(dead_code)]
151pub fn estimate_sparsity_spectral_flatness<T>(
152 signal: &[T],
153 flatness_threshold: f64,
154 window_size: usize,
155) -> FFTResult<usize>
156where
157 T: NumCast + Copy + Debug + 'static,
158{
159 let spectrum = fft(signal, None)?;
161
162 let power_spectrum: Vec<f64> = spectrum.iter().map(|c| c.norm_sqr()).collect();
164 let n = power_spectrum.len();
165
166 let mut significant_components = 0;
168 let step_size = window_size / 2;
169
170 for start in (0..n).step_by(step_size) {
171 let end = (start + window_size).min(n);
172 let window_power = &power_spectrum[start..end];
173
174 if window_power.len() < 2 || window_power.iter().all(|&x| x == 0.0) {
176 continue;
177 }
178
179 let geometric_mean = {
181 let log_sum = window_power
182 .iter()
183 .filter(|&&x| x > 0.0)
184 .map(|&x| x.ln())
185 .sum::<f64>();
186 let count = window_power.iter().filter(|&&x| x > 0.0).count() as f64;
187 if count > 0.0 {
188 (log_sum / count).exp()
189 } else {
190 0.0
191 }
192 };
193
194 let arithmetic_mean = window_power.iter().sum::<f64>() / window_power.len() as f64;
196
197 let spectral_flatness = if arithmetic_mean > 0.0 {
199 geometric_mean / arithmetic_mean
200 } else {
201 0.0
202 };
203
204 if spectral_flatness < flatness_threshold {
206 significant_components += window_power
207 .iter()
208 .filter(|&&x| x > arithmetic_mean * 0.1)
209 .count();
210 }
211 }
212
213 Ok(significant_components.max(1))
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 fn create_sparse_signal(n: usize, frequencies: &[(usize, f64)]) -> Vec<f64> {
221 let mut signal = vec![0.0; n];
222
223 for i in 0..n {
224 let t = 2.0 * PI * (i as f64) / (n as f64);
225 for &(freq, amp) in frequencies {
226 signal[i] += amp * (freq as f64 * t).sin();
227 }
228 }
229
230 signal
231 }
232
233 #[test]
234 fn test_estimate_sparsity_threshold() {
235 let n = 64;
236 let frequencies = vec![(3, 1.0), (7, 0.5)];
237 let signal = create_sparse_signal(n, &frequencies);
238
239 let result = estimate_sparsity_threshold(&signal, 0.1).unwrap();
240 assert!(result >= 2 && result <= 8);
242 }
243
244 #[test]
245 fn test_estimate_sparsity_adaptive() {
246 let n = 64;
247 let frequencies = vec![(3, 1.0), (7, 0.5), (15, 0.25)];
248 let signal = create_sparse_signal(n, &frequencies);
249
250 let result = estimate_sparsity_adaptive(&signal, 0.25, 10).unwrap();
251 assert!(result >= 2 && result <= 15);
253 }
254
255 #[test]
256 fn test_estimate_sparsity_frequency_pruning() {
257 let n = 64;
258 let frequencies = vec![(3, 1.0), (7, 0.5)];
259 let signal = create_sparse_signal(n, &frequencies);
260
261 let result = estimate_sparsity_frequency_pruning(&signal, 2.0).unwrap();
262 assert!(result >= 1);
263 }
264
265 #[test]
266 fn test_estimate_sparsity_spectral_flatness() {
267 let n = 64;
268 let frequencies = vec![(3, 1.0), (7, 0.5)];
269 let signal = create_sparse_signal(n, &frequencies);
270
271 let result = estimate_sparsity_spectral_flatness(&signal, 0.3, 8).unwrap();
272 assert!(result >= 1);
273 }
274}