scirs2_series/features/
wavelet.rs

1//! Wavelet transform features for time series analysis
2//!
3//! This module provides comprehensive wavelet-based feature extraction including
4//! Discrete Wavelet Transform (DWT), Continuous Wavelet Transform (CWT),
5//! multi-resolution analysis, time-frequency analysis, and wavelet-based denoising.
6
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use super::config::{DenoisingMethod, WaveletConfig, WaveletFamily};
12use crate::error::{Result, TimeSeriesError};
13
14/// Wavelet-based features for time series analysis
15///
16/// This struct contains comprehensive wavelet transform features including
17/// energy distribution across scales, entropy measures, regularity indices,
18/// and time-frequency analysis results.
19#[derive(Debug, Clone)]
20pub struct WaveletFeatures<F> {
21    /// Energy at different frequency bands from DWT decomposition
22    pub energy_bands: Vec<F>,
23    /// Relative wavelet energy (normalized energy distribution)
24    pub relative_energy: Vec<F>,
25    /// Wavelet entropy (Shannon entropy of wavelet coefficients)
26    pub wavelet_entropy: F,
27    /// Wavelet variance (measure of signal variability)
28    pub wavelet_variance: F,
29    /// Regularity measure based on wavelet coefficients
30    pub regularity_index: F,
31    /// Dominant scale from wavelet decomposition
32    pub dominant_scale: usize,
33    /// Multi-resolution analysis features
34    pub mra_features: MultiResolutionFeatures<F>,
35    /// Time-frequency analysis features
36    pub time_frequency_features: TimeFrequencyFeatures<F>,
37    /// Wavelet coefficient statistics
38    pub coefficient_stats: WaveletCoefficientStats<F>,
39}
40
41impl<F> Default for WaveletFeatures<F>
42where
43    F: Float + FromPrimitive,
44{
45    fn default() -> Self {
46        Self {
47            energy_bands: Vec::new(),
48            relative_energy: Vec::new(),
49            wavelet_entropy: F::zero(),
50            wavelet_variance: F::zero(),
51            regularity_index: F::zero(),
52            dominant_scale: 0,
53            mra_features: MultiResolutionFeatures::default(),
54            time_frequency_features: TimeFrequencyFeatures::default(),
55            coefficient_stats: WaveletCoefficientStats::default(),
56        }
57    }
58}
59
60/// Multi-resolution analysis features from wavelet decomposition
61#[derive(Debug, Clone)]
62pub struct MultiResolutionFeatures<F> {
63    /// Energy per resolution level
64    pub level_energies: Vec<F>,
65    /// Relative energy per level
66    pub level_relative_energies: Vec<F>,
67    /// Energy distribution entropy across levels
68    pub level_entropy: F,
69    /// Dominant resolution level
70    pub dominant_level: usize,
71    /// Coefficient of variation across levels
72    pub level_cv: F,
73}
74
75impl<F> Default for MultiResolutionFeatures<F>
76where
77    F: Float + FromPrimitive,
78{
79    fn default() -> Self {
80        Self {
81            level_energies: Vec::new(),
82            level_relative_energies: Vec::new(),
83            level_entropy: F::zero(),
84            dominant_level: 0,
85            level_cv: F::zero(),
86        }
87    }
88}
89
90/// Time-frequency analysis features from continuous wavelet transform
91#[derive(Debug, Clone)]
92pub struct TimeFrequencyFeatures<F> {
93    /// Instantaneous frequency estimates
94    pub instantaneous_frequencies: Vec<F>,
95    /// Time-localized energy concentrations
96    pub energy_concentrations: Vec<F>,
97    /// Frequency content stability over time
98    pub frequency_stability: F,
99    /// Scalogram entropy (time-frequency entropy)
100    pub scalogram_entropy: F,
101    /// Peak frequency evolution over time
102    pub frequency_evolution: Vec<F>,
103}
104
105impl<F> Default for TimeFrequencyFeatures<F>
106where
107    F: Float + FromPrimitive,
108{
109    fn default() -> Self {
110        Self {
111            instantaneous_frequencies: Vec::new(),
112            energy_concentrations: Vec::new(),
113            frequency_stability: F::zero(),
114            scalogram_entropy: F::zero(),
115            frequency_evolution: Vec::new(),
116        }
117    }
118}
119
120/// Statistical features of wavelet coefficients
121#[derive(Debug, Clone)]
122pub struct WaveletCoefficientStats<F> {
123    /// Mean of coefficients per level
124    pub level_means: Vec<F>,
125    /// Standard deviation of coefficients per level
126    pub level_stds: Vec<F>,
127    /// Skewness of coefficients per level
128    pub level_skewness: Vec<F>,
129    /// Kurtosis of coefficients per level
130    pub level_kurtosis: Vec<F>,
131    /// Maximum coefficient magnitude per level
132    pub level_max_magnitudes: Vec<F>,
133    /// Zero-crossing rate per level
134    pub level_zero_crossings: Vec<usize>,
135}
136
137impl<F> Default for WaveletCoefficientStats<F>
138where
139    F: Float + FromPrimitive,
140{
141    fn default() -> Self {
142        Self {
143            level_means: Vec::new(),
144            level_stds: Vec::new(),
145            level_skewness: Vec::new(),
146            level_kurtosis: Vec::new(),
147            level_max_magnitudes: Vec::new(),
148            level_zero_crossings: Vec::new(),
149        }
150    }
151}
152
153/// Wavelet denoising features
154#[derive(Debug, Clone)]
155pub struct WaveletDenoisingFeatures<F> {
156    /// Signal-to-noise ratio improvement
157    pub snr_improvement: F,
158    /// Energy preserved after denoising
159    pub energy_preserved: F,
160    /// Number of coefficients thresholded
161    pub coefficients_thresholded: usize,
162    /// Optimal threshold value used
163    pub optimal_threshold: F,
164    /// Mean squared error reduction
165    pub mse_reduction: F,
166}
167
168impl<F> Default for WaveletDenoisingFeatures<F>
169where
170    F: Float + FromPrimitive,
171{
172    fn default() -> Self {
173        Self {
174            snr_improvement: F::zero(),
175            energy_preserved: F::zero(),
176            coefficients_thresholded: 0,
177            optimal_threshold: F::zero(),
178            mse_reduction: F::zero(),
179        }
180    }
181}
182
183// =============================================================================
184// Main Calculation Functions
185// =============================================================================
186
187/// Calculate comprehensive wavelet-based features
188///
189/// This function performs wavelet decomposition and extracts various features
190/// including energy distribution, entropy measures, regularity indices,
191/// and time-frequency characteristics.
192///
193/// # Mathematical Background
194///
195/// The Discrete Wavelet Transform (DWT) decomposes a signal into different
196/// frequency bands (scales). For a signal x(t), the DWT coefficients are:
197///
198/// ```text
199/// W(j,k) = ∑ x(n) ψ*_{j,k}(n)
200/// ```
201///
202/// where ψ_{j,k} are the wavelet basis functions at scale j and position k.
203///
204/// The Continuous Wavelet Transform (CWT) provides time-frequency analysis:
205///
206/// ```text
207/// CWT(a,b) = (1/√a) ∫ x(t) ψ*((t-b)/a) dt
208/// ```
209///
210/// where a is the scale parameter and b is the translation parameter.
211///
212/// # Arguments
213///
214/// * `ts` - Input time series data
215/// * `config` - Wavelet analysis configuration
216///
217/// # Returns
218///
219/// Comprehensive wavelet features including energy distribution,
220/// entropy measures, and time-frequency characteristics.
221#[allow(dead_code)]
222pub fn calculate_wavelet_features<F>(
223    ts: &Array1<F>,
224    config: &WaveletConfig,
225) -> Result<WaveletFeatures<F>>
226where
227    F: Float + FromPrimitive + Debug + Clone + scirs2_core::ndarray::ScalarOperand,
228{
229    let n = ts.len();
230    if n < 8 {
231        return Ok(WaveletFeatures::default());
232    }
233
234    // Perform Discrete Wavelet Transform
235    let dwt_result = discrete_wavelet_transform(ts, config)?;
236
237    // Calculate energy-based features
238    let energy_bands = calculate_wavelet_energy_bands(&dwt_result.coefficients)?;
239    let relative_energy = calculate_relative_wavelet_energy(&energy_bands)?;
240
241    // Calculate wavelet entropy
242    let wavelet_entropy = calculate_wavelet_entropy(&dwt_result.coefficients)?;
243
244    // Calculate wavelet variance
245    let wavelet_variance = calculate_wavelet_variance(&dwt_result.coefficients)?;
246
247    // Calculate regularity index
248    let regularity_index = calculate_regularity_index(&dwt_result.coefficients)?;
249
250    // Find dominant scale
251    let dominant_scale = find_dominant_wavelet_scale(&energy_bands);
252
253    // Calculate multi-resolution analysis features
254    let mra_features = calculate_mra_features(&dwt_result)?;
255
256    // Calculate time-frequency features (CWT-based)
257    let time_frequency_features = if config.calculate_cwt {
258        calculate_time_frequency_features(ts, config)?
259    } else {
260        TimeFrequencyFeatures::default()
261    };
262
263    // Calculate coefficient statistics
264    let coefficient_stats = calculate_coefficient_statistics(&dwt_result.coefficients)?;
265
266    Ok(WaveletFeatures {
267        energy_bands,
268        relative_energy,
269        wavelet_entropy,
270        wavelet_variance,
271        regularity_index,
272        dominant_scale,
273        mra_features,
274        time_frequency_features,
275        coefficient_stats,
276    })
277}
278
279// =============================================================================
280// DWT Implementation
281// =============================================================================
282
283/// Result of Discrete Wavelet Transform
284#[derive(Debug, Clone)]
285struct DWTResult<F> {
286    /// Wavelet coefficients organized by decomposition level
287    /// coefficients[0] = approximation coefficients (lowest frequency)
288    /// coefficients[1..n] = detail coefficients from level 1 to n
289    coefficients: Vec<Array1<F>>,
290    /// Number of decomposition levels
291    #[allow(dead_code)]
292    levels: usize,
293    /// Original signal length
294    #[allow(dead_code)]
295    original_length: usize,
296}
297
298/// Perform Discrete Wavelet Transform
299///
300/// Implements a simplified DWT using Haar wavelets or Daubechies wavelets.
301/// This is a basic implementation for demonstration purposes.
302/// In production, you would typically use a specialized wavelet library.
303#[allow(dead_code)]
304fn discrete_wavelet_transform<F>(signal: &Array1<F>, config: &WaveletConfig) -> Result<DWTResult<F>>
305where
306    F: Float + FromPrimitive + Debug + Clone,
307{
308    let n = signal.len();
309    let max_levels = (n as f64).log2().floor() as usize - 1;
310    let levels = config.levels.min(max_levels).max(1);
311
312    let mut coefficients = Vec::with_capacity(levels + 1);
313    let mut current_signal = signal.clone();
314
315    // Get wavelet filter coefficients
316    let (h, g) = get_wavelet_filters(&config.family)?;
317
318    // Perform multilevel decomposition
319    for _level in 0..levels {
320        let (approx, detail) = wavelet_decompose_level(&current_signal, &h, &g)?;
321
322        // Store detail coefficients for this level
323        coefficients.push(detail);
324
325        // Use approximation for next level
326        current_signal = approx;
327
328        // Stop if _signal becomes too short
329        if current_signal.len() < 4 {
330            break;
331        }
332    }
333
334    // Store final approximation coefficients
335    coefficients.insert(0, current_signal);
336
337    Ok(DWTResult {
338        coefficients,
339        levels,
340        original_length: n,
341    })
342}
343
344/// Get wavelet filter coefficients for different wavelet families
345#[allow(dead_code)]
346fn get_wavelet_filters<F>(family: &WaveletFamily) -> Result<(Array1<F>, Array1<F>)>
347where
348    F: Float + FromPrimitive,
349{
350    match family {
351        WaveletFamily::Haar => {
352            // Haar wavelet filters
353            let sqrt_2_inv = F::from(std::f64::consts::FRAC_1_SQRT_2).unwrap();
354            let h = Array1::from_vec(vec![sqrt_2_inv, sqrt_2_inv]);
355            let g = Array1::from_vec(vec![-sqrt_2_inv, sqrt_2_inv]);
356            Ok((h, g))
357        }
358        WaveletFamily::Daubechies(n) => {
359            match n {
360                2 => {
361                    // db2 (same as Haar)
362                    let sqrt_2_inv = F::from(std::f64::consts::FRAC_1_SQRT_2).unwrap();
363                    let h = Array1::from_vec(vec![sqrt_2_inv, sqrt_2_inv]);
364                    let g = Array1::from_vec(vec![-sqrt_2_inv, sqrt_2_inv]);
365                    Ok((h, g))
366                }
367                4 => {
368                    // db4 Daubechies-4 coefficients
369                    let h = Array1::from_vec(vec![
370                        F::from(0.48296291314469025).unwrap(),
371                        F::from(0.8365163037378079).unwrap(),
372                        F::from(0.22414386804185735).unwrap(),
373                        F::from(-0.12940952255092145).unwrap(),
374                    ]);
375                    let g = Array1::from_vec(vec![
376                        F::from(-0.12940952255092145).unwrap(),
377                        F::from(-0.22414386804185735).unwrap(),
378                        F::from(0.8365163037378079).unwrap(),
379                        F::from(-0.48296291314469025).unwrap(),
380                    ]);
381                    Ok((h, g))
382                }
383                6 => {
384                    // db6 Daubechies-6 coefficients
385                    let h = Array1::from_vec(vec![
386                        F::from(0.3326705529509569).unwrap(),
387                        F::from(0.8068915093133388).unwrap(),
388                        F::from(0.4598775021193313).unwrap(),
389                        F::from(-0.13501102001039084).unwrap(),
390                        F::from(-0.08544127388224149).unwrap(),
391                        F::from(0.035226291882100656).unwrap(),
392                    ]);
393                    let g = Array1::from_vec(vec![
394                        F::from(0.035226291882100656).unwrap(),
395                        F::from(0.08544127388224149).unwrap(),
396                        F::from(-0.13501102001039084).unwrap(),
397                        F::from(-0.4598775021193313).unwrap(),
398                        F::from(0.8068915093133388).unwrap(),
399                        F::from(-0.3326705529509569).unwrap(),
400                    ]);
401                    Ok((h, g))
402                }
403                _ => {
404                    // Default to db4 for unsupported orders
405                    let h = Array1::from_vec(vec![
406                        F::from(0.48296291314469025).unwrap(),
407                        F::from(0.8365163037378079).unwrap(),
408                        F::from(0.22414386804185735).unwrap(),
409                        F::from(-0.12940952255092145).unwrap(),
410                    ]);
411                    let g = Array1::from_vec(vec![
412                        F::from(-0.12940952255092145).unwrap(),
413                        F::from(-0.22414386804185735).unwrap(),
414                        F::from(0.8365163037378079).unwrap(),
415                        F::from(-0.48296291314469025).unwrap(),
416                    ]);
417                    Ok((h, g))
418                }
419            }
420        }
421        _ => {
422            // Default to Haar for unsupported families
423            let h = Array1::from_vec(vec![
424                F::from(std::f64::consts::FRAC_1_SQRT_2).unwrap(),
425                F::from(std::f64::consts::FRAC_1_SQRT_2).unwrap(),
426            ]);
427            let g = Array1::from_vec(vec![
428                F::from(-std::f64::consts::FRAC_1_SQRT_2).unwrap(),
429                F::from(std::f64::consts::FRAC_1_SQRT_2).unwrap(),
430            ]);
431            Ok((h, g))
432        }
433    }
434}
435
436/// Perform one level of wavelet decomposition
437#[allow(dead_code)]
438fn wavelet_decompose_level<F>(
439    signal: &Array1<F>,
440    h: &Array1<F>, // Low-pass filter
441    g: &Array1<F>, // High-pass filter
442) -> Result<(Array1<F>, Array1<F>)>
443where
444    F: Float + FromPrimitive + Clone,
445{
446    let n = signal.len();
447    let filter_len = h.len();
448
449    if n < filter_len {
450        return Err(TimeSeriesError::InsufficientData {
451            message: "Signal too short for wavelet decomposition".to_string(),
452            required: filter_len,
453            actual: n,
454        });
455    }
456
457    // Convolve with filters and downsample
458    let approx_len = (n + filter_len - 1) / 2;
459    let detail_len = approx_len;
460
461    let mut approx = Array1::zeros(approx_len);
462    let mut detail = Array1::zeros(detail_len);
463
464    let mut approx_idx = 0;
465    let mut detail_idx = 0;
466
467    // Convolution with downsampling by 2
468    for i in (0..n).step_by(2) {
469        let mut approx_val = F::zero();
470        let mut detail_val = F::zero();
471
472        for j in 0..filter_len {
473            let signal_idx = if i + j < n { i + j } else { n - 1 };
474
475            approx_val = approx_val + h[j] * signal[signal_idx];
476            detail_val = detail_val + g[j] * signal[signal_idx];
477        }
478
479        if approx_idx < approx_len {
480            approx[approx_idx] = approx_val;
481            approx_idx += 1;
482        }
483
484        if detail_idx < detail_len {
485            detail[detail_idx] = detail_val;
486            detail_idx += 1;
487        }
488    }
489
490    Ok((approx, detail))
491}
492
493// =============================================================================
494// Energy and Entropy Analysis
495// =============================================================================
496
497/// Calculate energy in each wavelet frequency band
498#[allow(dead_code)]
499fn calculate_wavelet_energy_bands<F>(coefficients: &[Array1<F>]) -> Result<Vec<F>>
500where
501    F: Float + FromPrimitive,
502{
503    let mut energy_bands = Vec::with_capacity(coefficients.len());
504
505    for coeff_level in coefficients {
506        let energy = coeff_level.mapv(|x| x * x).sum();
507        energy_bands.push(energy);
508    }
509
510    Ok(energy_bands)
511}
512
513/// Calculate relative wavelet energy (normalized energy distribution)
514#[allow(dead_code)]
515fn calculate_relative_wavelet_energy<F>(_energybands: &[F]) -> Result<Vec<F>>
516where
517    F: Float + FromPrimitive,
518{
519    let total_energy: F = _energybands.iter().fold(F::zero(), |acc, &x| acc + x);
520
521    if total_energy <= F::zero() {
522        return Ok(vec![F::zero(); _energybands.len()]);
523    }
524
525    let relative_energy = _energybands
526        .iter()
527        .map(|&energy| energy / total_energy)
528        .collect();
529
530    Ok(relative_energy)
531}
532
533/// Calculate wavelet entropy based on energy distribution
534///
535/// Wavelet entropy measures the disorder in the wavelet coefficient
536/// energy distribution across different scales.
537///
538/// ```text
539/// WE = -∑ p_j * log(p_j)
540/// ```
541///
542/// where p_j is the relative energy at scale j.
543#[allow(dead_code)]
544fn calculate_wavelet_entropy<F>(coefficients: &[Array1<F>]) -> Result<F>
545where
546    F: Float + FromPrimitive,
547{
548    let energy_bands = calculate_wavelet_energy_bands(coefficients)?;
549    let relative_energy = calculate_relative_wavelet_energy(&energy_bands)?;
550
551    let mut entropy = F::zero();
552    for &p in &relative_energy {
553        if p > F::zero() {
554            entropy = entropy - p * p.ln();
555        }
556    }
557
558    Ok(entropy)
559}
560
561/// Calculate wavelet variance as a measure of signal variability
562#[allow(dead_code)]
563fn calculate_wavelet_variance<F>(coefficients: &[Array1<F>]) -> Result<F>
564where
565    F: Float + FromPrimitive,
566{
567    let mut total_variance = F::zero();
568    let mut total_count = 0;
569
570    // Skip the first level (approximation coefficients) and only use detail _coefficients
571    for coeff_level in coefficients.iter().skip(1) {
572        if coeff_level.len() > 1 {
573            let mean = coeff_level.sum() / F::from(coeff_level.len()).unwrap();
574            let variance = coeff_level.mapv(|x| (x - mean) * (x - mean)).sum()
575                / F::from(coeff_level.len() - 1).unwrap();
576
577            total_variance = total_variance + variance;
578            total_count += 1;
579        }
580    }
581
582    if total_count > 0 {
583        Ok(total_variance / F::from(total_count).unwrap())
584    } else {
585        Ok(F::zero())
586    }
587}
588
589/// Calculate regularity index based on wavelet coefficients
590///
591/// The regularity index measures the smoothness/regularity of the signal
592/// based on the decay of wavelet coefficients across scales.
593#[allow(dead_code)]
594fn calculate_regularity_index<F>(coefficients: &[Array1<F>]) -> Result<F>
595where
596    F: Float + FromPrimitive,
597{
598    if coefficients.len() < 2 {
599        return Ok(F::zero());
600    }
601
602    let mut scale_energies = Vec::new();
603
604    // Calculate log of average energy per scale
605    for (scale, coeff_level) in coefficients.iter().enumerate().skip(1) {
606        if !coeff_level.is_empty() {
607            let avg_energy =
608                coeff_level.mapv(|x| x * x).sum() / F::from(coeff_level.len()).unwrap();
609
610            if avg_energy > F::zero() {
611                let log_energy = avg_energy.ln();
612                let log_scale = F::from(scale).unwrap().ln();
613                scale_energies.push((log_scale, log_energy));
614            }
615        }
616    }
617
618    if scale_energies.len() < 2 {
619        return Ok(F::zero());
620    }
621
622    // Linear regression to estimate slope (regularity)
623    let n = F::from(scale_energies.len()).unwrap();
624    let sum_x: F = scale_energies
625        .iter()
626        .map(|(x_, _)| *x_)
627        .fold(F::zero(), |acc, x| acc + x);
628    let sum_y: F = scale_energies
629        .iter()
630        .map(|(_, y)| *y)
631        .fold(F::zero(), |acc, y| acc + y);
632    let sum_xy: F = scale_energies
633        .iter()
634        .map(|(x, y)| *x * *y)
635        .fold(F::zero(), |acc, xy| acc + xy);
636    let sum_xx: F = scale_energies
637        .iter()
638        .map(|(x_, _)| *x_ * *x_)
639        .fold(F::zero(), |acc, xx| acc + xx);
640
641    let denominator = n * sum_xx - sum_x * sum_x;
642    if denominator.abs() < F::from(1e-10).unwrap() {
643        return Ok(F::zero());
644    }
645
646    let slope = (n * sum_xy - sum_x * sum_y) / denominator;
647
648    // Regularity index is related to the negative slope
649    Ok(-slope)
650}
651
652/// Find the dominant scale (frequency band) based on energy distribution
653#[allow(dead_code)]
654fn find_dominant_wavelet_scale<F>(_energybands: &[F]) -> usize
655where
656    F: Float + PartialOrd,
657{
658    _energybands
659        .iter()
660        .enumerate()
661        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
662        .map(|(idx_, _)| idx_)
663        .unwrap_or(0)
664}
665
666// =============================================================================
667// Multi-Resolution Analysis
668// =============================================================================
669
670/// Calculate multi-resolution analysis features
671#[allow(dead_code)]
672fn calculate_mra_features<F>(_dwtresult: &DWTResult<F>) -> Result<MultiResolutionFeatures<F>>
673where
674    F: Float + FromPrimitive,
675{
676    let level_energies = calculate_wavelet_energy_bands(&_dwtresult.coefficients)?;
677    let level_relative_energies = calculate_relative_wavelet_energy(&level_energies)?;
678
679    // Calculate entropy across levels
680    let mut level_entropy = F::zero();
681    for &p in &level_relative_energies {
682        if p > F::zero() {
683            level_entropy = level_entropy - p * p.ln();
684        }
685    }
686
687    // Find dominant level
688    let dominant_level = find_dominant_wavelet_scale(&level_energies);
689
690    // Calculate coefficient of variation across levels
691    let mean_energy = level_energies.iter().fold(F::zero(), |acc, &x| acc + x)
692        / F::from(level_energies.len()).unwrap();
693
694    let variance_energy = level_energies.iter().fold(F::zero(), |acc, &x| {
695        acc + (x - mean_energy) * (x - mean_energy)
696    }) / F::from(level_energies.len()).unwrap();
697
698    let level_cv = if mean_energy > F::zero() {
699        variance_energy.sqrt() / mean_energy
700    } else {
701        F::zero()
702    };
703
704    Ok(MultiResolutionFeatures {
705        level_energies,
706        level_relative_energies,
707        level_entropy,
708        dominant_level,
709        level_cv,
710    })
711}
712
713// =============================================================================
714// Continuous Wavelet Transform (CWT)
715// =============================================================================
716
717/// Calculate time-frequency features using simplified CWT
718#[allow(dead_code)]
719fn calculate_time_frequency_features<F>(
720    signal: &Array1<F>,
721    config: &WaveletConfig,
722) -> Result<TimeFrequencyFeatures<F>>
723where
724    F: Float + FromPrimitive + Debug + Clone,
725{
726    let n = signal.len();
727    if n < 16 {
728        return Ok(TimeFrequencyFeatures::default());
729    }
730
731    // Simplified CWT using Morlet wavelet
732    let scales = generate_cwt_scales(config);
733    let cwt_matrix = compute_simplified_cwt(signal, &scales)?;
734
735    // Calculate instantaneous frequencies (simplified)
736    let instantaneous_frequencies = estimate_instantaneous_frequencies(&cwt_matrix, &scales)?;
737
738    // Calculate energy concentrations
739    let energy_concentrations = calculate_energy_concentrations(&cwt_matrix)?;
740
741    // Calculate frequency stability
742    let frequency_stability = calculate_frequency_stability(&instantaneous_frequencies)?;
743
744    // Calculate scalogram entropy
745    let scalogram_entropy = calculate_scalogram_entropy(&cwt_matrix)?;
746
747    // Calculate frequency evolution
748    let frequency_evolution = calculate_frequency_evolution(&cwt_matrix, &scales)?;
749
750    Ok(TimeFrequencyFeatures {
751        instantaneous_frequencies,
752        energy_concentrations,
753        frequency_stability,
754        scalogram_entropy,
755        frequency_evolution,
756    })
757}
758
759/// Generate scales for CWT analysis
760#[allow(dead_code)]
761fn generate_cwt_scales(config: &WaveletConfig) -> Vec<f64> {
762    let (min_scale, max_scale) = config.cwt_scales.unwrap_or((1.0, 32.0));
763    let count = config.cwt_scale_count;
764
765    let log_min = min_scale.ln();
766    let log_max = max_scale.ln();
767    let step = (log_max - log_min) / (count - 1) as f64;
768
769    (0..count)
770        .map(|i| (log_min + i as f64 * step).exp())
771        .collect()
772}
773
774/// Compute simplified CWT using Morlet-like wavelet
775#[allow(dead_code)]
776fn compute_simplified_cwt<F>(signal: &Array1<F>, scales: &[f64]) -> Result<Array2<F>>
777where
778    F: Float + FromPrimitive + Clone,
779{
780    let n = signal.len();
781    let n_scales = scales.len();
782    let mut cwt_matrix = Array2::zeros((n_scales, n));
783
784    for (scale_idx, &scale) in scales.iter().enumerate() {
785        // Simple wavelet: modulated Gaussian
786        let omega0 = 6.0; // Central frequency
787        let wavelet_support = (8.0 * scale) as usize;
788
789        for t in 0..n {
790            let mut cwt_value = F::zero();
791            let mut norm = F::zero();
792
793            for tau in 0..wavelet_support {
794                let t_shifted = t as isize - tau as isize;
795                if t_shifted >= 0 && (t_shifted as usize) < n {
796                    let signal_idx = t_shifted as usize;
797
798                    // Simplified Morlet wavelet
799                    let t_norm = (tau as f64) / scale;
800                    let envelope = (-0.5 * t_norm * t_norm).exp();
801                    let oscillation = (omega0 * t_norm).cos();
802                    let wavelet_val = F::from(envelope * oscillation).unwrap();
803
804                    cwt_value = cwt_value + signal[signal_idx] * wavelet_val;
805                    norm = norm + wavelet_val * wavelet_val;
806                }
807            }
808
809            // Normalize
810            if norm > F::zero() {
811                cwt_matrix[[scale_idx, t]] = cwt_value / norm.sqrt();
812            }
813        }
814    }
815
816    Ok(cwt_matrix)
817}
818
819/// Estimate instantaneous frequencies from CWT
820#[allow(dead_code)]
821fn estimate_instantaneous_frequencies<F>(_cwtmatrix: &Array2<F>, scales: &[f64]) -> Result<Vec<F>>
822where
823    F: Float + FromPrimitive + PartialOrd,
824{
825    let (_, n_time) = _cwtmatrix.dim();
826    let mut inst_freqs = Vec::with_capacity(n_time);
827
828    for t in 0..n_time {
829        let time_slice = _cwtmatrix.column(t);
830
831        // Find scale with maximum magnitude
832        let max_scale_idx = time_slice
833            .iter()
834            .enumerate()
835            .max_by(|(_, a), (_, b)| {
836                a.abs()
837                    .partial_cmp(&b.abs())
838                    .unwrap_or(std::cmp::Ordering::Equal)
839            })
840            .map(|(idx_, _)| idx_)
841            .unwrap_or(0);
842
843        // Convert scale to frequency (simplified)
844        let scale = scales[max_scale_idx];
845        let freq = 1.0 / scale; // Simplified frequency estimation
846        inst_freqs.push(F::from(freq).unwrap());
847    }
848
849    Ok(inst_freqs)
850}
851
852/// Calculate energy concentrations from CWT
853#[allow(dead_code)]
854fn calculate_energy_concentrations<F>(_cwtmatrix: &Array2<F>) -> Result<Vec<F>>
855where
856    F: Float + FromPrimitive,
857{
858    let (_, n_time) = _cwtmatrix.dim();
859    let mut concentrations = Vec::with_capacity(n_time);
860
861    for t in 0..n_time {
862        let time_slice = _cwtmatrix.column(t);
863        let energy = time_slice.mapv(|x| x * x).sum();
864        concentrations.push(energy);
865    }
866
867    Ok(concentrations)
868}
869
870/// Calculate frequency stability over time
871#[allow(dead_code)]
872fn calculate_frequency_stability<F>(_instantaneousfrequencies: &[F]) -> Result<F>
873where
874    F: Float + FromPrimitive,
875{
876    if _instantaneousfrequencies.len() < 2 {
877        return Ok(F::zero());
878    }
879
880    let n = _instantaneousfrequencies.len();
881    let mean = _instantaneousfrequencies
882        .iter()
883        .fold(F::zero(), |acc, &x| acc + x)
884        / F::from(n).unwrap();
885
886    let variance = _instantaneousfrequencies
887        .iter()
888        .fold(F::zero(), |acc, &x| acc + (x - mean) * (x - mean))
889        / F::from(n - 1).unwrap();
890
891    // Stability is inverse of coefficient of variation
892    if mean > F::zero() {
893        let cv = variance.sqrt() / mean;
894        Ok(F::one() / (F::one() + cv))
895    } else {
896        Ok(F::zero())
897    }
898}
899
900/// Calculate scalogram entropy
901#[allow(dead_code)]
902fn calculate_scalogram_entropy<F>(_cwtmatrix: &Array2<F>) -> Result<F>
903where
904    F: Float + FromPrimitive,
905{
906    let total_energy = _cwtmatrix.mapv(|x| x * x).sum();
907
908    if total_energy <= F::zero() {
909        return Ok(F::zero());
910    }
911
912    let mut entropy = F::zero();
913    for &coeff in _cwtmatrix.iter() {
914        let energy = coeff * coeff;
915        if energy > F::zero() {
916            let p = energy / total_energy;
917            entropy = entropy - p * p.ln();
918        }
919    }
920
921    Ok(entropy)
922}
923
924/// Calculate frequency evolution over time
925#[allow(dead_code)]
926fn calculate_frequency_evolution<F>(_cwtmatrix: &Array2<F>, scales: &[f64]) -> Result<Vec<F>>
927where
928    F: Float + FromPrimitive + PartialOrd,
929{
930    let (_, n_time) = _cwtmatrix.dim();
931    let mut evolution = Vec::with_capacity(n_time);
932
933    for t in 0..n_time {
934        let time_slice = _cwtmatrix.column(t);
935
936        // Calculate weighted average frequency
937        let mut weighted_freq = F::zero();
938        let mut total_weight = F::zero();
939
940        for (scale_idx, &scale) in scales.iter().enumerate() {
941            let weight = time_slice[scale_idx] * time_slice[scale_idx];
942            let freq = F::from(1.0 / scale).unwrap();
943
944            weighted_freq = weighted_freq + weight * freq;
945            total_weight = total_weight + weight;
946        }
947
948        if total_weight > F::zero() {
949            evolution.push(weighted_freq / total_weight);
950        } else {
951            evolution.push(F::zero());
952        }
953    }
954
955    Ok(evolution)
956}
957
958// =============================================================================
959// Coefficient Statistics
960// =============================================================================
961
962/// Calculate statistical features of wavelet coefficients
963#[allow(dead_code)]
964fn calculate_coefficient_statistics<F>(
965    coefficients: &[Array1<F>],
966) -> Result<WaveletCoefficientStats<F>>
967where
968    F: Float + FromPrimitive + PartialOrd,
969{
970    let mut level_means = Vec::new();
971    let mut level_stds = Vec::new();
972    let mut level_skewness = Vec::new();
973    let mut level_kurtosis = Vec::new();
974    let mut level_max_magnitudes = Vec::new();
975    let mut level_zero_crossings = Vec::new();
976
977    for coeff_level in coefficients {
978        if coeff_level.is_empty() {
979            level_means.push(F::zero());
980            level_stds.push(F::zero());
981            level_skewness.push(F::zero());
982            level_kurtosis.push(F::zero());
983            level_max_magnitudes.push(F::zero());
984            level_zero_crossings.push(0);
985            continue;
986        }
987
988        let n = coeff_level.len();
989        let n_f = F::from(n).unwrap();
990
991        // Mean
992        let mean = coeff_level.sum() / n_f;
993        level_means.push(mean);
994
995        // Standard deviation
996        let variance = coeff_level.mapv(|x| (x - mean) * (x - mean)).sum() / n_f;
997        let std_dev = variance.sqrt();
998        level_stds.push(std_dev);
999
1000        // Skewness and kurtosis
1001        if std_dev > F::zero() {
1002            let mut sum_cube = F::zero();
1003            let mut sum_fourth = F::zero();
1004
1005            for &x in coeff_level.iter() {
1006                let norm_dev = (x - mean) / std_dev;
1007                let norm_dev_sq = norm_dev * norm_dev;
1008                sum_cube = sum_cube + norm_dev * norm_dev_sq;
1009                sum_fourth = sum_fourth + norm_dev_sq * norm_dev_sq;
1010            }
1011
1012            let skewness = sum_cube / n_f;
1013            let kurtosis = sum_fourth / n_f - F::from(3.0).unwrap();
1014
1015            level_skewness.push(skewness);
1016            level_kurtosis.push(kurtosis);
1017        } else {
1018            level_skewness.push(F::zero());
1019            level_kurtosis.push(F::zero());
1020        }
1021
1022        // Maximum magnitude
1023        let max_magnitude = coeff_level
1024            .iter()
1025            .map(|&x| x.abs())
1026            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1027            .unwrap_or(F::zero());
1028        level_max_magnitudes.push(max_magnitude);
1029
1030        // Zero crossings
1031        let mut zero_crossings = 0;
1032        for i in 1..coeff_level.len() {
1033            if (coeff_level[i - 1] >= F::zero()) != (coeff_level[i] >= F::zero()) {
1034                zero_crossings += 1;
1035            }
1036        }
1037        level_zero_crossings.push(zero_crossings);
1038    }
1039
1040    Ok(WaveletCoefficientStats {
1041        level_means,
1042        level_stds,
1043        level_skewness,
1044        level_kurtosis,
1045        level_max_magnitudes,
1046        level_zero_crossings,
1047    })
1048}
1049
1050// =============================================================================
1051// Wavelet Denoising
1052// =============================================================================
1053
1054/// Perform wavelet denoising and extract denoising-related features
1055///
1056/// # Arguments
1057///
1058/// * `signal` - Input noisy signal
1059/// * `config` - Wavelet configuration including denoising method
1060///
1061/// # Returns
1062///
1063/// Tuple of (denoised_signal, denoising_features)
1064#[allow(dead_code)]
1065pub fn wavelet_denoise<F>(
1066    signal: &Array1<F>,
1067    config: &WaveletConfig,
1068) -> Result<(Array1<F>, WaveletDenoisingFeatures<F>)>
1069where
1070    F: Float + FromPrimitive + Debug + Clone + PartialOrd,
1071{
1072    // Perform DWT
1073    let dwt_result = discrete_wavelet_transform(signal, config)?;
1074
1075    // Calculate optimal threshold
1076    let threshold =
1077        calculate_optimal_threshold(&dwt_result.coefficients, &config.denoising_method)?;
1078
1079    // Apply thresholding
1080    let (thresholded_coeffs, coefficients_thresholded) = apply_thresholding(
1081        &dwt_result.coefficients,
1082        threshold,
1083        &config.denoising_method,
1084    )?;
1085
1086    // Reconstruct signal (simplified - in practice would use inverse DWT)
1087    let denoised_signal = reconstruct_signal_simplified(&thresholded_coeffs)?;
1088
1089    // Calculate denoising features
1090    let original_energy = signal.mapv(|x| x * x).sum();
1091    let denoised_energy = denoised_signal.mapv(|x| x * x).sum();
1092    let energy_preserved = if original_energy > F::zero() {
1093        denoised_energy / original_energy
1094    } else {
1095        F::zero()
1096    };
1097
1098    // Calculate SNR improvement (simplified)
1099    let snr_improvement = calculate_snr_improvement(signal, &denoised_signal)?;
1100
1101    // Calculate MSE reduction (simplified)
1102    let mse_reduction = calculate_mse_reduction(signal, &denoised_signal)?;
1103
1104    let features = WaveletDenoisingFeatures {
1105        snr_improvement,
1106        energy_preserved,
1107        coefficients_thresholded,
1108        optimal_threshold: threshold,
1109        mse_reduction,
1110    };
1111
1112    Ok((denoised_signal, features))
1113}
1114
1115/// Calculate optimal threshold for denoising
1116#[allow(dead_code)]
1117fn calculate_optimal_threshold<F>(coefficients: &[Array1<F>], method: &DenoisingMethod) -> Result<F>
1118where
1119    F: Float + FromPrimitive + PartialOrd,
1120{
1121    // Calculate noise level estimate using MAD of finest detail coefficients
1122    let finest_detail = &coefficients[coefficients.len() - 1];
1123    if finest_detail.is_empty() {
1124        return Ok(F::zero());
1125    }
1126
1127    let mut sorted_coeffs: Vec<F> = finest_detail.iter().map(|&x| x.abs()).collect();
1128    sorted_coeffs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1129
1130    let median_idx = sorted_coeffs.len() / 2;
1131    let mad = if sorted_coeffs.len().is_multiple_of(2) {
1132        (sorted_coeffs[median_idx - 1] + sorted_coeffs[median_idx]) / F::from(2.0).unwrap()
1133    } else {
1134        sorted_coeffs[median_idx]
1135    };
1136
1137    let sigma = mad / F::from(0.6745).unwrap(); // MAD to standard deviation conversion
1138
1139    match method {
1140        DenoisingMethod::Hard | DenoisingMethod::Soft => {
1141            // Universal threshold
1142            let n = F::from(finest_detail.len()).unwrap();
1143            Ok(sigma * (F::from(2.0).unwrap() * n.ln()).sqrt())
1144        }
1145        DenoisingMethod::Sure => {
1146            // SURE threshold (simplified)
1147            Ok(sigma * F::from(1.5).unwrap())
1148        }
1149        DenoisingMethod::Minimax => {
1150            // Minimax threshold (simplified)
1151            Ok(sigma * F::from(0.8).unwrap())
1152        }
1153    }
1154}
1155
1156/// Apply thresholding to wavelet coefficients
1157#[allow(dead_code)]
1158fn apply_thresholding<F>(
1159    coefficients: &[Array1<F>],
1160    threshold: F,
1161    method: &DenoisingMethod,
1162) -> Result<(Vec<Array1<F>>, usize)>
1163where
1164    F: Float + FromPrimitive + PartialOrd + Clone,
1165{
1166    let mut thresholded_coeffs = Vec::new();
1167    let mut total_thresholded = 0;
1168
1169    for (level, coeff_level) in coefficients.iter().enumerate() {
1170        if level == 0 {
1171            // Don't threshold approximation coefficients
1172            thresholded_coeffs.push(coeff_level.clone());
1173            continue;
1174        }
1175
1176        let mut thresholded_level = Array1::zeros(coeff_level.len());
1177        let mut _level_thresholded = 0;
1178
1179        for (i, &coeff) in coeff_level.iter().enumerate() {
1180            let abs_coeff = coeff.abs();
1181
1182            if abs_coeff <= threshold {
1183                _level_thresholded += 1;
1184                total_thresholded += 1;
1185                // Coefficient is set to zero (already initialized)
1186            } else {
1187                thresholded_level[i] = match method {
1188                    DenoisingMethod::Hard => coeff,
1189                    DenoisingMethod::Soft => {
1190                        let sign = if coeff >= F::zero() {
1191                            F::one()
1192                        } else {
1193                            -F::one()
1194                        };
1195                        sign * (abs_coeff - threshold)
1196                    }
1197                    DenoisingMethod::Sure | DenoisingMethod::Minimax => {
1198                        // Use soft thresholding for these methods
1199                        let sign = if coeff >= F::zero() {
1200                            F::one()
1201                        } else {
1202                            -F::one()
1203                        };
1204                        sign * (abs_coeff - threshold)
1205                    }
1206                };
1207            }
1208        }
1209
1210        thresholded_coeffs.push(thresholded_level);
1211    }
1212
1213    Ok((thresholded_coeffs, total_thresholded))
1214}
1215
1216/// Simplified signal reconstruction from thresholded coefficients
1217#[allow(dead_code)]
1218fn reconstruct_signal_simplified<F>(coefficients: &[Array1<F>]) -> Result<Array1<F>>
1219where
1220    F: Float + FromPrimitive + Clone,
1221{
1222    if coefficients.is_empty() {
1223        return Ok(Array1::zeros(0));
1224    }
1225
1226    // Simplified reconstruction: use approximation coefficients scaled by levels
1227    let approx_coeffs = &coefficients[0];
1228    let mut reconstructed = approx_coeffs.clone();
1229
1230    // Add scaled detail _coefficients (simplified approach)
1231    for (level, detail_coeffs) in coefficients.iter().enumerate().skip(1) {
1232        let scale_factor = F::from(2.0_f64.powi(level as i32)).unwrap();
1233
1234        // Upsample and add details (very simplified)
1235        for (i, &detail) in detail_coeffs.iter().enumerate() {
1236            let target_idx = i.min(reconstructed.len() - 1);
1237            reconstructed[target_idx] = reconstructed[target_idx] + detail / scale_factor;
1238        }
1239    }
1240
1241    Ok(reconstructed)
1242}
1243
1244/// Calculate SNR improvement after denoising
1245#[allow(dead_code)]
1246fn calculate_snr_improvement<F>(original: &Array1<F>, denoised: &Array1<F>) -> Result<F>
1247where
1248    F: Float + FromPrimitive,
1249{
1250    let signal_power = original.mapv(|x| x * x).sum();
1251    let noise_power = original
1252        .iter()
1253        .zip(denoised.iter())
1254        .fold(F::zero(), |acc, (&orig, &den)| {
1255            let diff = orig - den;
1256            acc + diff * diff
1257        });
1258
1259    if noise_power > F::zero() && signal_power > F::zero() {
1260        let snr = (signal_power / noise_power).ln() / F::from(10.0).unwrap().ln()
1261            * F::from(10.0).unwrap();
1262        Ok(snr)
1263    } else {
1264        Ok(F::zero())
1265    }
1266}
1267
1268/// Calculate MSE reduction after denoising
1269#[allow(dead_code)]
1270fn calculate_mse_reduction<F>(original: &Array1<F>, denoised: &Array1<F>) -> Result<F>
1271where
1272    F: Float + FromPrimitive,
1273{
1274    let n = F::from(original.len()).unwrap();
1275    let mse = original
1276        .iter()
1277        .zip(denoised.iter())
1278        .fold(F::zero(), |acc, (&orig, &den)| {
1279            let diff = orig - den;
1280            acc + diff * diff
1281        })
1282        / n;
1283
1284    // Return normalized MSE reduction
1285    let signal_variance = original.mapv(|x| x * x).sum() / n;
1286    if signal_variance > F::zero() {
1287        Ok(F::one() - (mse / signal_variance))
1288    } else {
1289        Ok(F::zero())
1290    }
1291}