1use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use super::config::{DenoisingMethod, WaveletConfig, WaveletFamily};
12use crate::error::{Result, TimeSeriesError};
13
14#[derive(Debug, Clone)]
20pub struct WaveletFeatures<F> {
21 pub energy_bands: Vec<F>,
23 pub relative_energy: Vec<F>,
25 pub wavelet_entropy: F,
27 pub wavelet_variance: F,
29 pub regularity_index: F,
31 pub dominant_scale: usize,
33 pub mra_features: MultiResolutionFeatures<F>,
35 pub time_frequency_features: TimeFrequencyFeatures<F>,
37 pub coefficient_stats: WaveletCoefficientStats<F>,
39}
40
41impl<F> Default for WaveletFeatures<F>
42where
43 F: Float + FromPrimitive,
44{
45 fn default() -> Self {
46 Self {
47 energy_bands: Vec::new(),
48 relative_energy: Vec::new(),
49 wavelet_entropy: F::zero(),
50 wavelet_variance: F::zero(),
51 regularity_index: F::zero(),
52 dominant_scale: 0,
53 mra_features: MultiResolutionFeatures::default(),
54 time_frequency_features: TimeFrequencyFeatures::default(),
55 coefficient_stats: WaveletCoefficientStats::default(),
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct MultiResolutionFeatures<F> {
63 pub level_energies: Vec<F>,
65 pub level_relative_energies: Vec<F>,
67 pub level_entropy: F,
69 pub dominant_level: usize,
71 pub level_cv: F,
73}
74
75impl<F> Default for MultiResolutionFeatures<F>
76where
77 F: Float + FromPrimitive,
78{
79 fn default() -> Self {
80 Self {
81 level_energies: Vec::new(),
82 level_relative_energies: Vec::new(),
83 level_entropy: F::zero(),
84 dominant_level: 0,
85 level_cv: F::zero(),
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct TimeFrequencyFeatures<F> {
93 pub instantaneous_frequencies: Vec<F>,
95 pub energy_concentrations: Vec<F>,
97 pub frequency_stability: F,
99 pub scalogram_entropy: F,
101 pub frequency_evolution: Vec<F>,
103}
104
105impl<F> Default for TimeFrequencyFeatures<F>
106where
107 F: Float + FromPrimitive,
108{
109 fn default() -> Self {
110 Self {
111 instantaneous_frequencies: Vec::new(),
112 energy_concentrations: Vec::new(),
113 frequency_stability: F::zero(),
114 scalogram_entropy: F::zero(),
115 frequency_evolution: Vec::new(),
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct WaveletCoefficientStats<F> {
123 pub level_means: Vec<F>,
125 pub level_stds: Vec<F>,
127 pub level_skewness: Vec<F>,
129 pub level_kurtosis: Vec<F>,
131 pub level_max_magnitudes: Vec<F>,
133 pub level_zero_crossings: Vec<usize>,
135}
136
137impl<F> Default for WaveletCoefficientStats<F>
138where
139 F: Float + FromPrimitive,
140{
141 fn default() -> Self {
142 Self {
143 level_means: Vec::new(),
144 level_stds: Vec::new(),
145 level_skewness: Vec::new(),
146 level_kurtosis: Vec::new(),
147 level_max_magnitudes: Vec::new(),
148 level_zero_crossings: Vec::new(),
149 }
150 }
151}
152
153#[derive(Debug, Clone)]
155pub struct WaveletDenoisingFeatures<F> {
156 pub snr_improvement: F,
158 pub energy_preserved: F,
160 pub coefficients_thresholded: usize,
162 pub optimal_threshold: F,
164 pub mse_reduction: F,
166}
167
168impl<F> Default for WaveletDenoisingFeatures<F>
169where
170 F: Float + FromPrimitive,
171{
172 fn default() -> Self {
173 Self {
174 snr_improvement: F::zero(),
175 energy_preserved: F::zero(),
176 coefficients_thresholded: 0,
177 optimal_threshold: F::zero(),
178 mse_reduction: F::zero(),
179 }
180 }
181}
182
183#[allow(dead_code)]
222pub fn calculate_wavelet_features<F>(
223 ts: &Array1<F>,
224 config: &WaveletConfig,
225) -> Result<WaveletFeatures<F>>
226where
227 F: Float + FromPrimitive + Debug + Clone + scirs2_core::ndarray::ScalarOperand,
228{
229 let n = ts.len();
230 if n < 8 {
231 return Ok(WaveletFeatures::default());
232 }
233
234 let dwt_result = discrete_wavelet_transform(ts, config)?;
236
237 let energy_bands = calculate_wavelet_energy_bands(&dwt_result.coefficients)?;
239 let relative_energy = calculate_relative_wavelet_energy(&energy_bands)?;
240
241 let wavelet_entropy = calculate_wavelet_entropy(&dwt_result.coefficients)?;
243
244 let wavelet_variance = calculate_wavelet_variance(&dwt_result.coefficients)?;
246
247 let regularity_index = calculate_regularity_index(&dwt_result.coefficients)?;
249
250 let dominant_scale = find_dominant_wavelet_scale(&energy_bands);
252
253 let mra_features = calculate_mra_features(&dwt_result)?;
255
256 let time_frequency_features = if config.calculate_cwt {
258 calculate_time_frequency_features(ts, config)?
259 } else {
260 TimeFrequencyFeatures::default()
261 };
262
263 let coefficient_stats = calculate_coefficient_statistics(&dwt_result.coefficients)?;
265
266 Ok(WaveletFeatures {
267 energy_bands,
268 relative_energy,
269 wavelet_entropy,
270 wavelet_variance,
271 regularity_index,
272 dominant_scale,
273 mra_features,
274 time_frequency_features,
275 coefficient_stats,
276 })
277}
278
279#[derive(Debug, Clone)]
285struct DWTResult<F> {
286 coefficients: Vec<Array1<F>>,
290 #[allow(dead_code)]
292 levels: usize,
293 #[allow(dead_code)]
295 original_length: usize,
296}
297
298#[allow(dead_code)]
304fn discrete_wavelet_transform<F>(signal: &Array1<F>, config: &WaveletConfig) -> Result<DWTResult<F>>
305where
306 F: Float + FromPrimitive + Debug + Clone,
307{
308 let n = signal.len();
309 let max_levels = (n as f64).log2().floor() as usize - 1;
310 let levels = config.levels.min(max_levels).max(1);
311
312 let mut coefficients = Vec::with_capacity(levels + 1);
313 let mut current_signal = signal.clone();
314
315 let (h, g) = get_wavelet_filters(&config.family)?;
317
318 for _level in 0..levels {
320 let (approx, detail) = wavelet_decompose_level(¤t_signal, &h, &g)?;
321
322 coefficients.push(detail);
324
325 current_signal = approx;
327
328 if current_signal.len() < 4 {
330 break;
331 }
332 }
333
334 coefficients.insert(0, current_signal);
336
337 Ok(DWTResult {
338 coefficients,
339 levels,
340 original_length: n,
341 })
342}
343
344#[allow(dead_code)]
346fn get_wavelet_filters<F>(family: &WaveletFamily) -> Result<(Array1<F>, Array1<F>)>
347where
348 F: Float + FromPrimitive,
349{
350 match family {
351 WaveletFamily::Haar => {
352 let sqrt_2_inv = F::from(std::f64::consts::FRAC_1_SQRT_2).unwrap();
354 let h = Array1::from_vec(vec![sqrt_2_inv, sqrt_2_inv]);
355 let g = Array1::from_vec(vec![-sqrt_2_inv, sqrt_2_inv]);
356 Ok((h, g))
357 }
358 WaveletFamily::Daubechies(n) => {
359 match n {
360 2 => {
361 let sqrt_2_inv = F::from(std::f64::consts::FRAC_1_SQRT_2).unwrap();
363 let h = Array1::from_vec(vec![sqrt_2_inv, sqrt_2_inv]);
364 let g = Array1::from_vec(vec![-sqrt_2_inv, sqrt_2_inv]);
365 Ok((h, g))
366 }
367 4 => {
368 let h = Array1::from_vec(vec![
370 F::from(0.48296291314469025).unwrap(),
371 F::from(0.8365163037378079).unwrap(),
372 F::from(0.22414386804185735).unwrap(),
373 F::from(-0.12940952255092145).unwrap(),
374 ]);
375 let g = Array1::from_vec(vec![
376 F::from(-0.12940952255092145).unwrap(),
377 F::from(-0.22414386804185735).unwrap(),
378 F::from(0.8365163037378079).unwrap(),
379 F::from(-0.48296291314469025).unwrap(),
380 ]);
381 Ok((h, g))
382 }
383 6 => {
384 let h = Array1::from_vec(vec![
386 F::from(0.3326705529509569).unwrap(),
387 F::from(0.8068915093133388).unwrap(),
388 F::from(0.4598775021193313).unwrap(),
389 F::from(-0.13501102001039084).unwrap(),
390 F::from(-0.08544127388224149).unwrap(),
391 F::from(0.035226291882100656).unwrap(),
392 ]);
393 let g = Array1::from_vec(vec![
394 F::from(0.035226291882100656).unwrap(),
395 F::from(0.08544127388224149).unwrap(),
396 F::from(-0.13501102001039084).unwrap(),
397 F::from(-0.4598775021193313).unwrap(),
398 F::from(0.8068915093133388).unwrap(),
399 F::from(-0.3326705529509569).unwrap(),
400 ]);
401 Ok((h, g))
402 }
403 _ => {
404 let h = Array1::from_vec(vec![
406 F::from(0.48296291314469025).unwrap(),
407 F::from(0.8365163037378079).unwrap(),
408 F::from(0.22414386804185735).unwrap(),
409 F::from(-0.12940952255092145).unwrap(),
410 ]);
411 let g = Array1::from_vec(vec![
412 F::from(-0.12940952255092145).unwrap(),
413 F::from(-0.22414386804185735).unwrap(),
414 F::from(0.8365163037378079).unwrap(),
415 F::from(-0.48296291314469025).unwrap(),
416 ]);
417 Ok((h, g))
418 }
419 }
420 }
421 _ => {
422 let h = Array1::from_vec(vec![
424 F::from(std::f64::consts::FRAC_1_SQRT_2).unwrap(),
425 F::from(std::f64::consts::FRAC_1_SQRT_2).unwrap(),
426 ]);
427 let g = Array1::from_vec(vec![
428 F::from(-std::f64::consts::FRAC_1_SQRT_2).unwrap(),
429 F::from(std::f64::consts::FRAC_1_SQRT_2).unwrap(),
430 ]);
431 Ok((h, g))
432 }
433 }
434}
435
436#[allow(dead_code)]
438fn wavelet_decompose_level<F>(
439 signal: &Array1<F>,
440 h: &Array1<F>, g: &Array1<F>, ) -> Result<(Array1<F>, Array1<F>)>
443where
444 F: Float + FromPrimitive + Clone,
445{
446 let n = signal.len();
447 let filter_len = h.len();
448
449 if n < filter_len {
450 return Err(TimeSeriesError::InsufficientData {
451 message: "Signal too short for wavelet decomposition".to_string(),
452 required: filter_len,
453 actual: n,
454 });
455 }
456
457 let approx_len = (n + filter_len - 1) / 2;
459 let detail_len = approx_len;
460
461 let mut approx = Array1::zeros(approx_len);
462 let mut detail = Array1::zeros(detail_len);
463
464 let mut approx_idx = 0;
465 let mut detail_idx = 0;
466
467 for i in (0..n).step_by(2) {
469 let mut approx_val = F::zero();
470 let mut detail_val = F::zero();
471
472 for j in 0..filter_len {
473 let signal_idx = if i + j < n { i + j } else { n - 1 };
474
475 approx_val = approx_val + h[j] * signal[signal_idx];
476 detail_val = detail_val + g[j] * signal[signal_idx];
477 }
478
479 if approx_idx < approx_len {
480 approx[approx_idx] = approx_val;
481 approx_idx += 1;
482 }
483
484 if detail_idx < detail_len {
485 detail[detail_idx] = detail_val;
486 detail_idx += 1;
487 }
488 }
489
490 Ok((approx, detail))
491}
492
493#[allow(dead_code)]
499fn calculate_wavelet_energy_bands<F>(coefficients: &[Array1<F>]) -> Result<Vec<F>>
500where
501 F: Float + FromPrimitive,
502{
503 let mut energy_bands = Vec::with_capacity(coefficients.len());
504
505 for coeff_level in coefficients {
506 let energy = coeff_level.mapv(|x| x * x).sum();
507 energy_bands.push(energy);
508 }
509
510 Ok(energy_bands)
511}
512
513#[allow(dead_code)]
515fn calculate_relative_wavelet_energy<F>(_energybands: &[F]) -> Result<Vec<F>>
516where
517 F: Float + FromPrimitive,
518{
519 let total_energy: F = _energybands.iter().fold(F::zero(), |acc, &x| acc + x);
520
521 if total_energy <= F::zero() {
522 return Ok(vec![F::zero(); _energybands.len()]);
523 }
524
525 let relative_energy = _energybands
526 .iter()
527 .map(|&energy| energy / total_energy)
528 .collect();
529
530 Ok(relative_energy)
531}
532
533#[allow(dead_code)]
544fn calculate_wavelet_entropy<F>(coefficients: &[Array1<F>]) -> Result<F>
545where
546 F: Float + FromPrimitive,
547{
548 let energy_bands = calculate_wavelet_energy_bands(coefficients)?;
549 let relative_energy = calculate_relative_wavelet_energy(&energy_bands)?;
550
551 let mut entropy = F::zero();
552 for &p in &relative_energy {
553 if p > F::zero() {
554 entropy = entropy - p * p.ln();
555 }
556 }
557
558 Ok(entropy)
559}
560
561#[allow(dead_code)]
563fn calculate_wavelet_variance<F>(coefficients: &[Array1<F>]) -> Result<F>
564where
565 F: Float + FromPrimitive,
566{
567 let mut total_variance = F::zero();
568 let mut total_count = 0;
569
570 for coeff_level in coefficients.iter().skip(1) {
572 if coeff_level.len() > 1 {
573 let mean = coeff_level.sum() / F::from(coeff_level.len()).unwrap();
574 let variance = coeff_level.mapv(|x| (x - mean) * (x - mean)).sum()
575 / F::from(coeff_level.len() - 1).unwrap();
576
577 total_variance = total_variance + variance;
578 total_count += 1;
579 }
580 }
581
582 if total_count > 0 {
583 Ok(total_variance / F::from(total_count).unwrap())
584 } else {
585 Ok(F::zero())
586 }
587}
588
589#[allow(dead_code)]
594fn calculate_regularity_index<F>(coefficients: &[Array1<F>]) -> Result<F>
595where
596 F: Float + FromPrimitive,
597{
598 if coefficients.len() < 2 {
599 return Ok(F::zero());
600 }
601
602 let mut scale_energies = Vec::new();
603
604 for (scale, coeff_level) in coefficients.iter().enumerate().skip(1) {
606 if !coeff_level.is_empty() {
607 let avg_energy =
608 coeff_level.mapv(|x| x * x).sum() / F::from(coeff_level.len()).unwrap();
609
610 if avg_energy > F::zero() {
611 let log_energy = avg_energy.ln();
612 let log_scale = F::from(scale).unwrap().ln();
613 scale_energies.push((log_scale, log_energy));
614 }
615 }
616 }
617
618 if scale_energies.len() < 2 {
619 return Ok(F::zero());
620 }
621
622 let n = F::from(scale_energies.len()).unwrap();
624 let sum_x: F = scale_energies
625 .iter()
626 .map(|(x_, _)| *x_)
627 .fold(F::zero(), |acc, x| acc + x);
628 let sum_y: F = scale_energies
629 .iter()
630 .map(|(_, y)| *y)
631 .fold(F::zero(), |acc, y| acc + y);
632 let sum_xy: F = scale_energies
633 .iter()
634 .map(|(x, y)| *x * *y)
635 .fold(F::zero(), |acc, xy| acc + xy);
636 let sum_xx: F = scale_energies
637 .iter()
638 .map(|(x_, _)| *x_ * *x_)
639 .fold(F::zero(), |acc, xx| acc + xx);
640
641 let denominator = n * sum_xx - sum_x * sum_x;
642 if denominator.abs() < F::from(1e-10).unwrap() {
643 return Ok(F::zero());
644 }
645
646 let slope = (n * sum_xy - sum_x * sum_y) / denominator;
647
648 Ok(-slope)
650}
651
652#[allow(dead_code)]
654fn find_dominant_wavelet_scale<F>(_energybands: &[F]) -> usize
655where
656 F: Float + PartialOrd,
657{
658 _energybands
659 .iter()
660 .enumerate()
661 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
662 .map(|(idx_, _)| idx_)
663 .unwrap_or(0)
664}
665
666#[allow(dead_code)]
672fn calculate_mra_features<F>(_dwtresult: &DWTResult<F>) -> Result<MultiResolutionFeatures<F>>
673where
674 F: Float + FromPrimitive,
675{
676 let level_energies = calculate_wavelet_energy_bands(&_dwtresult.coefficients)?;
677 let level_relative_energies = calculate_relative_wavelet_energy(&level_energies)?;
678
679 let mut level_entropy = F::zero();
681 for &p in &level_relative_energies {
682 if p > F::zero() {
683 level_entropy = level_entropy - p * p.ln();
684 }
685 }
686
687 let dominant_level = find_dominant_wavelet_scale(&level_energies);
689
690 let mean_energy = level_energies.iter().fold(F::zero(), |acc, &x| acc + x)
692 / F::from(level_energies.len()).unwrap();
693
694 let variance_energy = level_energies.iter().fold(F::zero(), |acc, &x| {
695 acc + (x - mean_energy) * (x - mean_energy)
696 }) / F::from(level_energies.len()).unwrap();
697
698 let level_cv = if mean_energy > F::zero() {
699 variance_energy.sqrt() / mean_energy
700 } else {
701 F::zero()
702 };
703
704 Ok(MultiResolutionFeatures {
705 level_energies,
706 level_relative_energies,
707 level_entropy,
708 dominant_level,
709 level_cv,
710 })
711}
712
713#[allow(dead_code)]
719fn calculate_time_frequency_features<F>(
720 signal: &Array1<F>,
721 config: &WaveletConfig,
722) -> Result<TimeFrequencyFeatures<F>>
723where
724 F: Float + FromPrimitive + Debug + Clone,
725{
726 let n = signal.len();
727 if n < 16 {
728 return Ok(TimeFrequencyFeatures::default());
729 }
730
731 let scales = generate_cwt_scales(config);
733 let cwt_matrix = compute_simplified_cwt(signal, &scales)?;
734
735 let instantaneous_frequencies = estimate_instantaneous_frequencies(&cwt_matrix, &scales)?;
737
738 let energy_concentrations = calculate_energy_concentrations(&cwt_matrix)?;
740
741 let frequency_stability = calculate_frequency_stability(&instantaneous_frequencies)?;
743
744 let scalogram_entropy = calculate_scalogram_entropy(&cwt_matrix)?;
746
747 let frequency_evolution = calculate_frequency_evolution(&cwt_matrix, &scales)?;
749
750 Ok(TimeFrequencyFeatures {
751 instantaneous_frequencies,
752 energy_concentrations,
753 frequency_stability,
754 scalogram_entropy,
755 frequency_evolution,
756 })
757}
758
759#[allow(dead_code)]
761fn generate_cwt_scales(config: &WaveletConfig) -> Vec<f64> {
762 let (min_scale, max_scale) = config.cwt_scales.unwrap_or((1.0, 32.0));
763 let count = config.cwt_scale_count;
764
765 let log_min = min_scale.ln();
766 let log_max = max_scale.ln();
767 let step = (log_max - log_min) / (count - 1) as f64;
768
769 (0..count)
770 .map(|i| (log_min + i as f64 * step).exp())
771 .collect()
772}
773
774#[allow(dead_code)]
776fn compute_simplified_cwt<F>(signal: &Array1<F>, scales: &[f64]) -> Result<Array2<F>>
777where
778 F: Float + FromPrimitive + Clone,
779{
780 let n = signal.len();
781 let n_scales = scales.len();
782 let mut cwt_matrix = Array2::zeros((n_scales, n));
783
784 for (scale_idx, &scale) in scales.iter().enumerate() {
785 let omega0 = 6.0; let wavelet_support = (8.0 * scale) as usize;
788
789 for t in 0..n {
790 let mut cwt_value = F::zero();
791 let mut norm = F::zero();
792
793 for tau in 0..wavelet_support {
794 let t_shifted = t as isize - tau as isize;
795 if t_shifted >= 0 && (t_shifted as usize) < n {
796 let signal_idx = t_shifted as usize;
797
798 let t_norm = (tau as f64) / scale;
800 let envelope = (-0.5 * t_norm * t_norm).exp();
801 let oscillation = (omega0 * t_norm).cos();
802 let wavelet_val = F::from(envelope * oscillation).unwrap();
803
804 cwt_value = cwt_value + signal[signal_idx] * wavelet_val;
805 norm = norm + wavelet_val * wavelet_val;
806 }
807 }
808
809 if norm > F::zero() {
811 cwt_matrix[[scale_idx, t]] = cwt_value / norm.sqrt();
812 }
813 }
814 }
815
816 Ok(cwt_matrix)
817}
818
819#[allow(dead_code)]
821fn estimate_instantaneous_frequencies<F>(_cwtmatrix: &Array2<F>, scales: &[f64]) -> Result<Vec<F>>
822where
823 F: Float + FromPrimitive + PartialOrd,
824{
825 let (_, n_time) = _cwtmatrix.dim();
826 let mut inst_freqs = Vec::with_capacity(n_time);
827
828 for t in 0..n_time {
829 let time_slice = _cwtmatrix.column(t);
830
831 let max_scale_idx = time_slice
833 .iter()
834 .enumerate()
835 .max_by(|(_, a), (_, b)| {
836 a.abs()
837 .partial_cmp(&b.abs())
838 .unwrap_or(std::cmp::Ordering::Equal)
839 })
840 .map(|(idx_, _)| idx_)
841 .unwrap_or(0);
842
843 let scale = scales[max_scale_idx];
845 let freq = 1.0 / scale; inst_freqs.push(F::from(freq).unwrap());
847 }
848
849 Ok(inst_freqs)
850}
851
852#[allow(dead_code)]
854fn calculate_energy_concentrations<F>(_cwtmatrix: &Array2<F>) -> Result<Vec<F>>
855where
856 F: Float + FromPrimitive,
857{
858 let (_, n_time) = _cwtmatrix.dim();
859 let mut concentrations = Vec::with_capacity(n_time);
860
861 for t in 0..n_time {
862 let time_slice = _cwtmatrix.column(t);
863 let energy = time_slice.mapv(|x| x * x).sum();
864 concentrations.push(energy);
865 }
866
867 Ok(concentrations)
868}
869
870#[allow(dead_code)]
872fn calculate_frequency_stability<F>(_instantaneousfrequencies: &[F]) -> Result<F>
873where
874 F: Float + FromPrimitive,
875{
876 if _instantaneousfrequencies.len() < 2 {
877 return Ok(F::zero());
878 }
879
880 let n = _instantaneousfrequencies.len();
881 let mean = _instantaneousfrequencies
882 .iter()
883 .fold(F::zero(), |acc, &x| acc + x)
884 / F::from(n).unwrap();
885
886 let variance = _instantaneousfrequencies
887 .iter()
888 .fold(F::zero(), |acc, &x| acc + (x - mean) * (x - mean))
889 / F::from(n - 1).unwrap();
890
891 if mean > F::zero() {
893 let cv = variance.sqrt() / mean;
894 Ok(F::one() / (F::one() + cv))
895 } else {
896 Ok(F::zero())
897 }
898}
899
900#[allow(dead_code)]
902fn calculate_scalogram_entropy<F>(_cwtmatrix: &Array2<F>) -> Result<F>
903where
904 F: Float + FromPrimitive,
905{
906 let total_energy = _cwtmatrix.mapv(|x| x * x).sum();
907
908 if total_energy <= F::zero() {
909 return Ok(F::zero());
910 }
911
912 let mut entropy = F::zero();
913 for &coeff in _cwtmatrix.iter() {
914 let energy = coeff * coeff;
915 if energy > F::zero() {
916 let p = energy / total_energy;
917 entropy = entropy - p * p.ln();
918 }
919 }
920
921 Ok(entropy)
922}
923
924#[allow(dead_code)]
926fn calculate_frequency_evolution<F>(_cwtmatrix: &Array2<F>, scales: &[f64]) -> Result<Vec<F>>
927where
928 F: Float + FromPrimitive + PartialOrd,
929{
930 let (_, n_time) = _cwtmatrix.dim();
931 let mut evolution = Vec::with_capacity(n_time);
932
933 for t in 0..n_time {
934 let time_slice = _cwtmatrix.column(t);
935
936 let mut weighted_freq = F::zero();
938 let mut total_weight = F::zero();
939
940 for (scale_idx, &scale) in scales.iter().enumerate() {
941 let weight = time_slice[scale_idx] * time_slice[scale_idx];
942 let freq = F::from(1.0 / scale).unwrap();
943
944 weighted_freq = weighted_freq + weight * freq;
945 total_weight = total_weight + weight;
946 }
947
948 if total_weight > F::zero() {
949 evolution.push(weighted_freq / total_weight);
950 } else {
951 evolution.push(F::zero());
952 }
953 }
954
955 Ok(evolution)
956}
957
958#[allow(dead_code)]
964fn calculate_coefficient_statistics<F>(
965 coefficients: &[Array1<F>],
966) -> Result<WaveletCoefficientStats<F>>
967where
968 F: Float + FromPrimitive + PartialOrd,
969{
970 let mut level_means = Vec::new();
971 let mut level_stds = Vec::new();
972 let mut level_skewness = Vec::new();
973 let mut level_kurtosis = Vec::new();
974 let mut level_max_magnitudes = Vec::new();
975 let mut level_zero_crossings = Vec::new();
976
977 for coeff_level in coefficients {
978 if coeff_level.is_empty() {
979 level_means.push(F::zero());
980 level_stds.push(F::zero());
981 level_skewness.push(F::zero());
982 level_kurtosis.push(F::zero());
983 level_max_magnitudes.push(F::zero());
984 level_zero_crossings.push(0);
985 continue;
986 }
987
988 let n = coeff_level.len();
989 let n_f = F::from(n).unwrap();
990
991 let mean = coeff_level.sum() / n_f;
993 level_means.push(mean);
994
995 let variance = coeff_level.mapv(|x| (x - mean) * (x - mean)).sum() / n_f;
997 let std_dev = variance.sqrt();
998 level_stds.push(std_dev);
999
1000 if std_dev > F::zero() {
1002 let mut sum_cube = F::zero();
1003 let mut sum_fourth = F::zero();
1004
1005 for &x in coeff_level.iter() {
1006 let norm_dev = (x - mean) / std_dev;
1007 let norm_dev_sq = norm_dev * norm_dev;
1008 sum_cube = sum_cube + norm_dev * norm_dev_sq;
1009 sum_fourth = sum_fourth + norm_dev_sq * norm_dev_sq;
1010 }
1011
1012 let skewness = sum_cube / n_f;
1013 let kurtosis = sum_fourth / n_f - F::from(3.0).unwrap();
1014
1015 level_skewness.push(skewness);
1016 level_kurtosis.push(kurtosis);
1017 } else {
1018 level_skewness.push(F::zero());
1019 level_kurtosis.push(F::zero());
1020 }
1021
1022 let max_magnitude = coeff_level
1024 .iter()
1025 .map(|&x| x.abs())
1026 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1027 .unwrap_or(F::zero());
1028 level_max_magnitudes.push(max_magnitude);
1029
1030 let mut zero_crossings = 0;
1032 for i in 1..coeff_level.len() {
1033 if (coeff_level[i - 1] >= F::zero()) != (coeff_level[i] >= F::zero()) {
1034 zero_crossings += 1;
1035 }
1036 }
1037 level_zero_crossings.push(zero_crossings);
1038 }
1039
1040 Ok(WaveletCoefficientStats {
1041 level_means,
1042 level_stds,
1043 level_skewness,
1044 level_kurtosis,
1045 level_max_magnitudes,
1046 level_zero_crossings,
1047 })
1048}
1049
1050#[allow(dead_code)]
1065pub fn wavelet_denoise<F>(
1066 signal: &Array1<F>,
1067 config: &WaveletConfig,
1068) -> Result<(Array1<F>, WaveletDenoisingFeatures<F>)>
1069where
1070 F: Float + FromPrimitive + Debug + Clone + PartialOrd,
1071{
1072 let dwt_result = discrete_wavelet_transform(signal, config)?;
1074
1075 let threshold =
1077 calculate_optimal_threshold(&dwt_result.coefficients, &config.denoising_method)?;
1078
1079 let (thresholded_coeffs, coefficients_thresholded) = apply_thresholding(
1081 &dwt_result.coefficients,
1082 threshold,
1083 &config.denoising_method,
1084 )?;
1085
1086 let denoised_signal = reconstruct_signal_simplified(&thresholded_coeffs)?;
1088
1089 let original_energy = signal.mapv(|x| x * x).sum();
1091 let denoised_energy = denoised_signal.mapv(|x| x * x).sum();
1092 let energy_preserved = if original_energy > F::zero() {
1093 denoised_energy / original_energy
1094 } else {
1095 F::zero()
1096 };
1097
1098 let snr_improvement = calculate_snr_improvement(signal, &denoised_signal)?;
1100
1101 let mse_reduction = calculate_mse_reduction(signal, &denoised_signal)?;
1103
1104 let features = WaveletDenoisingFeatures {
1105 snr_improvement,
1106 energy_preserved,
1107 coefficients_thresholded,
1108 optimal_threshold: threshold,
1109 mse_reduction,
1110 };
1111
1112 Ok((denoised_signal, features))
1113}
1114
1115#[allow(dead_code)]
1117fn calculate_optimal_threshold<F>(coefficients: &[Array1<F>], method: &DenoisingMethod) -> Result<F>
1118where
1119 F: Float + FromPrimitive + PartialOrd,
1120{
1121 let finest_detail = &coefficients[coefficients.len() - 1];
1123 if finest_detail.is_empty() {
1124 return Ok(F::zero());
1125 }
1126
1127 let mut sorted_coeffs: Vec<F> = finest_detail.iter().map(|&x| x.abs()).collect();
1128 sorted_coeffs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1129
1130 let median_idx = sorted_coeffs.len() / 2;
1131 let mad = if sorted_coeffs.len().is_multiple_of(2) {
1132 (sorted_coeffs[median_idx - 1] + sorted_coeffs[median_idx]) / F::from(2.0).unwrap()
1133 } else {
1134 sorted_coeffs[median_idx]
1135 };
1136
1137 let sigma = mad / F::from(0.6745).unwrap(); match method {
1140 DenoisingMethod::Hard | DenoisingMethod::Soft => {
1141 let n = F::from(finest_detail.len()).unwrap();
1143 Ok(sigma * (F::from(2.0).unwrap() * n.ln()).sqrt())
1144 }
1145 DenoisingMethod::Sure => {
1146 Ok(sigma * F::from(1.5).unwrap())
1148 }
1149 DenoisingMethod::Minimax => {
1150 Ok(sigma * F::from(0.8).unwrap())
1152 }
1153 }
1154}
1155
1156#[allow(dead_code)]
1158fn apply_thresholding<F>(
1159 coefficients: &[Array1<F>],
1160 threshold: F,
1161 method: &DenoisingMethod,
1162) -> Result<(Vec<Array1<F>>, usize)>
1163where
1164 F: Float + FromPrimitive + PartialOrd + Clone,
1165{
1166 let mut thresholded_coeffs = Vec::new();
1167 let mut total_thresholded = 0;
1168
1169 for (level, coeff_level) in coefficients.iter().enumerate() {
1170 if level == 0 {
1171 thresholded_coeffs.push(coeff_level.clone());
1173 continue;
1174 }
1175
1176 let mut thresholded_level = Array1::zeros(coeff_level.len());
1177 let mut _level_thresholded = 0;
1178
1179 for (i, &coeff) in coeff_level.iter().enumerate() {
1180 let abs_coeff = coeff.abs();
1181
1182 if abs_coeff <= threshold {
1183 _level_thresholded += 1;
1184 total_thresholded += 1;
1185 } else {
1187 thresholded_level[i] = match method {
1188 DenoisingMethod::Hard => coeff,
1189 DenoisingMethod::Soft => {
1190 let sign = if coeff >= F::zero() {
1191 F::one()
1192 } else {
1193 -F::one()
1194 };
1195 sign * (abs_coeff - threshold)
1196 }
1197 DenoisingMethod::Sure | DenoisingMethod::Minimax => {
1198 let sign = if coeff >= F::zero() {
1200 F::one()
1201 } else {
1202 -F::one()
1203 };
1204 sign * (abs_coeff - threshold)
1205 }
1206 };
1207 }
1208 }
1209
1210 thresholded_coeffs.push(thresholded_level);
1211 }
1212
1213 Ok((thresholded_coeffs, total_thresholded))
1214}
1215
1216#[allow(dead_code)]
1218fn reconstruct_signal_simplified<F>(coefficients: &[Array1<F>]) -> Result<Array1<F>>
1219where
1220 F: Float + FromPrimitive + Clone,
1221{
1222 if coefficients.is_empty() {
1223 return Ok(Array1::zeros(0));
1224 }
1225
1226 let approx_coeffs = &coefficients[0];
1228 let mut reconstructed = approx_coeffs.clone();
1229
1230 for (level, detail_coeffs) in coefficients.iter().enumerate().skip(1) {
1232 let scale_factor = F::from(2.0_f64.powi(level as i32)).unwrap();
1233
1234 for (i, &detail) in detail_coeffs.iter().enumerate() {
1236 let target_idx = i.min(reconstructed.len() - 1);
1237 reconstructed[target_idx] = reconstructed[target_idx] + detail / scale_factor;
1238 }
1239 }
1240
1241 Ok(reconstructed)
1242}
1243
1244#[allow(dead_code)]
1246fn calculate_snr_improvement<F>(original: &Array1<F>, denoised: &Array1<F>) -> Result<F>
1247where
1248 F: Float + FromPrimitive,
1249{
1250 let signal_power = original.mapv(|x| x * x).sum();
1251 let noise_power = original
1252 .iter()
1253 .zip(denoised.iter())
1254 .fold(F::zero(), |acc, (&orig, &den)| {
1255 let diff = orig - den;
1256 acc + diff * diff
1257 });
1258
1259 if noise_power > F::zero() && signal_power > F::zero() {
1260 let snr = (signal_power / noise_power).ln() / F::from(10.0).unwrap().ln()
1261 * F::from(10.0).unwrap();
1262 Ok(snr)
1263 } else {
1264 Ok(F::zero())
1265 }
1266}
1267
1268#[allow(dead_code)]
1270fn calculate_mse_reduction<F>(original: &Array1<F>, denoised: &Array1<F>) -> Result<F>
1271where
1272 F: Float + FromPrimitive,
1273{
1274 let n = F::from(original.len()).unwrap();
1275 let mse = original
1276 .iter()
1277 .zip(denoised.iter())
1278 .fold(F::zero(), |acc, (&orig, &den)| {
1279 let diff = orig - den;
1280 acc + diff * diff
1281 })
1282 / n;
1283
1284 let signal_variance = original.mapv(|x| x * x).sum() / n;
1286 if signal_variance > F::zero() {
1287 Ok(F::one() - (mse / signal_variance))
1288 } else {
1289 Ok(F::zero())
1290 }
1291}