shape_runtime/multi_table/
analysis.rs1use super::types::Divergence;
4use crate::data::OwnedDataRow as RowValue;
5use shape_ast::error::{Result, ShapeError};
6
7pub struct MultiTableAnalysis;
9
10impl MultiTableAnalysis {
11 pub fn correlation(data1: &[RowValue], data2: &[RowValue]) -> Result<f64> {
13 if data1.len() != data2.len() {
14 return Err(ShapeError::RuntimeError {
15 message: "Data lengths must match for correlation calculation".into(),
16 location: None,
17 });
18 }
19
20 if data1.is_empty() {
21 return Err(ShapeError::RuntimeError {
22 message: "Cannot calculate correlation on empty data".into(),
23 location: None,
24 });
25 }
26
27 let n = data1.len() as f64;
28
29 let get_value = |row: &RowValue| {
30 row.field_names()
31 .next()
32 .and_then(|f| row.get_field(f))
33 .unwrap_or(f64::NAN)
34 };
35
36 let mean1 = data1.iter().map(get_value).sum::<f64>() / n;
37 let mean2 = data2.iter().map(get_value).sum::<f64>() / n;
38
39 let mut covariance = 0.0;
40 let mut variance1 = 0.0;
41 let mut variance2 = 0.0;
42
43 for i in 0..data1.len() {
44 let diff1 = get_value(&data1[i]) - mean1;
45 let diff2 = get_value(&data2[i]) - mean2;
46
47 covariance += diff1 * diff2;
48 variance1 += diff1 * diff1;
49 variance2 += diff2 * diff2;
50 }
51
52 let std1 = (variance1 / n).sqrt();
53 let std2 = (variance2 / n).sqrt();
54
55 if std1 == 0.0 || std2 == 0.0 {
56 return Err(ShapeError::RuntimeError {
57 message: "Cannot calculate correlation with zero variance".into(),
58 location: None,
59 });
60 }
61
62 Ok(covariance / (n * std1 * std2))
63 }
64
65 pub fn find_divergences(
67 data1: &[RowValue],
68 data2: &[RowValue],
69 window: usize,
70 ) -> Result<Vec<Divergence>> {
71 if data1.len() != data2.len() {
72 return Err(ShapeError::RuntimeError {
73 message: "Data series must have equal length".into(),
74 location: None,
75 });
76 }
77
78 if window == 0 || window > data1.len() {
79 return Err(ShapeError::RuntimeError {
80 message: "Invalid window size".into(),
81 location: None,
82 });
83 }
84
85 let mut divergences = Vec::new();
86
87 for i in window..data1.len() {
88 let trend1 = Self::calculate_trend(&data1[i - window..i]);
89 let trend2 = Self::calculate_trend(&data2[i - window..i]);
90
91 if (trend1 > 0.0 && trend2 < 0.0) || (trend1 < 0.0 && trend2 > 0.0) {
92 divergences.push(Divergence {
93 timestamp: data1[i].timestamp,
94 index: i,
95 id1_trend: trend1,
96 id2_trend: trend2,
97 strength: (trend1 - trend2).abs(),
98 });
99 }
100 }
101
102 Ok(divergences)
103 }
104
105 fn calculate_trend(rows: &[RowValue]) -> f64 {
107 if rows.is_empty() {
108 return 0.0;
109 }
110
111 let n = rows.len() as f64;
112 let mut sum_x = 0.0;
113 let mut sum_y = 0.0;
114 let mut sum_xy = 0.0;
115 let mut sum_x2 = 0.0;
116
117 for (i, row) in rows.iter().enumerate() {
118 let x = i as f64;
119 let y = row
120 .field_names()
121 .next()
122 .and_then(|f| row.get_field(f))
123 .unwrap_or(f64::NAN);
124
125 sum_x += x;
126 sum_y += y;
127 sum_xy += x * y;
128 sum_x2 += x * x;
129 }
130
131 let denominator = n * sum_x2 - sum_x * sum_x;
132 if denominator == 0.0 {
133 return 0.0;
134 }
135
136 (n * sum_xy - sum_x * sum_y) / denominator
137 }
138}