1use crate::types::{
9 AccountCorrelation, AccountTimeSeries, AnomalyType, CorrelationAnomaly, CorrelationResult,
10 CorrelationStats, CorrelationType,
11};
12use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
13use std::collections::HashMap;
14
15#[derive(Debug, Clone)]
23pub struct TemporalCorrelation {
24 metadata: KernelMetadata,
25}
26
27impl Default for TemporalCorrelation {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl TemporalCorrelation {
34 #[must_use]
36 pub fn new() -> Self {
37 Self {
38 metadata: KernelMetadata::batch("accounting/temporal-correlation", Domain::Accounting)
39 .with_description("Account time series correlation analysis")
40 .with_throughput(5_000)
41 .with_latency_us(500.0),
42 }
43 }
44
45 pub fn correlate(
47 time_series: &[AccountTimeSeries],
48 config: &CorrelationConfig,
49 ) -> CorrelationResult {
50 let mut correlations = Vec::new();
51 let mut anomalies = Vec::new();
52
53 for i in 0..time_series.len() {
55 for j in (i + 1)..time_series.len() {
56 let ts_a = &time_series[i];
57 let ts_b = &time_series[j];
58
59 if let Some(corr) = Self::calculate_correlation(ts_a, ts_b, config) {
60 correlations.push(corr);
61 }
62 }
63 }
64
65 if let Some(ref expected) = config.expected_correlations {
67 for (pair, expected_coef) in expected {
68 let actual = correlations.iter().find(|c| {
69 (c.account_a == pair.0 && c.account_b == pair.1)
70 || (c.account_a == pair.1 && c.account_b == pair.0)
71 });
72
73 if let Some(actual_corr) = actual {
74 let diff = (actual_corr.coefficient - expected_coef).abs();
75 if diff > config.correlation_threshold {
76 if let Some(ts) = time_series.iter().find(|t| t.account_code == pair.0) {
78 if let Some(last_point) = ts.data_points.last() {
79 anomalies.push(CorrelationAnomaly {
80 account_code: pair.0.clone(),
81 date: last_point.date,
82 expected: *expected_coef,
83 actual: actual_corr.coefficient,
84 z_score: diff / 0.1, anomaly_type: AnomalyType::MissingCorrelation,
86 });
87 }
88 }
89 }
90 }
91 }
92 }
93
94 for ts in time_series {
96 let related: Vec<_> = correlations
97 .iter()
98 .filter(|c| c.account_a == ts.account_code || c.account_b == ts.account_code)
99 .filter(|c| c.coefficient.abs() > config.significant_correlation)
100 .collect();
101
102 if !related.is_empty() {
103 let ts_anomalies = Self::detect_point_anomalies(ts, &related, time_series, config);
104 anomalies.extend(ts_anomalies);
105 }
106 }
107
108 let significant_count = correlations
109 .iter()
110 .filter(|c| c.coefficient.abs() >= config.significant_correlation)
111 .count();
112
113 let avg_correlation = if !correlations.is_empty() {
114 correlations
115 .iter()
116 .map(|c| c.coefficient.abs())
117 .sum::<f64>()
118 / correlations.len() as f64
119 } else {
120 0.0
121 };
122
123 let anomaly_count = anomalies.len();
124
125 CorrelationResult {
126 correlations,
127 anomalies,
128 stats: CorrelationStats {
129 accounts_analyzed: time_series.len(),
130 significant_correlations: significant_count,
131 anomaly_count,
132 avg_correlation,
133 },
134 }
135 }
136
137 fn calculate_correlation(
139 ts_a: &AccountTimeSeries,
140 ts_b: &AccountTimeSeries,
141 config: &CorrelationConfig,
142 ) -> Option<AccountCorrelation> {
143 let (values_a, values_b) = Self::align_series(ts_a, ts_b);
145
146 if values_a.len() < config.min_data_points {
147 return None;
148 }
149
150 let n = values_a.len() as f64;
152 let sum_a: f64 = values_a.iter().sum();
153 let sum_b: f64 = values_b.iter().sum();
154 let sum_ab: f64 = values_a
155 .iter()
156 .zip(values_b.iter())
157 .map(|(a, b)| a * b)
158 .sum();
159 let sum_a2: f64 = values_a.iter().map(|a| a * a).sum();
160 let sum_b2: f64 = values_b.iter().map(|b| b * b).sum();
161
162 let numerator = n * sum_ab - sum_a * sum_b;
163 let denominator = ((n * sum_a2 - sum_a * sum_a) * (n * sum_b2 - sum_b * sum_b)).sqrt();
164
165 if denominator.abs() < 1e-10 {
166 return None;
167 }
168
169 let coefficient = numerator / denominator;
170
171 let t_stat = coefficient * ((n - 2.0) / (1.0 - coefficient * coefficient)).sqrt();
173 let p_value = Self::t_distribution_pvalue(t_stat.abs(), (n - 2.0) as u32);
174
175 let correlation_type = if p_value > config.significance_level {
176 CorrelationType::None
177 } else if coefficient > 0.0 {
178 CorrelationType::Positive
179 } else {
180 CorrelationType::Negative
181 };
182
183 Some(AccountCorrelation {
184 account_a: ts_a.account_code.clone(),
185 account_b: ts_b.account_code.clone(),
186 coefficient,
187 p_value,
188 correlation_type,
189 })
190 }
191
192 fn align_series(ts_a: &AccountTimeSeries, ts_b: &AccountTimeSeries) -> (Vec<f64>, Vec<f64>) {
194 let dates_a: HashMap<u64, f64> = ts_a
195 .data_points
196 .iter()
197 .map(|p| (p.date, p.balance))
198 .collect();
199
200 let dates_b: HashMap<u64, f64> = ts_b
201 .data_points
202 .iter()
203 .map(|p| (p.date, p.balance))
204 .collect();
205
206 let common_dates: Vec<u64> = dates_a
207 .keys()
208 .filter(|d| dates_b.contains_key(d))
209 .copied()
210 .collect();
211
212 let values_a: Vec<f64> = common_dates
213 .iter()
214 .filter_map(|d| dates_a.get(d))
215 .copied()
216 .collect();
217 let values_b: Vec<f64> = common_dates
218 .iter()
219 .filter_map(|d| dates_b.get(d))
220 .copied()
221 .collect();
222
223 (values_a, values_b)
224 }
225
226 fn t_distribution_pvalue(t: f64, df: u32) -> f64 {
230 if df == 0 {
231 return 1.0;
232 }
233
234 let t_abs = t.abs();
235 let df_f = df as f64;
236
237 if df > 100 {
239 return 2.0 * (1.0 - Self::normal_cdf(t_abs));
240 }
241
242 let x = df_f / (df_f + t_abs * t_abs);
245
246 Self::regularized_incomplete_beta(x, df_f / 2.0, 0.5)
248 }
249
250 fn regularized_incomplete_beta(x: f64, a: f64, b: f64) -> f64 {
254 if x <= 0.0 {
255 return 0.0;
256 }
257 if x >= 1.0 {
258 return 1.0;
259 }
260
261 let threshold = (a + 1.0) / (a + b + 2.0);
263
264 if x > threshold {
265 return 1.0 - Self::regularized_incomplete_beta(1.0 - x, b, a);
266 }
267
268 let log_beta = Self::log_gamma(a) + Self::log_gamma(b) - Self::log_gamma(a + b);
270
271 let front = (a * x.ln() + b * (1.0 - x).ln() - log_beta - a.ln()).exp();
273
274 let mut f = 1.0;
276 let mut c = 1.0;
277 let mut d = 0.0;
278
279 for m in 1..200 {
280 let m_f = m as f64;
281
282 let num_even = m_f * (b - m_f) * x / ((a + 2.0 * m_f - 1.0) * (a + 2.0 * m_f));
284 d = 1.0 + num_even * d;
285 if d.abs() < 1e-30 {
286 d = 1e-30;
287 }
288 d = 1.0 / d;
289 c = 1.0 + num_even / c;
290 if c.abs() < 1e-30 {
291 c = 1e-30;
292 }
293 f *= d * c;
294
295 let num_odd =
297 -(a + m_f) * (a + b + m_f) * x / ((a + 2.0 * m_f) * (a + 2.0 * m_f + 1.0));
298 d = 1.0 + num_odd * d;
299 if d.abs() < 1e-30 {
300 d = 1e-30;
301 }
302 d = 1.0 / d;
303 c = 1.0 + num_odd / c;
304 if c.abs() < 1e-30 {
305 c = 1e-30;
306 }
307 let delta = d * c;
308 f *= delta;
309
310 if (delta - 1.0).abs() < 1e-10 {
312 break;
313 }
314 }
315
316 front * f / a
317 }
318
319 fn log_gamma(x: f64) -> f64 {
321 if x <= 0.0 {
322 return f64::INFINITY;
323 }
324
325 #[allow(clippy::excessive_precision)]
328 const LANCZOS_COEFFS: [f64; 9] = [
329 0.99999999999980993,
330 676.5203681218851,
331 -1259.1392167224028,
332 771.32342877765313,
333 -176.61502916214059,
334 12.507343278686905,
335 -0.13857109526572012,
336 9.9843695780195716e-6,
337 1.5056327351493116e-7,
338 ];
339 let g = 7.0;
340 let c = LANCZOS_COEFFS;
341
342 if x < 0.5 {
343 return std::f64::consts::PI.ln()
345 - (std::f64::consts::PI * x).sin().ln()
346 - Self::log_gamma(1.0 - x);
347 }
348
349 let x = x - 1.0;
350 let mut ag = c[0];
351 for (i, &coef) in c.iter().enumerate().skip(1) {
352 ag += coef / (x + i as f64);
353 }
354
355 let tmp = x + g + 0.5;
356 0.5 * (2.0 * std::f64::consts::PI).ln() + (x + 0.5) * tmp.ln() - tmp + ag.ln()
357 }
358
359 fn normal_cdf(x: f64) -> f64 {
361 0.5 * (1.0 + Self::erf(x / std::f64::consts::SQRT_2))
363 }
364
365 fn erf(x: f64) -> f64 {
367 let a1 = 0.254829592;
369 let a2 = -0.284496736;
370 let a3 = 1.421413741;
371 let a4 = -1.453152027;
372 let a5 = 1.061405429;
373 let p = 0.3275911;
374
375 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
376 let x = x.abs();
377
378 let t = 1.0 / (1.0 + p * x);
379 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
380
381 sign * y
382 }
383
384 fn detect_point_anomalies(
386 ts: &AccountTimeSeries,
387 related_correlations: &[&AccountCorrelation],
388 all_series: &[AccountTimeSeries],
389 config: &CorrelationConfig,
390 ) -> Vec<CorrelationAnomaly> {
391 let mut anomalies = Vec::new();
392
393 for point in &ts.data_points {
394 let mut predictions = Vec::new();
396
397 for corr in related_correlations {
398 let related_code = if corr.account_a == ts.account_code {
399 &corr.account_b
400 } else {
401 &corr.account_a
402 };
403
404 if let Some(related_ts) =
405 all_series.iter().find(|t| t.account_code == *related_code)
406 {
407 if let Some(related_point) =
408 related_ts.data_points.iter().find(|p| p.date == point.date)
409 {
410 let predicted = related_point.balance * corr.coefficient;
412 predictions.push(predicted);
413 }
414 }
415 }
416
417 if predictions.is_empty() {
418 continue;
419 }
420
421 let avg_prediction = predictions.iter().sum::<f64>() / predictions.len() as f64;
422 let std_dev = if predictions.len() > 1 {
423 let variance = predictions
424 .iter()
425 .map(|p| (p - avg_prediction).powi(2))
426 .sum::<f64>()
427 / (predictions.len() - 1) as f64;
428 variance.sqrt()
429 } else {
430 avg_prediction.abs() * 0.1 };
432
433 if std_dev > 0.0 {
434 let z_score = (point.balance - avg_prediction) / std_dev;
435
436 if z_score.abs() > config.anomaly_threshold {
437 let anomaly_type = if z_score > 0.0 {
438 AnomalyType::UnexpectedHigh
439 } else {
440 AnomalyType::UnexpectedLow
441 };
442
443 anomalies.push(CorrelationAnomaly {
444 account_code: ts.account_code.clone(),
445 date: point.date,
446 expected: avg_prediction,
447 actual: point.balance,
448 z_score,
449 anomaly_type,
450 });
451 }
452 }
453 }
454
455 anomalies
456 }
457
458 pub fn rolling_correlation(
460 ts_a: &AccountTimeSeries,
461 ts_b: &AccountTimeSeries,
462 window_size: usize,
463 ) -> Vec<RollingCorrelation> {
464 let mut results = Vec::new();
465 let (values_a, values_b) = Self::align_series(ts_a, ts_b);
466
467 if values_a.len() < window_size {
468 return results;
469 }
470
471 let _dates_a: HashMap<u64, usize> = ts_a
473 .data_points
474 .iter()
475 .enumerate()
476 .map(|(i, p)| (p.date, i))
477 .collect();
478
479 let common_dates: Vec<u64> = ts_a
480 .data_points
481 .iter()
482 .filter(|p| ts_b.data_points.iter().any(|pb| pb.date == p.date))
483 .map(|p| p.date)
484 .collect();
485
486 for i in window_size..=values_a.len() {
487 let window_a = &values_a[i - window_size..i];
488 let window_b = &values_b[i - window_size..i];
489
490 let correlation = Self::calculate_window_correlation(window_a, window_b);
491
492 if i - 1 < common_dates.len() {
493 results.push(RollingCorrelation {
494 end_date: common_dates[i - 1],
495 correlation,
496 window_size,
497 });
498 }
499 }
500
501 results
502 }
503
504 fn calculate_window_correlation(values_a: &[f64], values_b: &[f64]) -> f64 {
506 let n = values_a.len() as f64;
507 let sum_a: f64 = values_a.iter().sum();
508 let sum_b: f64 = values_b.iter().sum();
509 let sum_ab: f64 = values_a
510 .iter()
511 .zip(values_b.iter())
512 .map(|(a, b)| a * b)
513 .sum();
514 let sum_a2: f64 = values_a.iter().map(|a| a * a).sum();
515 let sum_b2: f64 = values_b.iter().map(|b| b * b).sum();
516
517 let numerator = n * sum_ab - sum_a * sum_b;
518 let denominator = ((n * sum_a2 - sum_a * sum_a) * (n * sum_b2 - sum_b * sum_b)).sqrt();
519
520 if denominator.abs() < 1e-10 {
521 0.0
522 } else {
523 numerator / denominator
524 }
525 }
526
527 pub fn detect_correlation_breaks(
529 rolling: &[RollingCorrelation],
530 threshold: f64,
531 ) -> Vec<CorrelationBreak> {
532 let mut breaks = Vec::new();
533
534 for i in 1..rolling.len() {
535 let change = rolling[i].correlation - rolling[i - 1].correlation;
536 if change.abs() > threshold {
537 breaks.push(CorrelationBreak {
538 date: rolling[i].end_date,
539 change,
540 before: rolling[i - 1].correlation,
541 after: rolling[i].correlation,
542 });
543 }
544 }
545
546 breaks
547 }
548}
549
550impl GpuKernel for TemporalCorrelation {
551 fn metadata(&self) -> &KernelMetadata {
552 &self.metadata
553 }
554}
555
556#[derive(Debug, Clone)]
558pub struct CorrelationConfig {
559 pub min_data_points: usize,
561 pub significance_level: f64,
563 pub significant_correlation: f64,
565 pub correlation_threshold: f64,
567 pub anomaly_threshold: f64,
569 pub expected_correlations: Option<HashMap<(String, String), f64>>,
571}
572
573impl Default for CorrelationConfig {
574 fn default() -> Self {
575 Self {
576 min_data_points: 10,
577 significance_level: 0.05,
578 significant_correlation: 0.5,
579 correlation_threshold: 0.3,
580 anomaly_threshold: 2.0,
581 expected_correlations: None,
582 }
583 }
584}
585
586#[derive(Debug, Clone)]
588pub struct RollingCorrelation {
589 pub end_date: u64,
591 pub correlation: f64,
593 pub window_size: usize,
595}
596
597#[derive(Debug, Clone)]
599pub struct CorrelationBreak {
600 pub date: u64,
602 pub change: f64,
604 pub before: f64,
606 pub after: f64,
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613 use crate::types::{TimeFrequency, TimeSeriesPoint};
614
615 fn create_correlated_series() -> (AccountTimeSeries, AccountTimeSeries) {
616 let base_values = vec![
617 100.0, 110.0, 105.0, 120.0, 115.0, 130.0, 125.0, 140.0, 135.0, 150.0,
618 ];
619 let correlated_values: Vec<f64> = base_values.iter().map(|v| v * 0.5 + 20.0).collect();
620
621 let ts_a = AccountTimeSeries {
622 account_code: "1000".to_string(),
623 data_points: base_values
624 .iter()
625 .enumerate()
626 .map(|(i, &v)| TimeSeriesPoint {
627 date: 1700000000 + (i as u64 * 86400),
628 balance: v,
629 period_change: if i > 0 { v - base_values[i - 1] } else { 0.0 },
630 })
631 .collect(),
632 frequency: TimeFrequency::Daily,
633 };
634
635 let ts_b = AccountTimeSeries {
636 account_code: "2000".to_string(),
637 data_points: correlated_values
638 .iter()
639 .enumerate()
640 .map(|(i, &v)| TimeSeriesPoint {
641 date: 1700000000 + (i as u64 * 86400),
642 balance: v,
643 period_change: if i > 0 {
644 v - correlated_values[i - 1]
645 } else {
646 0.0
647 },
648 })
649 .collect(),
650 frequency: TimeFrequency::Daily,
651 };
652
653 (ts_a, ts_b)
654 }
655
656 #[test]
657 fn test_temporal_metadata() {
658 let kernel = TemporalCorrelation::new();
659 assert_eq!(kernel.metadata().id, "accounting/temporal-correlation");
660 assert_eq!(kernel.metadata().domain, Domain::Accounting);
661 }
662
663 #[test]
664 fn test_positive_correlation() {
665 let (ts_a, ts_b) = create_correlated_series();
666 let time_series = vec![ts_a, ts_b];
667 let config = CorrelationConfig::default();
668
669 let result = TemporalCorrelation::correlate(&time_series, &config);
670
671 assert_eq!(result.correlations.len(), 1);
672 let corr = &result.correlations[0];
673 assert!(corr.coefficient > 0.9); assert_eq!(corr.correlation_type, CorrelationType::Positive);
675 }
676
677 #[test]
678 fn test_negative_correlation() {
679 let (ts_a, mut ts_b) = create_correlated_series();
680
681 for point in &mut ts_b.data_points {
683 point.balance = 200.0 - point.balance;
684 }
685
686 let time_series = vec![ts_a, ts_b];
687 let config = CorrelationConfig::default();
688
689 let result = TemporalCorrelation::correlate(&time_series, &config);
690
691 assert_eq!(result.correlations.len(), 1);
692 let corr = &result.correlations[0];
693 assert!(corr.coefficient < -0.9); assert_eq!(corr.correlation_type, CorrelationType::Negative);
695 }
696
697 #[test]
698 fn test_no_correlation() {
699 let ts_a = AccountTimeSeries {
700 account_code: "1000".to_string(),
701 data_points: (0..10)
702 .map(|i| TimeSeriesPoint {
703 date: 1700000000 + (i as u64 * 86400),
704 balance: 100.0 + (i as f64),
705 period_change: 1.0,
706 })
707 .collect(),
708 frequency: TimeFrequency::Daily,
709 };
710
711 let ts_b = AccountTimeSeries {
713 account_code: "2000".to_string(),
714 data_points: (0..10)
715 .map(|i| TimeSeriesPoint {
716 date: 1700000000 + (i as u64 * 86400),
717 balance: [50.0, 80.0, 45.0, 90.0, 40.0, 85.0, 35.0, 95.0, 30.0, 100.0][i],
718 period_change: 0.0,
719 })
720 .collect(),
721 frequency: TimeFrequency::Daily,
722 };
723
724 let time_series = vec![ts_a, ts_b];
725 let config = CorrelationConfig::default();
726
727 let result = TemporalCorrelation::correlate(&time_series, &config);
728
729 assert_eq!(result.correlations.len(), 1);
730 assert!(result.correlations[0].coefficient.abs() < 0.5);
732 }
733
734 #[test]
735 fn test_insufficient_data() {
736 let ts_a = AccountTimeSeries {
737 account_code: "1000".to_string(),
738 data_points: vec![TimeSeriesPoint {
739 date: 1700000000,
740 balance: 100.0,
741 period_change: 0.0,
742 }],
743 frequency: TimeFrequency::Daily,
744 };
745
746 let ts_b = AccountTimeSeries {
747 account_code: "2000".to_string(),
748 data_points: vec![TimeSeriesPoint {
749 date: 1700000000,
750 balance: 50.0,
751 period_change: 0.0,
752 }],
753 frequency: TimeFrequency::Daily,
754 };
755
756 let time_series = vec![ts_a, ts_b];
757 let config = CorrelationConfig {
758 min_data_points: 10,
759 ..Default::default()
760 };
761
762 let result = TemporalCorrelation::correlate(&time_series, &config);
763
764 assert!(result.correlations.is_empty());
765 }
766
767 #[test]
768 fn test_rolling_correlation() {
769 let (ts_a, ts_b) = create_correlated_series();
770
771 let rolling = TemporalCorrelation::rolling_correlation(&ts_a, &ts_b, 5);
772
773 assert!(!rolling.is_empty());
774 assert!(rolling.iter().all(|r| r.correlation > 0.9));
776 }
777
778 #[test]
779 fn test_correlation_break_detection() {
780 let rolling = vec![
781 RollingCorrelation {
782 end_date: 1700000000,
783 correlation: 0.9,
784 window_size: 5,
785 },
786 RollingCorrelation {
787 end_date: 1700086400,
788 correlation: 0.85,
789 window_size: 5,
790 },
791 RollingCorrelation {
792 end_date: 1700172800,
793 correlation: 0.2,
794 window_size: 5,
795 }, RollingCorrelation {
797 end_date: 1700259200,
798 correlation: 0.25,
799 window_size: 5,
800 },
801 ];
802
803 let breaks = TemporalCorrelation::detect_correlation_breaks(&rolling, 0.5);
804
805 assert_eq!(breaks.len(), 1);
806 assert!((breaks[0].change - (-0.65)).abs() < 0.01);
807 }
808
809 #[test]
810 fn test_correlation_stats() {
811 let (ts_a, ts_b) = create_correlated_series();
812 let time_series = vec![ts_a, ts_b];
813 let config = CorrelationConfig::default();
814
815 let result = TemporalCorrelation::correlate(&time_series, &config);
816
817 assert_eq!(result.stats.accounts_analyzed, 2);
818 assert_eq!(result.stats.significant_correlations, 1);
819 }
820
821 #[test]
822 fn test_expected_correlation_anomaly() {
823 let (ts_a, mut ts_b) = create_correlated_series();
824
825 for (i, point) in ts_b.data_points.iter_mut().enumerate() {
827 point.balance = [50.0, 80.0, 45.0, 90.0, 40.0, 85.0, 35.0, 95.0, 30.0, 100.0][i];
828 }
829
830 let mut expected = HashMap::new();
831 expected.insert(("1000".to_string(), "2000".to_string()), 0.9);
832
833 let time_series = vec![ts_a, ts_b];
834 let config = CorrelationConfig {
835 expected_correlations: Some(expected),
836 correlation_threshold: 0.3,
837 ..Default::default()
838 };
839
840 let result = TemporalCorrelation::correlate(&time_series, &config);
841
842 assert!(
844 result
845 .anomalies
846 .iter()
847 .any(|a| a.anomaly_type == AnomalyType::MissingCorrelation)
848 );
849 }
850}