rustkernel_accounting/
temporal.rs

1//! Temporal correlation kernel.
2//!
3//! This module provides temporal correlation analysis for accounting:
4//! - Calculate correlations between account time series
5//! - Detect anomalies based on expected correlations
6//! - Identify pattern changes
7
8use crate::types::{
9    AccountCorrelation, AccountTimeSeries, AnomalyType, CorrelationAnomaly, CorrelationResult,
10    CorrelationStats, CorrelationType,
11};
12use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
13use std::collections::HashMap;
14
15// ============================================================================
16// Temporal Correlation Kernel
17// ============================================================================
18
19/// Temporal correlation kernel.
20///
21/// Analyzes correlations between account time series.
22#[derive(Debug, Clone)]
23pub struct TemporalCorrelation {
24    metadata: KernelMetadata,
25}
26
27impl Default for TemporalCorrelation {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl TemporalCorrelation {
34    /// Create a new temporal correlation kernel.
35    #[must_use]
36    pub fn new() -> Self {
37        Self {
38            metadata: KernelMetadata::batch("accounting/temporal-correlation", Domain::Accounting)
39                .with_description("Account time series correlation analysis")
40                .with_throughput(5_000)
41                .with_latency_us(500.0),
42        }
43    }
44
45    /// Calculate correlations between account time series.
46    pub fn correlate(
47        time_series: &[AccountTimeSeries],
48        config: &CorrelationConfig,
49    ) -> CorrelationResult {
50        let mut correlations = Vec::new();
51        let mut anomalies = Vec::new();
52
53        // Calculate pairwise correlations
54        for i in 0..time_series.len() {
55            for j in (i + 1)..time_series.len() {
56                let ts_a = &time_series[i];
57                let ts_b = &time_series[j];
58
59                if let Some(corr) = Self::calculate_correlation(ts_a, ts_b, config) {
60                    correlations.push(corr);
61                }
62            }
63        }
64
65        // Detect anomalies based on expected correlations
66        if let Some(ref expected) = config.expected_correlations {
67            for (pair, expected_coef) in expected {
68                let actual = correlations.iter().find(|c| {
69                    (c.account_a == pair.0 && c.account_b == pair.1)
70                        || (c.account_a == pair.1 && c.account_b == pair.0)
71                });
72
73                if let Some(actual_corr) = actual {
74                    let diff = (actual_corr.coefficient - expected_coef).abs();
75                    if diff > config.correlation_threshold {
76                        // Find the corresponding time series
77                        if let Some(ts) = time_series.iter().find(|t| t.account_code == pair.0) {
78                            if let Some(last_point) = ts.data_points.last() {
79                                anomalies.push(CorrelationAnomaly {
80                                    account_code: pair.0.clone(),
81                                    date: last_point.date,
82                                    expected: *expected_coef,
83                                    actual: actual_corr.coefficient,
84                                    z_score: diff / 0.1, // Simplified z-score
85                                    anomaly_type: AnomalyType::MissingCorrelation,
86                                });
87                            }
88                        }
89                    }
90                }
91            }
92        }
93
94        // Detect point anomalies using correlation-based prediction
95        for ts in time_series {
96            let related: Vec<_> = correlations
97                .iter()
98                .filter(|c| c.account_a == ts.account_code || c.account_b == ts.account_code)
99                .filter(|c| c.coefficient.abs() > config.significant_correlation)
100                .collect();
101
102            if !related.is_empty() {
103                let ts_anomalies = Self::detect_point_anomalies(ts, &related, time_series, config);
104                anomalies.extend(ts_anomalies);
105            }
106        }
107
108        let significant_count = correlations
109            .iter()
110            .filter(|c| c.coefficient.abs() >= config.significant_correlation)
111            .count();
112
113        let avg_correlation = if !correlations.is_empty() {
114            correlations
115                .iter()
116                .map(|c| c.coefficient.abs())
117                .sum::<f64>()
118                / correlations.len() as f64
119        } else {
120            0.0
121        };
122
123        let anomaly_count = anomalies.len();
124
125        CorrelationResult {
126            correlations,
127            anomalies,
128            stats: CorrelationStats {
129                accounts_analyzed: time_series.len(),
130                significant_correlations: significant_count,
131                anomaly_count,
132                avg_correlation,
133            },
134        }
135    }
136
137    /// Calculate correlation between two time series.
138    fn calculate_correlation(
139        ts_a: &AccountTimeSeries,
140        ts_b: &AccountTimeSeries,
141        config: &CorrelationConfig,
142    ) -> Option<AccountCorrelation> {
143        // Align time series by date
144        let (values_a, values_b) = Self::align_series(ts_a, ts_b);
145
146        if values_a.len() < config.min_data_points {
147            return None;
148        }
149
150        // Calculate Pearson correlation
151        let n = values_a.len() as f64;
152        let sum_a: f64 = values_a.iter().sum();
153        let sum_b: f64 = values_b.iter().sum();
154        let sum_ab: f64 = values_a
155            .iter()
156            .zip(values_b.iter())
157            .map(|(a, b)| a * b)
158            .sum();
159        let sum_a2: f64 = values_a.iter().map(|a| a * a).sum();
160        let sum_b2: f64 = values_b.iter().map(|b| b * b).sum();
161
162        let numerator = n * sum_ab - sum_a * sum_b;
163        let denominator = ((n * sum_a2 - sum_a * sum_a) * (n * sum_b2 - sum_b * sum_b)).sqrt();
164
165        if denominator.abs() < 1e-10 {
166            return None;
167        }
168
169        let coefficient = numerator / denominator;
170
171        // Calculate p-value (simplified t-test approximation)
172        let t_stat = coefficient * ((n - 2.0) / (1.0 - coefficient * coefficient)).sqrt();
173        let p_value = Self::t_distribution_pvalue(t_stat.abs(), (n - 2.0) as u32);
174
175        let correlation_type = if p_value > config.significance_level {
176            CorrelationType::None
177        } else if coefficient > 0.0 {
178            CorrelationType::Positive
179        } else {
180            CorrelationType::Negative
181        };
182
183        Some(AccountCorrelation {
184            account_a: ts_a.account_code.clone(),
185            account_b: ts_b.account_code.clone(),
186            coefficient,
187            p_value,
188            correlation_type,
189        })
190    }
191
192    /// Align two time series by date.
193    fn align_series(ts_a: &AccountTimeSeries, ts_b: &AccountTimeSeries) -> (Vec<f64>, Vec<f64>) {
194        let dates_a: HashMap<u64, f64> = ts_a
195            .data_points
196            .iter()
197            .map(|p| (p.date, p.balance))
198            .collect();
199
200        let dates_b: HashMap<u64, f64> = ts_b
201            .data_points
202            .iter()
203            .map(|p| (p.date, p.balance))
204            .collect();
205
206        let common_dates: Vec<u64> = dates_a
207            .keys()
208            .filter(|d| dates_b.contains_key(d))
209            .copied()
210            .collect();
211
212        let values_a: Vec<f64> = common_dates
213            .iter()
214            .filter_map(|d| dates_a.get(d))
215            .copied()
216            .collect();
217        let values_b: Vec<f64> = common_dates
218            .iter()
219            .filter_map(|d| dates_b.get(d))
220            .copied()
221            .collect();
222
223        (values_a, values_b)
224    }
225
226    /// Calculate p-value from t-distribution using proper Student's t CDF.
227    ///
228    /// Uses the regularized incomplete beta function for exact calculation.
229    fn t_distribution_pvalue(t: f64, df: u32) -> f64 {
230        if df == 0 {
231            return 1.0;
232        }
233
234        let t_abs = t.abs();
235        let df_f = df as f64;
236
237        // For large df, use normal approximation (numerically stable)
238        if df > 100 {
239            return 2.0 * (1.0 - Self::normal_cdf(t_abs));
240        }
241
242        // Use regularized incomplete beta function:
243        // P(T > |t|) = I_{x}(df/2, 1/2) where x = df/(df + t²)
244        let x = df_f / (df_f + t_abs * t_abs);
245
246        // Two-tailed p-value: 2 * P(T > |t|) = I_x(df/2, 1/2)
247        Self::regularized_incomplete_beta(x, df_f / 2.0, 0.5)
248    }
249
250    /// Regularized incomplete beta function I_x(a, b).
251    ///
252    /// Uses continued fraction expansion for numerical stability.
253    fn regularized_incomplete_beta(x: f64, a: f64, b: f64) -> f64 {
254        if x <= 0.0 {
255            return 0.0;
256        }
257        if x >= 1.0 {
258            return 1.0;
259        }
260
261        // Use symmetry: I_x(a,b) = 1 - I_{1-x}(b,a) when x > (a+1)/(a+b+2)
262        let threshold = (a + 1.0) / (a + b + 2.0);
263
264        if x > threshold {
265            return 1.0 - Self::regularized_incomplete_beta(1.0 - x, b, a);
266        }
267
268        // Beta function using log-gamma approximation
269        let log_beta = Self::log_gamma(a) + Self::log_gamma(b) - Self::log_gamma(a + b);
270
271        // Front factor: x^a * (1-x)^b / (a * B(a,b))
272        let front = (a * x.ln() + b * (1.0 - x).ln() - log_beta - a.ln()).exp();
273
274        // Continued fraction expansion (Lentz's method)
275        let mut f = 1.0;
276        let mut c = 1.0;
277        let mut d = 0.0;
278
279        for m in 1..200 {
280            let m_f = m as f64;
281
282            // Even term: d_{2m}
283            let num_even = m_f * (b - m_f) * x / ((a + 2.0 * m_f - 1.0) * (a + 2.0 * m_f));
284            d = 1.0 + num_even * d;
285            if d.abs() < 1e-30 {
286                d = 1e-30;
287            }
288            d = 1.0 / d;
289            c = 1.0 + num_even / c;
290            if c.abs() < 1e-30 {
291                c = 1e-30;
292            }
293            f *= d * c;
294
295            // Odd term: d_{2m+1}
296            let num_odd =
297                -(a + m_f) * (a + b + m_f) * x / ((a + 2.0 * m_f) * (a + 2.0 * m_f + 1.0));
298            d = 1.0 + num_odd * d;
299            if d.abs() < 1e-30 {
300                d = 1e-30;
301            }
302            d = 1.0 / d;
303            c = 1.0 + num_odd / c;
304            if c.abs() < 1e-30 {
305                c = 1e-30;
306            }
307            let delta = d * c;
308            f *= delta;
309
310            // Check convergence
311            if (delta - 1.0).abs() < 1e-10 {
312                break;
313            }
314        }
315
316        front * f / a
317    }
318
319    /// Log-gamma function using Stirling's approximation.
320    fn log_gamma(x: f64) -> f64 {
321        if x <= 0.0 {
322            return f64::INFINITY;
323        }
324
325        // Use Lanczos approximation coefficients
326        // Note: These are standard mathematical constants with exact values
327        #[allow(clippy::excessive_precision)]
328        const LANCZOS_COEFFS: [f64; 9] = [
329            0.99999999999980993,
330            676.5203681218851,
331            -1259.1392167224028,
332            771.32342877765313,
333            -176.61502916214059,
334            12.507343278686905,
335            -0.13857109526572012,
336            9.9843695780195716e-6,
337            1.5056327351493116e-7,
338        ];
339        let g = 7.0;
340        let c = LANCZOS_COEFFS;
341
342        if x < 0.5 {
343            // Reflection formula
344            return std::f64::consts::PI.ln()
345                - (std::f64::consts::PI * x).sin().ln()
346                - Self::log_gamma(1.0 - x);
347        }
348
349        let x = x - 1.0;
350        let mut ag = c[0];
351        for (i, &coef) in c.iter().enumerate().skip(1) {
352            ag += coef / (x + i as f64);
353        }
354
355        let tmp = x + g + 0.5;
356        0.5 * (2.0 * std::f64::consts::PI).ln() + (x + 0.5) * tmp.ln() - tmp + ag.ln()
357    }
358
359    /// Standard normal CDF approximation.
360    fn normal_cdf(x: f64) -> f64 {
361        // Approximation using error function
362        0.5 * (1.0 + Self::erf(x / std::f64::consts::SQRT_2))
363    }
364
365    /// Error function approximation.
366    fn erf(x: f64) -> f64 {
367        // Horner form approximation
368        let a1 = 0.254829592;
369        let a2 = -0.284496736;
370        let a3 = 1.421413741;
371        let a4 = -1.453152027;
372        let a5 = 1.061405429;
373        let p = 0.3275911;
374
375        let sign = if x >= 0.0 { 1.0 } else { -1.0 };
376        let x = x.abs();
377
378        let t = 1.0 / (1.0 + p * x);
379        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
380
381        sign * y
382    }
383
384    /// Detect point anomalies using correlation-based prediction.
385    fn detect_point_anomalies(
386        ts: &AccountTimeSeries,
387        related_correlations: &[&AccountCorrelation],
388        all_series: &[AccountTimeSeries],
389        config: &CorrelationConfig,
390    ) -> Vec<CorrelationAnomaly> {
391        let mut anomalies = Vec::new();
392
393        for point in &ts.data_points {
394            // Predict value based on correlated accounts
395            let mut predictions = Vec::new();
396
397            for corr in related_correlations {
398                let related_code = if corr.account_a == ts.account_code {
399                    &corr.account_b
400                } else {
401                    &corr.account_a
402                };
403
404                if let Some(related_ts) =
405                    all_series.iter().find(|t| t.account_code == *related_code)
406                {
407                    if let Some(related_point) =
408                        related_ts.data_points.iter().find(|p| p.date == point.date)
409                    {
410                        // Simple linear prediction
411                        let predicted = related_point.balance * corr.coefficient;
412                        predictions.push(predicted);
413                    }
414                }
415            }
416
417            if predictions.is_empty() {
418                continue;
419            }
420
421            let avg_prediction = predictions.iter().sum::<f64>() / predictions.len() as f64;
422            let std_dev = if predictions.len() > 1 {
423                let variance = predictions
424                    .iter()
425                    .map(|p| (p - avg_prediction).powi(2))
426                    .sum::<f64>()
427                    / (predictions.len() - 1) as f64;
428                variance.sqrt()
429            } else {
430                avg_prediction.abs() * 0.1 // Fallback
431            };
432
433            if std_dev > 0.0 {
434                let z_score = (point.balance - avg_prediction) / std_dev;
435
436                if z_score.abs() > config.anomaly_threshold {
437                    let anomaly_type = if z_score > 0.0 {
438                        AnomalyType::UnexpectedHigh
439                    } else {
440                        AnomalyType::UnexpectedLow
441                    };
442
443                    anomalies.push(CorrelationAnomaly {
444                        account_code: ts.account_code.clone(),
445                        date: point.date,
446                        expected: avg_prediction,
447                        actual: point.balance,
448                        z_score,
449                        anomaly_type,
450                    });
451                }
452            }
453        }
454
455        anomalies
456    }
457
458    /// Calculate rolling correlation over time windows.
459    pub fn rolling_correlation(
460        ts_a: &AccountTimeSeries,
461        ts_b: &AccountTimeSeries,
462        window_size: usize,
463    ) -> Vec<RollingCorrelation> {
464        let mut results = Vec::new();
465        let (values_a, values_b) = Self::align_series(ts_a, ts_b);
466
467        if values_a.len() < window_size {
468            return results;
469        }
470
471        // Get dates for the aligned series
472        let _dates_a: HashMap<u64, usize> = ts_a
473            .data_points
474            .iter()
475            .enumerate()
476            .map(|(i, p)| (p.date, i))
477            .collect();
478
479        let common_dates: Vec<u64> = ts_a
480            .data_points
481            .iter()
482            .filter(|p| ts_b.data_points.iter().any(|pb| pb.date == p.date))
483            .map(|p| p.date)
484            .collect();
485
486        for i in window_size..=values_a.len() {
487            let window_a = &values_a[i - window_size..i];
488            let window_b = &values_b[i - window_size..i];
489
490            let correlation = Self::calculate_window_correlation(window_a, window_b);
491
492            if i - 1 < common_dates.len() {
493                results.push(RollingCorrelation {
494                    end_date: common_dates[i - 1],
495                    correlation,
496                    window_size,
497                });
498            }
499        }
500
501        results
502    }
503
504    /// Calculate correlation for a window.
505    fn calculate_window_correlation(values_a: &[f64], values_b: &[f64]) -> f64 {
506        let n = values_a.len() as f64;
507        let sum_a: f64 = values_a.iter().sum();
508        let sum_b: f64 = values_b.iter().sum();
509        let sum_ab: f64 = values_a
510            .iter()
511            .zip(values_b.iter())
512            .map(|(a, b)| a * b)
513            .sum();
514        let sum_a2: f64 = values_a.iter().map(|a| a * a).sum();
515        let sum_b2: f64 = values_b.iter().map(|b| b * b).sum();
516
517        let numerator = n * sum_ab - sum_a * sum_b;
518        let denominator = ((n * sum_a2 - sum_a * sum_a) * (n * sum_b2 - sum_b * sum_b)).sqrt();
519
520        if denominator.abs() < 1e-10 {
521            0.0
522        } else {
523            numerator / denominator
524        }
525    }
526
527    /// Detect structural breaks in correlation.
528    pub fn detect_correlation_breaks(
529        rolling: &[RollingCorrelation],
530        threshold: f64,
531    ) -> Vec<CorrelationBreak> {
532        let mut breaks = Vec::new();
533
534        for i in 1..rolling.len() {
535            let change = rolling[i].correlation - rolling[i - 1].correlation;
536            if change.abs() > threshold {
537                breaks.push(CorrelationBreak {
538                    date: rolling[i].end_date,
539                    change,
540                    before: rolling[i - 1].correlation,
541                    after: rolling[i].correlation,
542                });
543            }
544        }
545
546        breaks
547    }
548}
549
550impl GpuKernel for TemporalCorrelation {
551    fn metadata(&self) -> &KernelMetadata {
552        &self.metadata
553    }
554}
555
556/// Correlation configuration.
557#[derive(Debug, Clone)]
558pub struct CorrelationConfig {
559    /// Minimum data points required.
560    pub min_data_points: usize,
561    /// Significance level for correlation.
562    pub significance_level: f64,
563    /// Threshold for significant correlation.
564    pub significant_correlation: f64,
565    /// Correlation change threshold.
566    pub correlation_threshold: f64,
567    /// Z-score threshold for anomalies.
568    pub anomaly_threshold: f64,
569    /// Expected correlations (account pair -> expected coefficient).
570    pub expected_correlations: Option<HashMap<(String, String), f64>>,
571}
572
573impl Default for CorrelationConfig {
574    fn default() -> Self {
575        Self {
576            min_data_points: 10,
577            significance_level: 0.05,
578            significant_correlation: 0.5,
579            correlation_threshold: 0.3,
580            anomaly_threshold: 2.0,
581            expected_correlations: None,
582        }
583    }
584}
585
586/// Rolling correlation result.
587#[derive(Debug, Clone)]
588pub struct RollingCorrelation {
589    /// End date of window.
590    pub end_date: u64,
591    /// Correlation value.
592    pub correlation: f64,
593    /// Window size used.
594    pub window_size: usize,
595}
596
597/// Correlation break point.
598#[derive(Debug, Clone)]
599pub struct CorrelationBreak {
600    /// Date of break.
601    pub date: u64,
602    /// Change in correlation.
603    pub change: f64,
604    /// Correlation before break.
605    pub before: f64,
606    /// Correlation after break.
607    pub after: f64,
608}
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613    use crate::types::{TimeFrequency, TimeSeriesPoint};
614
615    fn create_correlated_series() -> (AccountTimeSeries, AccountTimeSeries) {
616        let base_values = vec![
617            100.0, 110.0, 105.0, 120.0, 115.0, 130.0, 125.0, 140.0, 135.0, 150.0,
618        ];
619        let correlated_values: Vec<f64> = base_values.iter().map(|v| v * 0.5 + 20.0).collect();
620
621        let ts_a = AccountTimeSeries {
622            account_code: "1000".to_string(),
623            data_points: base_values
624                .iter()
625                .enumerate()
626                .map(|(i, &v)| TimeSeriesPoint {
627                    date: 1700000000 + (i as u64 * 86400),
628                    balance: v,
629                    period_change: if i > 0 { v - base_values[i - 1] } else { 0.0 },
630                })
631                .collect(),
632            frequency: TimeFrequency::Daily,
633        };
634
635        let ts_b = AccountTimeSeries {
636            account_code: "2000".to_string(),
637            data_points: correlated_values
638                .iter()
639                .enumerate()
640                .map(|(i, &v)| TimeSeriesPoint {
641                    date: 1700000000 + (i as u64 * 86400),
642                    balance: v,
643                    period_change: if i > 0 {
644                        v - correlated_values[i - 1]
645                    } else {
646                        0.0
647                    },
648                })
649                .collect(),
650            frequency: TimeFrequency::Daily,
651        };
652
653        (ts_a, ts_b)
654    }
655
656    #[test]
657    fn test_temporal_metadata() {
658        let kernel = TemporalCorrelation::new();
659        assert_eq!(kernel.metadata().id, "accounting/temporal-correlation");
660        assert_eq!(kernel.metadata().domain, Domain::Accounting);
661    }
662
663    #[test]
664    fn test_positive_correlation() {
665        let (ts_a, ts_b) = create_correlated_series();
666        let time_series = vec![ts_a, ts_b];
667        let config = CorrelationConfig::default();
668
669        let result = TemporalCorrelation::correlate(&time_series, &config);
670
671        assert_eq!(result.correlations.len(), 1);
672        let corr = &result.correlations[0];
673        assert!(corr.coefficient > 0.9); // Should be highly correlated
674        assert_eq!(corr.correlation_type, CorrelationType::Positive);
675    }
676
677    #[test]
678    fn test_negative_correlation() {
679        let (ts_a, mut ts_b) = create_correlated_series();
680
681        // Make ts_b negatively correlated
682        for point in &mut ts_b.data_points {
683            point.balance = 200.0 - point.balance;
684        }
685
686        let time_series = vec![ts_a, ts_b];
687        let config = CorrelationConfig::default();
688
689        let result = TemporalCorrelation::correlate(&time_series, &config);
690
691        assert_eq!(result.correlations.len(), 1);
692        let corr = &result.correlations[0];
693        assert!(corr.coefficient < -0.9); // Should be highly negatively correlated
694        assert_eq!(corr.correlation_type, CorrelationType::Negative);
695    }
696
697    #[test]
698    fn test_no_correlation() {
699        let ts_a = AccountTimeSeries {
700            account_code: "1000".to_string(),
701            data_points: (0..10)
702                .map(|i| TimeSeriesPoint {
703                    date: 1700000000 + (i as u64 * 86400),
704                    balance: 100.0 + (i as f64),
705                    period_change: 1.0,
706                })
707                .collect(),
708            frequency: TimeFrequency::Daily,
709        };
710
711        // Random-ish uncorrelated data
712        let ts_b = AccountTimeSeries {
713            account_code: "2000".to_string(),
714            data_points: (0..10)
715                .map(|i| TimeSeriesPoint {
716                    date: 1700000000 + (i as u64 * 86400),
717                    balance: [50.0, 80.0, 45.0, 90.0, 40.0, 85.0, 35.0, 95.0, 30.0, 100.0][i],
718                    period_change: 0.0,
719                })
720                .collect(),
721            frequency: TimeFrequency::Daily,
722        };
723
724        let time_series = vec![ts_a, ts_b];
725        let config = CorrelationConfig::default();
726
727        let result = TemporalCorrelation::correlate(&time_series, &config);
728
729        assert_eq!(result.correlations.len(), 1);
730        // Correlation should be weak
731        assert!(result.correlations[0].coefficient.abs() < 0.5);
732    }
733
734    #[test]
735    fn test_insufficient_data() {
736        let ts_a = AccountTimeSeries {
737            account_code: "1000".to_string(),
738            data_points: vec![TimeSeriesPoint {
739                date: 1700000000,
740                balance: 100.0,
741                period_change: 0.0,
742            }],
743            frequency: TimeFrequency::Daily,
744        };
745
746        let ts_b = AccountTimeSeries {
747            account_code: "2000".to_string(),
748            data_points: vec![TimeSeriesPoint {
749                date: 1700000000,
750                balance: 50.0,
751                period_change: 0.0,
752            }],
753            frequency: TimeFrequency::Daily,
754        };
755
756        let time_series = vec![ts_a, ts_b];
757        let config = CorrelationConfig {
758            min_data_points: 10,
759            ..Default::default()
760        };
761
762        let result = TemporalCorrelation::correlate(&time_series, &config);
763
764        assert!(result.correlations.is_empty());
765    }
766
767    #[test]
768    fn test_rolling_correlation() {
769        let (ts_a, ts_b) = create_correlated_series();
770
771        let rolling = TemporalCorrelation::rolling_correlation(&ts_a, &ts_b, 5);
772
773        assert!(!rolling.is_empty());
774        // All rolling correlations should be high for perfectly correlated series
775        assert!(rolling.iter().all(|r| r.correlation > 0.9));
776    }
777
778    #[test]
779    fn test_correlation_break_detection() {
780        let rolling = vec![
781            RollingCorrelation {
782                end_date: 1700000000,
783                correlation: 0.9,
784                window_size: 5,
785            },
786            RollingCorrelation {
787                end_date: 1700086400,
788                correlation: 0.85,
789                window_size: 5,
790            },
791            RollingCorrelation {
792                end_date: 1700172800,
793                correlation: 0.2,
794                window_size: 5,
795            }, // Break!
796            RollingCorrelation {
797                end_date: 1700259200,
798                correlation: 0.25,
799                window_size: 5,
800            },
801        ];
802
803        let breaks = TemporalCorrelation::detect_correlation_breaks(&rolling, 0.5);
804
805        assert_eq!(breaks.len(), 1);
806        assert!((breaks[0].change - (-0.65)).abs() < 0.01);
807    }
808
809    #[test]
810    fn test_correlation_stats() {
811        let (ts_a, ts_b) = create_correlated_series();
812        let time_series = vec![ts_a, ts_b];
813        let config = CorrelationConfig::default();
814
815        let result = TemporalCorrelation::correlate(&time_series, &config);
816
817        assert_eq!(result.stats.accounts_analyzed, 2);
818        assert_eq!(result.stats.significant_correlations, 1);
819    }
820
821    #[test]
822    fn test_expected_correlation_anomaly() {
823        let (ts_a, mut ts_b) = create_correlated_series();
824
825        // Make ts_b uncorrelated
826        for (i, point) in ts_b.data_points.iter_mut().enumerate() {
827            point.balance = [50.0, 80.0, 45.0, 90.0, 40.0, 85.0, 35.0, 95.0, 30.0, 100.0][i];
828        }
829
830        let mut expected = HashMap::new();
831        expected.insert(("1000".to_string(), "2000".to_string()), 0.9);
832
833        let time_series = vec![ts_a, ts_b];
834        let config = CorrelationConfig {
835            expected_correlations: Some(expected),
836            correlation_threshold: 0.3,
837            ..Default::default()
838        };
839
840        let result = TemporalCorrelation::correlate(&time_series, &config);
841
842        // Should detect anomaly because correlation doesn't match expected
843        assert!(
844            result
845                .anomalies
846                .iter()
847                .any(|a| a.anomaly_type == AnomalyType::MissingCorrelation)
848        );
849    }
850}