Skip to main content

shape_runtime/multi_table/
analysis.rs

1//! Multi-series analysis utilities
2
3use super::types::Divergence;
4use crate::data::OwnedDataRow as RowValue;
5use shape_ast::error::{Result, ShapeError};
6
7/// Multi-series analysis utilities
8pub struct MultiTableAnalysis;
9
10impl MultiTableAnalysis {
11    /// Calculate correlation between two aligned datasets
12    pub fn correlation(data1: &[RowValue], data2: &[RowValue]) -> Result<f64> {
13        if data1.len() != data2.len() {
14            return Err(ShapeError::RuntimeError {
15                message: "Data lengths must match for correlation calculation".into(),
16                location: None,
17            });
18        }
19
20        if data1.is_empty() {
21            return Err(ShapeError::RuntimeError {
22                message: "Cannot calculate correlation on empty data".into(),
23                location: None,
24            });
25        }
26
27        let n = data1.len() as f64;
28
29        let get_value = |row: &RowValue| {
30            row.field_names()
31                .next()
32                .and_then(|f| row.get_field(f))
33                .unwrap_or(f64::NAN)
34        };
35
36        let mean1 = data1.iter().map(get_value).sum::<f64>() / n;
37        let mean2 = data2.iter().map(get_value).sum::<f64>() / n;
38
39        let mut covariance = 0.0;
40        let mut variance1 = 0.0;
41        let mut variance2 = 0.0;
42
43        for i in 0..data1.len() {
44            let diff1 = get_value(&data1[i]) - mean1;
45            let diff2 = get_value(&data2[i]) - mean2;
46
47            covariance += diff1 * diff2;
48            variance1 += diff1 * diff1;
49            variance2 += diff2 * diff2;
50        }
51
52        let std1 = (variance1 / n).sqrt();
53        let std2 = (variance2 / n).sqrt();
54
55        if std1 == 0.0 || std2 == 0.0 {
56            return Err(ShapeError::RuntimeError {
57                message: "Cannot calculate correlation with zero variance".into(),
58                location: None,
59            });
60        }
61
62        Ok(covariance / (n * std1 * std2))
63    }
64
65    /// Find divergences between two data series
66    pub fn find_divergences(
67        data1: &[RowValue],
68        data2: &[RowValue],
69        window: usize,
70    ) -> Result<Vec<Divergence>> {
71        if data1.len() != data2.len() {
72            return Err(ShapeError::RuntimeError {
73                message: "Data series must have equal length".into(),
74                location: None,
75            });
76        }
77
78        if window == 0 || window > data1.len() {
79            return Err(ShapeError::RuntimeError {
80                message: "Invalid window size".into(),
81                location: None,
82            });
83        }
84
85        let mut divergences = Vec::new();
86
87        for i in window..data1.len() {
88            let trend1 = Self::calculate_trend(&data1[i - window..i]);
89            let trend2 = Self::calculate_trend(&data2[i - window..i]);
90
91            if (trend1 > 0.0 && trend2 < 0.0) || (trend1 < 0.0 && trend2 > 0.0) {
92                divergences.push(Divergence {
93                    timestamp: data1[i].timestamp,
94                    index: i,
95                    id1_trend: trend1,
96                    id2_trend: trend2,
97                    strength: (trend1 - trend2).abs(),
98                });
99            }
100        }
101
102        Ok(divergences)
103    }
104
105    /// Calculate simple trend (slope) over a window
106    fn calculate_trend(rows: &[RowValue]) -> f64 {
107        if rows.is_empty() {
108            return 0.0;
109        }
110
111        let n = rows.len() as f64;
112        let mut sum_x = 0.0;
113        let mut sum_y = 0.0;
114        let mut sum_xy = 0.0;
115        let mut sum_x2 = 0.0;
116
117        for (i, row) in rows.iter().enumerate() {
118            let x = i as f64;
119            let y = row
120                .field_names()
121                .next()
122                .and_then(|f| row.get_field(f))
123                .unwrap_or(f64::NAN);
124
125            sum_x += x;
126            sum_y += y;
127            sum_xy += x * y;
128            sum_x2 += x * x;
129        }
130
131        let denominator = n * sum_x2 - sum_x * sum_x;
132        if denominator == 0.0 {
133            return 0.0;
134        }
135
136        (n * sum_xy - sum_x * sum_y) / denominator
137    }
138}