scirs2_series/
correlation.rs

1//! Time series correlation and relationship analysis
2//!
3//! This module provides various methods for analyzing correlations and relationships between time series:
4//! - Cross-correlation functions (CCF)
5//! - Dynamic time warping (DTW) with various constraints
6//! - Time-frequency analysis using wavelets and spectrograms
7//! - Coherence analysis for frequency domain relationships
8
9use crate::error::TimeSeriesError;
10use scirs2_core::ndarray::{s, Array1, Array2};
11use scirs2_core::validation::checkarray_finite;
12use statrs::statistics::Statistics;
13use std::f64::consts::PI;
14
15/// Result type for correlation analysis
16pub type CorrelationResult<T> = Result<T, TimeSeriesError>;
17
18/// Cross-correlation function result
19#[derive(Debug, Clone)]
20pub struct CrossCorrelationResult {
21    /// Cross-correlation values
22    pub correlations: Array1<f64>,
23    /// Lag values corresponding to correlations
24    pub lags: Array1<i32>,
25    /// Maximum correlation value
26    pub max_correlation: f64,
27    /// Lag at maximum correlation
28    pub lag_at_max: i32,
29    /// Confidence intervals (if computed)
30    pub confidence_lower: Option<Array1<f64>>,
31    /// Confidence intervals (if computed)
32    pub confidence_upper: Option<Array1<f64>>,
33}
34
35/// Dynamic time warping result
36#[derive(Debug, Clone)]
37pub struct DTWResult {
38    /// DTW distance between series
39    pub distance: f64,
40    /// Optimal warping path (indices)
41    pub warping_path: Vec<(usize, usize)>,
42    /// Cost matrix
43    pub cost_matrix: Array2<f64>,
44    /// Normalized distance (if requested)
45    pub normalized_distance: Option<f64>,
46    /// Local cost function used
47    pub cost_function: DTWCostFunction,
48    /// Constraint type used
49    pub constraint: DTWConstraint,
50}
51
52/// Time-frequency analysis result
53#[derive(Debug, Clone)]
54pub struct TimeFrequencyResult {
55    /// Time-frequency representation (spectrogram)
56    pub spectrogram: Array2<f64>,
57    /// Time vector
58    pub times: Array1<f64>,
59    /// Frequency vector
60    pub frequencies: Array1<f64>,
61    /// Analysis method used
62    pub method: TimeFrequencyMethod,
63    /// Window parameters (if applicable)
64    pub window_info: Option<WindowInfo>,
65}
66
67/// Coherence analysis result
68#[derive(Debug, Clone)]
69pub struct CoherenceResult {
70    /// Coherence values
71    pub coherence: Array1<f64>,
72    /// Phase difference
73    pub phase: Array1<f64>,
74    /// Frequency vector
75    pub frequencies: Array1<f64>,
76    /// Cross-power spectral density
77    pub cross_psd: Array1<f64>,
78    /// Power spectral densities for each series
79    pub psd_x: Array1<f64>,
80    /// Power spectral densities for each series
81    pub psd_y: Array1<f64>,
82    /// Confidence level (if computed)
83    pub confidence_level: Option<f64>,
84    /// Confidence threshold for coherence
85    pub confidence_threshold: Option<f64>,
86}
87
88/// DTW cost functions
89#[derive(Debug, Clone, Copy)]
90pub enum DTWCostFunction {
91    /// Euclidean distance
92    Euclidean,
93    /// Manhattan distance
94    Manhattan,
95    /// Squared Euclidean distance
96    SquaredEuclidean,
97    /// Cosine distance
98    Cosine,
99}
100
101/// DTW constraint types
102#[derive(Debug, Clone, Copy)]
103pub enum DTWConstraint {
104    /// No constraint (full DTW)
105    None,
106    /// Sakoe-Chiba band constraint
107    SakoeChiba(usize),
108    /// Itakura parallelogram constraint
109    Itakura,
110}
111
112/// Time-frequency analysis methods
113#[derive(Debug, Clone, Copy)]
114pub enum TimeFrequencyMethod {
115    /// Short-time Fourier transform
116    STFT,
117    /// Continuous wavelet transform
118    CWT,
119    /// Morlet wavelet
120    Morlet,
121    /// Gabor transform
122    Gabor,
123}
124
125/// Window information for time-frequency analysis
126#[derive(Debug, Clone)]
127pub struct WindowInfo {
128    /// Window type
129    pub window_type: WindowType,
130    /// Window size
131    pub window_size: usize,
132    /// Overlap between windows
133    pub overlap: usize,
134}
135
136/// Window types for analysis
137#[derive(Debug, Clone, Copy)]
138pub enum WindowType {
139    /// Hamming window
140    Hamming,
141    /// Hanning window
142    Hanning,
143    /// Blackman window
144    Blackman,
145    /// Gaussian window
146    Gaussian,
147    /// Rectangular window
148    Rectangular,
149}
150
151/// Configuration for cross-correlation analysis
152#[derive(Debug, Clone)]
153pub struct CrossCorrelationConfig {
154    /// Maximum lag to compute
155    pub max_lag: usize,
156    /// Whether to normalize the series
157    pub normalize: bool,
158    /// Confidence level for intervals
159    pub confidence_level: Option<f64>,
160    /// Method for correlation calculation
161    pub method: CorrelationMethod,
162}
163
164/// Correlation calculation methods
165#[derive(Debug, Clone, Copy)]
166pub enum CorrelationMethod {
167    /// Pearson correlation
168    Pearson,
169    /// Spearman rank correlation
170    Spearman,
171    /// Kendall's tau
172    Kendall,
173}
174
175impl Default for CrossCorrelationConfig {
176    fn default() -> Self {
177        Self {
178            max_lag: 20,
179            normalize: true,
180            confidence_level: Some(0.95),
181            method: CorrelationMethod::Pearson,
182        }
183    }
184}
185
186/// Configuration for DTW analysis
187#[derive(Debug, Clone)]
188pub struct DTWConfig {
189    /// Cost function to use
190    pub cost_function: DTWCostFunction,
191    /// Constraint type
192    pub constraint: DTWConstraint,
193    /// Whether to normalize the distance
194    pub normalize: bool,
195    /// Step pattern (for advanced DTW variants)
196    pub step_pattern: StepPattern,
197}
198
199/// DTW step patterns
200#[derive(Debug, Clone, Copy)]
201pub enum StepPattern {
202    /// Symmetric step pattern
203    Symmetric,
204    /// Asymmetric step pattern
205    Asymmetric,
206    /// Quasi-symmetric step pattern
207    QuasiSymmetric,
208}
209
210impl Default for DTWConfig {
211    fn default() -> Self {
212        Self {
213            cost_function: DTWCostFunction::Euclidean,
214            constraint: DTWConstraint::None,
215            normalize: true,
216            step_pattern: StepPattern::Symmetric,
217        }
218    }
219}
220
221/// Configuration for time-frequency analysis
222#[derive(Debug, Clone)]
223pub struct TimeFrequencyConfig {
224    /// Analysis method
225    pub method: TimeFrequencyMethod,
226    /// Window configuration
227    pub window: WindowInfo,
228    /// Sampling frequency
229    pub sampling_freq: f64,
230    /// Frequency range (optional)
231    pub freq_range: Option<(f64, f64)>,
232    /// Number of frequency bins
233    pub n_freq_bins: Option<usize>,
234}
235
236impl Default for TimeFrequencyConfig {
237    fn default() -> Self {
238        Self {
239            method: TimeFrequencyMethod::STFT,
240            window: WindowInfo {
241                window_type: WindowType::Hanning,
242                window_size: 256,
243                overlap: 128,
244            },
245            sampling_freq: 1.0,
246            freq_range: None,
247            n_freq_bins: None,
248        }
249    }
250}
251
252/// Configuration for coherence analysis
253#[derive(Debug, Clone)]
254pub struct CoherenceConfig {
255    /// Window size for spectral estimation
256    pub window_size: usize,
257    /// Overlap between windows
258    pub overlap: usize,
259    /// Window type
260    pub window_type: WindowType,
261    /// Sampling frequency
262    pub sampling_freq: f64,
263    /// Confidence level for significance testing
264    pub confidence_level: Option<f64>,
265    /// Detrending method
266    pub detrend: DetrendMethod,
267}
268
269/// Detrending methods
270#[derive(Debug, Clone, Copy)]
271pub enum DetrendMethod {
272    /// No detrending
273    None,
274    /// Linear detrending
275    Linear,
276    /// Mean removal
277    Mean,
278}
279
280impl Default for CoherenceConfig {
281    fn default() -> Self {
282        Self {
283            window_size: 256,
284            overlap: 128,
285            window_type: WindowType::Hanning,
286            sampling_freq: 1.0,
287            confidence_level: Some(0.95),
288            detrend: DetrendMethod::Linear,
289        }
290    }
291}
292
293/// Main struct for correlation analysis
294pub struct CorrelationAnalyzer {
295    /// Random seed for reproducibility
296    pub random_seed: Option<u64>,
297}
298
299impl CorrelationAnalyzer {
300    /// Create a new correlation analyzer
301    pub fn new() -> Self {
302        Self { random_seed: None }
303    }
304
305    /// Create a new analyzer with random seed
306    pub fn with_seed(seed: u64) -> Self {
307        Self {
308            random_seed: Some(seed),
309        }
310    }
311
312    /// Compute cross-correlation function between two time series
313    ///
314    /// # Arguments
315    ///
316    /// * `x` - First time series
317    /// * `y` - Second time series
318    /// * `config` - Configuration for cross-correlation
319    ///
320    /// # Returns
321    ///
322    /// Result containing cross-correlation function and statistics
323    pub fn cross_correlation(
324        &self,
325        x: &Array1<f64>,
326        y: &Array1<f64>,
327        config: &CrossCorrelationConfig,
328    ) -> CorrelationResult<CrossCorrelationResult> {
329        checkarray_finite(x, "x")?;
330        checkarray_finite(y, "y")?;
331
332        if x.len() != y.len() {
333            return Err(TimeSeriesError::InvalidInput(
334                "Time series must have the same length".to_string(),
335            ));
336        }
337
338        let n = x.len();
339        if n < 2 * config.max_lag + 1 {
340            return Err(TimeSeriesError::InvalidInput(
341                "Time series too short for the specified maximum lag".to_string(),
342            ));
343        }
344
345        // Normalize series if requested
346        let x_proc = if config.normalize {
347            self.normalize_series(x)?
348        } else {
349            x.clone()
350        };
351
352        let y_proc = if config.normalize {
353            self.normalize_series(y)?
354        } else {
355            y.clone()
356        };
357
358        // Compute cross-correlations for different lags
359        let mut correlations = Array1::zeros(2 * config.max_lag + 1);
360        let mut lags = Array1::zeros(2 * config.max_lag + 1);
361
362        for i in 0..correlations.len() {
363            let lag = i as i32 - config.max_lag as i32;
364            lags[i] = lag;
365
366            correlations[i] = match config.method {
367                CorrelationMethod::Pearson => {
368                    self.compute_lagged_correlation(&x_proc, &y_proc, lag)?
369                }
370                CorrelationMethod::Spearman => {
371                    self.compute_lagged_spearman(&x_proc, &y_proc, lag)?
372                }
373                CorrelationMethod::Kendall => self.compute_lagged_kendall(&x_proc, &y_proc, lag)?,
374            };
375        }
376
377        // Find maximum correlation and its lag
378        let max_idx = correlations
379            .iter()
380            .enumerate()
381            .max_by(|(_, a), (_, b)| a.abs().partial_cmp(&b.abs()).unwrap())
382            .map(|(idx_, _)| idx_)
383            .unwrap_or(0);
384
385        let max_correlation = correlations[max_idx];
386        let lag_at_max = lags[max_idx];
387
388        // Compute confidence intervals if requested
389        let (confidence_lower, confidence_upper) = if let Some(conf_level) = config.confidence_level
390        {
391            let (lower, upper) =
392                self.compute_correlation_confidence_intervals(&correlations, n, conf_level)?;
393            (Some(lower), Some(upper))
394        } else {
395            (None, None)
396        };
397
398        Ok(CrossCorrelationResult {
399            correlations,
400            lags,
401            max_correlation,
402            lag_at_max,
403            confidence_lower,
404            confidence_upper,
405        })
406    }
407
408    /// Compute dynamic time warping distance between two time series
409    ///
410    /// # Arguments
411    ///
412    /// * `x` - First time series
413    /// * `y` - Second time series
414    /// * `config` - Configuration for DTW
415    ///
416    /// # Returns
417    ///
418    /// Result containing DTW distance and warping path
419    pub fn dynamic_time_warping(
420        &self,
421        x: &Array1<f64>,
422        y: &Array1<f64>,
423        config: &DTWConfig,
424    ) -> CorrelationResult<DTWResult> {
425        checkarray_finite(x, "x")?;
426        checkarray_finite(y, "y")?;
427
428        let n = x.len();
429        let m = y.len();
430
431        if n == 0 || m == 0 {
432            return Err(TimeSeriesError::InvalidInput(
433                "Time series cannot be empty".to_string(),
434            ));
435        }
436
437        // Initialize cost matrix
438        let mut cost_matrix = Array2::from_elem((n + 1, m + 1), f64::INFINITY);
439        cost_matrix[[0, 0]] = 0.0;
440
441        // Fill cost matrix according to constraint
442        match config.constraint {
443            DTWConstraint::None => {
444                for i in 1..=n {
445                    for j in 1..=m {
446                        let local_cost =
447                            self.compute_local_cost(x[i - 1], y[j - 1], config.cost_function);
448                        cost_matrix[[i, j]] = local_cost
449                            + self.min_predecessor(&cost_matrix, i, j, config.step_pattern);
450                    }
451                }
452            }
453            DTWConstraint::SakoeChiba(radius) => {
454                for i in 1..=n {
455                    let j_start =
456                        ((i as f64 * m as f64 / n as f64) - radius as f64).max(1.0) as usize;
457                    let j_end =
458                        ((i as f64 * m as f64 / n as f64) + radius as f64).min(m as f64) as usize;
459
460                    for j in j_start..=j_end {
461                        let local_cost =
462                            self.compute_local_cost(x[i - 1], y[j - 1], config.cost_function);
463                        cost_matrix[[i, j]] = local_cost
464                            + self.min_predecessor(&cost_matrix, i, j, config.step_pattern);
465                    }
466                }
467            }
468            DTWConstraint::Itakura => {
469                // Simplified Itakura parallelogram constraint
470                for i in 1..=n {
471                    let slope_constraint = 2.0;
472                    let j_start = ((i as f64 / slope_constraint).max(1.0)) as usize;
473                    let j_end = ((i as f64 * slope_constraint).min(m as f64)) as usize;
474
475                    for j in j_start..=j_end {
476                        let local_cost =
477                            self.compute_local_cost(x[i - 1], y[j - 1], config.cost_function);
478                        cost_matrix[[i, j]] = local_cost
479                            + self.min_predecessor(&cost_matrix, i, j, config.step_pattern);
480                    }
481                }
482            }
483        }
484
485        let distance = cost_matrix[[n, m]];
486
487        if !distance.is_finite() {
488            return Err(TimeSeriesError::ComputationError(
489                "DTW computation resulted in infinite distance".to_string(),
490            ));
491        }
492
493        // Backtrack to find optimal warping path
494        let warping_path = self.backtrack_warping_path(&cost_matrix, n, m, config.step_pattern)?;
495
496        // Normalize distance if requested
497        let normalized_distance = if config.normalize {
498            Some(distance / warping_path.len() as f64)
499        } else {
500            None
501        };
502
503        Ok(DTWResult {
504            distance,
505            warping_path,
506            cost_matrix: cost_matrix.slice(s![1.., 1..]).to_owned(),
507            normalized_distance,
508            cost_function: config.cost_function,
509            constraint: config.constraint,
510        })
511    }
512
513    /// Perform time-frequency analysis on a time series
514    ///
515    /// # Arguments
516    ///
517    /// * `x` - Input time series
518    /// * `config` - Configuration for time-frequency analysis
519    ///
520    /// # Returns
521    ///
522    /// Result containing time-frequency representation
523    pub fn time_frequency_analysis(
524        &self,
525        x: &Array1<f64>,
526        config: &TimeFrequencyConfig,
527    ) -> CorrelationResult<TimeFrequencyResult> {
528        checkarray_finite(x, "x")?;
529
530        if x.len() < config.window.window_size {
531            return Err(TimeSeriesError::InvalidInput(
532                "Time series too short for the specified window size".to_string(),
533            ));
534        }
535
536        match config.method {
537            TimeFrequencyMethod::STFT => self.compute_stft(x, config),
538            TimeFrequencyMethod::CWT => self.compute_cwt(x, config),
539            TimeFrequencyMethod::Morlet => self.compute_morlet_wavelet(x, config),
540            TimeFrequencyMethod::Gabor => self.compute_gabor_transform(x, config),
541        }
542    }
543
544    /// Compute coherence between two time series
545    ///
546    /// # Arguments
547    ///
548    /// * `x` - First time series
549    /// * `y` - Second time series
550    /// * `config` - Configuration for coherence analysis
551    ///
552    /// # Returns
553    ///
554    /// Result containing coherence and phase information
555    pub fn coherence_analysis(
556        &self,
557        x: &Array1<f64>,
558        y: &Array1<f64>,
559        config: &CoherenceConfig,
560    ) -> CorrelationResult<CoherenceResult> {
561        checkarray_finite(x, "x")?;
562        checkarray_finite(y, "y")?;
563
564        if x.len() != y.len() {
565            return Err(TimeSeriesError::InvalidInput(
566                "Time series must have the same length".to_string(),
567            ));
568        }
569
570        if x.len() < config.window_size {
571            return Err(TimeSeriesError::InvalidInput(
572                "Time series too short for the specified window size".to_string(),
573            ));
574        }
575
576        // Preprocess series
577        let x_proc = self.detrend_series(x, config.detrend)?;
578        let y_proc = self.detrend_series(y, config.detrend)?;
579
580        // Compute windowed segments
581        let hop_size = config.window_size - config.overlap;
582        let n_windows = (x.len() - config.overlap) / hop_size;
583
584        if n_windows < 2 {
585            return Err(TimeSeriesError::InvalidInput(
586                "Not enough data for reliable coherence estimation".to_string(),
587            ));
588        }
589
590        // Generate window function
591        let window = self.generate_window(config.window_type, config.window_size)?;
592
593        // Compute spectral estimates
594        let freq_bins = config.window_size / 2 + 1;
595        let mut cross_psd = Array1::zeros(freq_bins);
596        let mut psd_x = Array1::zeros(freq_bins);
597        let mut psd_y = Array1::zeros(freq_bins);
598
599        for i in 0..n_windows {
600            let start_idx = i * hop_size;
601            let end_idx = start_idx + config.window_size;
602
603            if end_idx > x_proc.len() {
604                break;
605            }
606
607            let x_segment = x_proc.slice(s![start_idx..end_idx]).to_owned();
608            let y_segment = y_proc.slice(s![start_idx..end_idx]).to_owned();
609
610            // Apply window
611            let x_windowed = &x_segment * &window;
612            let y_windowed = &y_segment * &window;
613
614            // Compute FFTs
615            let x_fft = self.compute_fft(&x_windowed)?;
616            let y_fft = self.compute_fft(&y_windowed)?;
617
618            // Accumulate spectral estimates
619            for k in 0..freq_bins {
620                let x_complex = x_fft[k];
621                let y_complex = y_fft[k];
622
623                cross_psd[k] += x_complex.re * y_complex.re + x_complex.im * y_complex.im;
624                psd_x[k] += x_complex.re * x_complex.re + x_complex.im * x_complex.im;
625                psd_y[k] += y_complex.re * y_complex.re + y_complex.im * y_complex.im;
626            }
627        }
628
629        // Normalize by number of windows
630        cross_psd /= n_windows as f64;
631        psd_x /= n_windows as f64;
632        psd_y /= n_windows as f64;
633
634        // Compute coherence
635        let mut coherence = Array1::zeros(freq_bins);
636        let mut phase = Array1::zeros(freq_bins);
637
638        for k in 0..freq_bins {
639            let denominator = (psd_x[k] * psd_y[k]).sqrt();
640            if denominator > f64::EPSILON {
641                coherence[k] = (cross_psd[k] / denominator).abs();
642                phase[k] = cross_psd[k].atan2(0.0); // Simplified phase calculation
643            }
644        }
645
646        // Generate frequency vector
647        let frequencies = Array1::from_iter(
648            (0..freq_bins).map(|k| k as f64 * config.sampling_freq / (2.0 * freq_bins as f64)),
649        );
650
651        // Compute confidence threshold if requested
652        let confidence_threshold = config
653            .confidence_level
654            .map(|conf_level| self.coherence_confidence_threshold(conf_level, n_windows));
655
656        Ok(CoherenceResult {
657            coherence,
658            phase,
659            frequencies,
660            cross_psd,
661            psd_x,
662            psd_y,
663            confidence_level: config.confidence_level,
664            confidence_threshold,
665        })
666    }
667
668    // Helper methods
669
670    fn normalize_series(&self, x: &Array1<f64>) -> CorrelationResult<Array1<f64>> {
671        let mean = x.mean().unwrap_or(0.0);
672        let std = (x.mapv(|xi| (xi - mean).powi(2)).sum() / x.len() as f64).sqrt();
673
674        if std < f64::EPSILON {
675            return Ok(x - mean); // Only remove mean if std is zero
676        }
677
678        Ok((x - mean) / std)
679    }
680
681    fn compute_lagged_correlation(
682        &self,
683        x: &Array1<f64>,
684        y: &Array1<f64>,
685        lag: i32,
686    ) -> CorrelationResult<f64> {
687        let _n = x.len() as i32;
688
689        let (x_slice, y_slice) = if lag >= 0 {
690            let lag = lag as usize;
691            if lag >= x.len() {
692                return Ok(0.0);
693            }
694            (
695                x.slice(s![lag..]).to_owned(),
696                y.slice(s![..x.len() - lag]).to_owned(),
697            )
698        } else {
699            let lag = (-lag) as usize;
700            if lag >= y.len() {
701                return Ok(0.0);
702            }
703            (
704                x.slice(s![..x.len() - lag]).to_owned(),
705                y.slice(s![lag..]).to_owned(),
706            )
707        };
708
709        if x_slice.is_empty() || y_slice.is_empty() {
710            return Ok(0.0);
711        }
712
713        self.pearson_correlation(&x_slice, &y_slice)
714    }
715
716    fn compute_lagged_spearman(
717        &self,
718        x: &Array1<f64>,
719        y: &Array1<f64>,
720        lag: i32,
721    ) -> CorrelationResult<f64> {
722        let (x_slice, y_slice) = if lag >= 0 {
723            let lag = lag as usize;
724            if lag >= x.len() {
725                return Ok(0.0);
726            }
727            (
728                x.slice(s![lag..]).to_owned(),
729                y.slice(s![..x.len() - lag]).to_owned(),
730            )
731        } else {
732            let lag = (-lag) as usize;
733            if lag >= y.len() {
734                return Ok(0.0);
735            }
736            (
737                x.slice(s![..x.len() - lag]).to_owned(),
738                y.slice(s![lag..]).to_owned(),
739            )
740        };
741
742        if x_slice.is_empty() || y_slice.is_empty() {
743            return Ok(0.0);
744        }
745
746        // Convert to ranks and compute Pearson correlation
747        let x_ranks = self.compute_ranks(&x_slice);
748        let y_ranks = self.compute_ranks(&y_slice);
749
750        self.pearson_correlation(&x_ranks, &y_ranks)
751    }
752
753    fn compute_lagged_kendall(
754        &self,
755        x: &Array1<f64>,
756        y: &Array1<f64>,
757        lag: i32,
758    ) -> CorrelationResult<f64> {
759        let (x_slice, y_slice) = if lag >= 0 {
760            let lag = lag as usize;
761            if lag >= x.len() {
762                return Ok(0.0);
763            }
764            (
765                x.slice(s![lag..]).to_owned(),
766                y.slice(s![..x.len() - lag]).to_owned(),
767            )
768        } else {
769            let lag = (-lag) as usize;
770            if lag >= y.len() {
771                return Ok(0.0);
772            }
773            (
774                x.slice(s![..x.len() - lag]).to_owned(),
775                y.slice(s![lag..]).to_owned(),
776            )
777        };
778
779        if x_slice.is_empty() || y_slice.is_empty() {
780            return Ok(0.0);
781        }
782
783        self.kendall_tau(&x_slice, &y_slice)
784    }
785
786    fn pearson_correlation(&self, x: &Array1<f64>, y: &Array1<f64>) -> CorrelationResult<f64> {
787        if x.len() != y.len() || x.is_empty() {
788            return Ok(0.0);
789        }
790
791        let n = x.len() as f64;
792        let mean_x = x.sum() / n;
793        let mean_y = y.sum() / n;
794
795        let mut numerator = 0.0;
796        let mut sum_sq_x = 0.0;
797        let mut sum_sq_y = 0.0;
798
799        for (xi, yi) in x.iter().zip(y.iter()) {
800            let diff_x = xi - mean_x;
801            let diff_y = yi - mean_y;
802            numerator += diff_x * diff_y;
803            sum_sq_x += diff_x * diff_x;
804            sum_sq_y += diff_y * diff_y;
805        }
806
807        let denominator = (sum_sq_x * sum_sq_y).sqrt();
808        if denominator < f64::EPSILON {
809            return Ok(0.0);
810        }
811
812        Ok(numerator / denominator)
813    }
814
815    fn compute_ranks(&self, x: &Array1<f64>) -> Array1<f64> {
816        let mut indexed_values: Vec<(usize, f64)> =
817            x.iter().enumerate().map(|(i, &val)| (i, val)).collect();
818        indexed_values.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
819
820        let mut ranks = Array1::zeros(x.len());
821        for (rank, &(original_index, _)) in indexed_values.iter().enumerate() {
822            ranks[original_index] = rank as f64 + 1.0;
823        }
824
825        ranks
826    }
827
828    fn kendall_tau(&self, x: &Array1<f64>, y: &Array1<f64>) -> CorrelationResult<f64> {
829        let n = x.len();
830        if n < 2 {
831            return Ok(0.0);
832        }
833
834        let mut concordant = 0;
835        let mut discordant = 0;
836
837        for i in 0..n {
838            for j in i + 1..n {
839                let x_diff = x[i] - x[j];
840                let y_diff = y[i] - y[j];
841
842                if x_diff * y_diff > 0.0 {
843                    concordant += 1;
844                } else if x_diff * y_diff < 0.0 {
845                    discordant += 1;
846                }
847            }
848        }
849
850        let total_pairs = n * (n - 1) / 2;
851        if total_pairs == 0 {
852            return Ok(0.0);
853        }
854
855        Ok((concordant - discordant) as f64 / total_pairs as f64)
856    }
857
858    fn compute_correlation_confidence_intervals(
859        &self,
860        correlations: &Array1<f64>,
861        n: usize,
862        confidence_level: f64,
863    ) -> CorrelationResult<(Array1<f64>, Array1<f64>)> {
864        let alpha = 1.0 - confidence_level;
865        let z_score = self.normal_quantile(1.0 - alpha / 2.0);
866        let std_error = 1.0 / (n as f64).sqrt();
867
868        let lower = correlations.mapv(|r| r - z_score * std_error);
869        let upper = correlations.mapv(|r| r + z_score * std_error);
870
871        Ok((lower, upper))
872    }
873
874    fn normal_quantile(&self, p: f64) -> f64 {
875        // Beasley-Springer-Moro algorithm approximation
876        if p <= 0.0 {
877            return f64::NEG_INFINITY;
878        }
879        if p >= 1.0 {
880            return f64::INFINITY;
881        }
882        if (p - 0.5).abs() < f64::EPSILON {
883            return 0.0;
884        }
885
886        let a = [
887            -3.969_683_028_665_376e1,
888            2.209_460_984_245_205e2,
889            -2.759_285_104_469_687e2,
890            1.383_577_518_672_69e2,
891            -3.066_479_806_614_716e1,
892            2.506_628_277_459_239,
893        ];
894
895        let b = [
896            -5.447_609_879_822_406e1,
897            1.615_858_368_580_409e2,
898            -1.556_989_798_598_866e2,
899            6.680_131_188_771_972e1,
900            -1.328_068_155_288_572e1,
901        ];
902
903        if !(0.02425..=0.97575).contains(&p) {
904            let q = if p < 0.5 { p } else { 1.0 - p };
905            let u = (-2.0 * q.ln()).sqrt();
906            let sign = if p < 0.5 { -1.0 } else { 1.0 };
907            sign * (a[0] + u * (a[1] + u * (a[2] + u * (a[3] + u * (a[4] + u * a[5])))))
908                / (1.0 + u * (b[0] + u * (b[1] + u * (b[2] + u * (b[3] + u * b[4])))))
909        } else {
910            let q = p - 0.5;
911            let r = q * q;
912            q * (a[0] + r * (a[1] + r * (a[2] + r * (a[3] + r * (a[4] + r * a[5])))))
913                / (1.0 + r * (b[0] + r * (b[1] + r * (b[2] + r * (b[3] + r * b[4])))))
914        }
915    }
916
917    fn compute_local_cost(&self, x: f64, y: f64, costfunction: DTWCostFunction) -> f64 {
918        match costfunction {
919            DTWCostFunction::Euclidean => (x - y).abs(),
920            DTWCostFunction::Manhattan => (x - y).abs(),
921            DTWCostFunction::SquaredEuclidean => (x - y).powi(2),
922            DTWCostFunction::Cosine => {
923                let dot_product = x * y;
924                let magnitude = (x * x + y * y).sqrt();
925                if magnitude < f64::EPSILON {
926                    0.0
927                } else {
928                    1.0 - dot_product / magnitude
929                }
930            }
931        }
932    }
933
934    fn min_predecessor(
935        &self,
936        cost_matrix: &Array2<f64>,
937        i: usize,
938        j: usize,
939        step_pattern: StepPattern,
940    ) -> f64 {
941        match step_pattern {
942            StepPattern::Symmetric => {
943                let candidates = [
944                    cost_matrix[[i - 1, j]],     // Vertical
945                    cost_matrix[[i, j - 1]],     // Horizontal
946                    cost_matrix[[i - 1, j - 1]], // Diagonal
947                ];
948                candidates.iter().cloned().fold(f64::INFINITY, f64::min)
949            }
950            StepPattern::Asymmetric => {
951                let candidates = [
952                    cost_matrix[[i - 1, j]] + cost_matrix[[i - 1, j - 1]],
953                    cost_matrix[[i, j - 1]],
954                    cost_matrix[[i - 1, j - 1]],
955                ];
956                candidates.iter().cloned().fold(f64::INFINITY, f64::min)
957            }
958            StepPattern::QuasiSymmetric => {
959                let candidates = [
960                    cost_matrix[[i - 1, j]],
961                    cost_matrix[[i, j - 1]],
962                    2.0 * cost_matrix[[i - 1, j - 1]],
963                ];
964                candidates.iter().cloned().fold(f64::INFINITY, f64::min)
965            }
966        }
967    }
968
969    fn backtrack_warping_path(
970        &self,
971        cost_matrix: &Array2<f64>,
972        n: usize,
973        m: usize,
974        step_pattern: StepPattern,
975    ) -> CorrelationResult<Vec<(usize, usize)>> {
976        let mut path = Vec::new();
977        let mut i = n;
978        let mut j = m;
979
980        path.push((i - 1, j - 1)); // Convert to 0-indexed
981
982        while i > 1 || j > 1 {
983            let candidates = match step_pattern {
984                StepPattern::Symmetric => vec![
985                    (
986                        i.saturating_sub(1),
987                        j,
988                        cost_matrix[[i.saturating_sub(1), j]],
989                    ),
990                    (
991                        i,
992                        j.saturating_sub(1),
993                        cost_matrix[[i, j.saturating_sub(1)]],
994                    ),
995                    (
996                        i.saturating_sub(1),
997                        j.saturating_sub(1),
998                        cost_matrix[[i.saturating_sub(1), j.saturating_sub(1)]],
999                    ),
1000                ],
1001                _ => vec![
1002                    // Simplified for other patterns
1003                    (
1004                        i.saturating_sub(1),
1005                        j,
1006                        cost_matrix[[i.saturating_sub(1), j]],
1007                    ),
1008                    (
1009                        i,
1010                        j.saturating_sub(1),
1011                        cost_matrix[[i, j.saturating_sub(1)]],
1012                    ),
1013                    (
1014                        i.saturating_sub(1),
1015                        j.saturating_sub(1),
1016                        cost_matrix[[i.saturating_sub(1), j.saturating_sub(1)]],
1017                    ),
1018                ],
1019            };
1020
1021            let (next_i, next_j_, _) = candidates
1022                .into_iter()
1023                .filter(|(ni, nj_, _)| *ni > 0 && *nj_ > 0)
1024                .min_by(|a, b| a.2.partial_cmp(&b.2).unwrap())
1025                .unwrap_or((1, 1, 0.0));
1026
1027            i = next_i;
1028            j = next_j_;
1029            path.push((i - 1, j - 1)); // Convert to 0-indexed
1030        }
1031
1032        path.reverse();
1033        Ok(path)
1034    }
1035
1036    fn compute_stft(
1037        &self,
1038        x: &Array1<f64>,
1039        config: &TimeFrequencyConfig,
1040    ) -> CorrelationResult<TimeFrequencyResult> {
1041        let window = self.generate_window(config.window.window_type, config.window.window_size)?;
1042        let hop_size = config.window.window_size - config.window.overlap;
1043        let n_windows = (x.len() - config.window.overlap) / hop_size;
1044        let freq_bins = config.window.window_size / 2 + 1;
1045
1046        let mut spectrogram = Array2::zeros((freq_bins, n_windows));
1047        let mut times = Array1::zeros(n_windows);
1048
1049        for i in 0..n_windows {
1050            let start_idx = i * hop_size;
1051            let end_idx = start_idx + config.window.window_size;
1052
1053            if end_idx > x.len() {
1054                break;
1055            }
1056
1057            times[i] = start_idx as f64 / config.sampling_freq;
1058
1059            let segment = x.slice(s![start_idx..end_idx]).to_owned();
1060            let windowed = &segment * &window;
1061
1062            let fft_result = self.compute_fft(&windowed)?;
1063
1064            for k in 0..freq_bins {
1065                let magnitude = (fft_result[k].re.powi(2) + fft_result[k].im.powi(2)).sqrt();
1066                spectrogram[[k, i]] = magnitude;
1067            }
1068        }
1069
1070        let frequencies = Array1::from_iter(
1071            (0..freq_bins)
1072                .map(|k| k as f64 * config.sampling_freq / config.window.window_size as f64),
1073        );
1074
1075        Ok(TimeFrequencyResult {
1076            spectrogram,
1077            times,
1078            frequencies,
1079            method: config.method,
1080            window_info: Some(config.window.clone()),
1081        })
1082    }
1083
1084    fn compute_cwt(
1085        &self,
1086        x: &Array1<f64>,
1087        config: &TimeFrequencyConfig,
1088    ) -> CorrelationResult<TimeFrequencyResult> {
1089        // Simplified continuous wavelet transform implementation
1090        let n_scales = config.n_freq_bins.unwrap_or(50);
1091        let n_times = x.len();
1092
1093        let mut spectrogram = Array2::zeros((n_scales, n_times));
1094        let times = Array1::from_iter((0..n_times).map(|i| i as f64 / config.sampling_freq));
1095        let mut frequencies = Array1::zeros(n_scales);
1096
1097        // Generate scales logarithmically
1098        let min_scale = 1.0;
1099        let max_scale = n_times as f64 / 4.0;
1100        let scale_factor = (max_scale / min_scale).powf(1.0 / (n_scales - 1) as f64);
1101
1102        for scale_idx in 0..n_scales {
1103            let current_scale = min_scale * scale_factor.powi(scale_idx as i32);
1104            frequencies[scale_idx] = config.sampling_freq / (2.0 * PI * current_scale);
1105
1106            // Convolve with Morlet wavelet at current scale
1107            for t in 0..n_times {
1108                let mut convolution_result = 0.0;
1109
1110                for tau in 0..n_times {
1111                    let t_normalized = (tau as f64 - t as f64) / current_scale;
1112                    let wavelet_value = self.morlet_wavelet(t_normalized);
1113                    convolution_result += x[tau] * wavelet_value;
1114                }
1115
1116                spectrogram[[scale_idx, t]] = convolution_result.abs();
1117            }
1118        }
1119
1120        Ok(TimeFrequencyResult {
1121            spectrogram,
1122            times,
1123            frequencies,
1124            method: config.method,
1125            window_info: None,
1126        })
1127    }
1128
1129    fn compute_morlet_wavelet(
1130        &self,
1131        x: &Array1<f64>,
1132        config: &TimeFrequencyConfig,
1133    ) -> CorrelationResult<TimeFrequencyResult> {
1134        // Similar to CWT but specifically for Morlet wavelets
1135        self.compute_cwt(x, config)
1136    }
1137
1138    fn compute_gabor_transform(
1139        &self,
1140        x: &Array1<f64>,
1141        config: &TimeFrequencyConfig,
1142    ) -> CorrelationResult<TimeFrequencyResult> {
1143        // Gabor transform is essentially a windowed Fourier transform
1144        self.compute_stft(x, config)
1145    }
1146
1147    fn morlet_wavelet(&self, t: f64) -> f64 {
1148        let sigma = 1.0;
1149        let omega = 6.0;
1150        let gaussian = (-t * t / (2.0 * sigma * sigma)).exp();
1151        let oscillation = (omega * t).cos();
1152        gaussian * oscillation / (PI.powf(0.25) * sigma.sqrt())
1153    }
1154
1155    fn generate_window(
1156        &self,
1157        window_type: WindowType,
1158        size: usize,
1159    ) -> CorrelationResult<Array1<f64>> {
1160        let mut window = Array1::zeros(size);
1161
1162        match window_type {
1163            WindowType::Hamming => {
1164                for i in 0..size {
1165                    window[i] = 0.54 - 0.46 * (2.0 * PI * i as f64 / (size - 1) as f64).cos();
1166                }
1167            }
1168            WindowType::Hanning => {
1169                for i in 0..size {
1170                    window[i] = 0.5 * (1.0 - (2.0 * PI * i as f64 / (size - 1) as f64).cos());
1171                }
1172            }
1173            WindowType::Blackman => {
1174                for i in 0..size {
1175                    let factor = 2.0 * PI * i as f64 / (size - 1) as f64;
1176                    window[i] = 0.42 - 0.5 * factor.cos() + 0.08 * (2.0 * factor).cos();
1177                }
1178            }
1179            WindowType::Gaussian => {
1180                let sigma = (size as f64) / 6.0;
1181                let center = (size - 1) as f64 / 2.0;
1182                for i in 0..size {
1183                    let x = (i as f64 - center) / sigma;
1184                    window[i] = (-0.5 * x * x).exp();
1185                }
1186            }
1187            WindowType::Rectangular => {
1188                window.fill(1.0);
1189            }
1190        }
1191
1192        Ok(window)
1193    }
1194
1195    fn detrend_series(
1196        &self,
1197        x: &Array1<f64>,
1198        method: DetrendMethod,
1199    ) -> CorrelationResult<Array1<f64>> {
1200        match method {
1201            DetrendMethod::None => Ok(x.clone()),
1202            DetrendMethod::Mean => {
1203                let mean = x.mean().unwrap_or(0.0);
1204                Ok(x - mean)
1205            }
1206            DetrendMethod::Linear => {
1207                let n = x.len() as f64;
1208                let t: Array1<f64> = Array1::from_iter((0..x.len()).map(|i| i as f64));
1209
1210                // Linear regression: y = a + b*t
1211                let sum_t = t.sum();
1212                let sum_y = x.sum();
1213                let sum_tt = t.mapv(|ti| ti * ti).sum();
1214                let sum_ty = t.iter().zip(x.iter()).map(|(ti, yi)| ti * yi).sum::<f64>();
1215
1216                let b = (n * sum_ty - sum_t * sum_y) / (n * sum_tt - sum_t * sum_t);
1217                let a = (sum_y - b * sum_t) / n;
1218
1219                let trend = t.mapv(|ti| a + b * ti);
1220                Ok(x - &trend)
1221            }
1222        }
1223    }
1224
1225    fn compute_fft(&self, x: &Array1<f64>) -> CorrelationResult<Vec<Complex>> {
1226        // Simplified FFT implementation using DFT
1227        let n = x.len();
1228        let mut result = vec![Complex { re: 0.0, im: 0.0 }; n];
1229
1230        #[allow(clippy::needless_range_loop)]
1231        for k in 0..n {
1232            let mut sum = Complex { re: 0.0, im: 0.0 };
1233            for t in 0..n {
1234                let angle = -2.0 * PI * k as f64 * t as f64 / n as f64;
1235                sum.re += x[t] * angle.cos();
1236                sum.im += x[t] * angle.sin();
1237            }
1238            result[k] = sum;
1239        }
1240
1241        Ok(result)
1242    }
1243
1244    fn coherence_confidence_threshold(&self, confidence_level: f64, nsegments: usize) -> f64 {
1245        // Approximation for coherence confidence threshold
1246        let alpha = 1.0 - confidence_level;
1247        let dof = 2 * nsegments;
1248
1249        // For large DOF, use chi-squared approximation
1250        if dof > 30 {
1251            let z = self.normal_quantile(1.0 - alpha);
1252            1.0 - (z * z / dof as f64).exp()
1253        } else {
1254            // Conservative threshold for small samples
1255            0.5
1256        }
1257    }
1258}
1259
1260impl Default for CorrelationAnalyzer {
1261    fn default() -> Self {
1262        Self::new()
1263    }
1264}
1265
1266/// Complex number representation for FFT
1267#[derive(Debug, Clone, Copy)]
1268struct Complex {
1269    re: f64,
1270    im: f64,
1271}
1272
1273#[cfg(test)]
1274mod tests {
1275    use super::*;
1276    use scirs2_core::ndarray::Array1;
1277
1278    #[test]
1279    fn test_cross_correlation() {
1280        let n = 50;
1281        let x = Array1::from_vec((0..n).map(|i| (i as f64 * 0.1).sin()).collect());
1282        let y = Array1::from_vec((0..n).map(|i| ((i as f64 + 2.0) * 0.1).sin()).collect());
1283
1284        let analyzer = CorrelationAnalyzer::new();
1285        let config = CrossCorrelationConfig::default();
1286        let result = analyzer.cross_correlation(&x, &y, &config).unwrap();
1287
1288        assert_eq!(result.correlations.len(), 2 * config.max_lag + 1);
1289        assert!(result.max_correlation.abs() <= 1.0 + f64::EPSILON * 10.0);
1290        assert!(result.lag_at_max.abs() <= config.max_lag as i32);
1291    }
1292
1293    #[test]
1294    fn test_dynamic_time_warping() {
1295        let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1296        let y = Array1::from_vec(vec![1.0, 2.0, 2.0, 3.0, 4.0, 5.0]);
1297
1298        let analyzer = CorrelationAnalyzer::new();
1299        let config = DTWConfig::default();
1300        let result = analyzer.dynamic_time_warping(&x, &y, &config).unwrap();
1301
1302        assert!(result.distance >= 0.0);
1303        assert!(!result.warping_path.is_empty());
1304        assert_eq!(result.cost_matrix.nrows(), x.len());
1305        assert_eq!(result.cost_matrix.ncols(), y.len());
1306    }
1307
1308    #[test]
1309    fn test_coherence_analysis() {
1310        let n = 512;
1311        let x = Array1::from_vec((0..n).map(|i| (i as f64 * 0.1).sin()).collect());
1312        let y = Array1::from_vec(
1313            (0..n)
1314                .map(|i| (i as f64 * 0.1).sin() + 0.1 * scirs2_core::random::random::<f64>())
1315                .collect(),
1316        );
1317
1318        let analyzer = CorrelationAnalyzer::new();
1319        let config = CoherenceConfig::default();
1320        let result = analyzer.coherence_analysis(&x, &y, &config).unwrap();
1321
1322        assert!(!result.coherence.is_empty());
1323        assert!(!result.frequencies.is_empty());
1324        assert_eq!(result.coherence.len(), result.frequencies.len());
1325
1326        // Coherence values should be between 0 and 1
1327        for &coh in result.coherence.iter() {
1328            assert!((0.0..=1.0).contains(&coh));
1329        }
1330    }
1331}