term_guard/analyzers/advanced/
correlation.rs

1//! Correlation analyzer for computing relationships between numeric columns.
2//!
3//! This module provides analyzers for computing various types of correlations
4//! including Pearson, Spearman, and Kendall's tau correlations between pairs
5//! of numeric columns.
6
7use crate::analyzers::{Analyzer, AnalyzerError, AnalyzerResult, AnalyzerState, MetricValue};
8use crate::security::SqlSecurity;
9use arrow::array::{Array, ArrayRef};
10use async_trait::async_trait;
11use datafusion::prelude::*;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14use tracing::instrument;
15
16/// Types of correlation that can be computed.
17#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
18pub enum CorrelationType {
19    /// Pearson correlation coefficient (-1 to 1)
20    Pearson,
21    /// Spearman rank correlation coefficient
22    Spearman,
23    /// Kendall's tau correlation
24    KendallTau,
25    /// Covariance
26    Covariance,
27}
28
29impl CorrelationType {
30    /// Returns a human-readable name for this correlation type.
31    pub fn name(&self) -> &str {
32        match self {
33            CorrelationType::Pearson => "Pearson",
34            CorrelationType::Spearman => "Spearman",
35            CorrelationType::KendallTau => "Kendall's tau",
36            CorrelationType::Covariance => "Covariance",
37        }
38    }
39}
40
41/// State for correlation computation.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CorrelationState {
44    /// Number of valid pairs (both values non-null)
45    pub n: u64,
46    /// Sum of x values
47    pub sum_x: f64,
48    /// Sum of y values
49    pub sum_y: f64,
50    /// Sum of x squared
51    pub sum_x2: f64,
52    /// Sum of y squared
53    pub sum_y2: f64,
54    /// Sum of x*y
55    pub sum_xy: f64,
56    /// For Spearman: ranks of x values
57    pub x_ranks: Option<Vec<f64>>,
58    /// For Spearman: ranks of y values
59    pub y_ranks: Option<Vec<f64>>,
60    /// The correlation type being computed
61    pub correlation_type: CorrelationType,
62}
63
64impl AnalyzerState for CorrelationState {
65    fn merge(states: Vec<Self>) -> AnalyzerResult<Self>
66    where
67        Self: Sized,
68    {
69        if states.is_empty() {
70            return Err(AnalyzerError::state_merge("Cannot merge empty states"));
71        }
72
73        let first = &states[0];
74        let correlation_type = first.correlation_type.clone();
75
76        // For Pearson and Covariance, we can merge by summing
77        if matches!(
78            correlation_type,
79            CorrelationType::Pearson | CorrelationType::Covariance
80        ) {
81            let mut merged = CorrelationState {
82                n: 0,
83                sum_x: 0.0,
84                sum_y: 0.0,
85                sum_x2: 0.0,
86                sum_y2: 0.0,
87                sum_xy: 0.0,
88                x_ranks: None,
89                y_ranks: None,
90                correlation_type,
91            };
92
93            for state in states {
94                merged.n += state.n;
95                merged.sum_x += state.sum_x;
96                merged.sum_y += state.sum_y;
97                merged.sum_x2 += state.sum_x2;
98                merged.sum_y2 += state.sum_y2;
99                merged.sum_xy += state.sum_xy;
100            }
101
102            Ok(merged)
103        } else {
104            // For rank-based correlations, merging is more complex
105            // and would require re-ranking combined data
106            Err(AnalyzerError::state_merge(
107                "Cannot merge rank-based correlation states",
108            ))
109        }
110    }
111
112    fn is_empty(&self) -> bool {
113        self.n == 0
114    }
115}
116
117/// Analyzer for computing correlation between two numeric columns.
118///
119/// # Example
120///
121/// ```rust,ignore
122/// use term_guard::analyzers::advanced::{CorrelationAnalyzer, CorrelationType};
123///
124/// let analyzer = CorrelationAnalyzer::new(
125///     "height",
126///     "weight",
127///     CorrelationType::Pearson
128/// );
129/// ```
130#[derive(Debug, Clone)]
131pub struct CorrelationAnalyzer {
132    /// First column name
133    column1: String,
134    /// Second column name
135    column2: String,
136    /// Type of correlation to compute
137    correlation_type: CorrelationType,
138}
139
140impl CorrelationAnalyzer {
141    /// Creates a new correlation analyzer.
142    pub fn new(
143        column1: impl Into<String>,
144        column2: impl Into<String>,
145        correlation_type: CorrelationType,
146    ) -> Self {
147        Self {
148            column1: column1.into(),
149            column2: column2.into(),
150            correlation_type,
151        }
152    }
153
154    /// Creates a Pearson correlation analyzer.
155    pub fn pearson(column1: impl Into<String>, column2: impl Into<String>) -> Self {
156        Self::new(column1, column2, CorrelationType::Pearson)
157    }
158
159    /// Creates a Spearman correlation analyzer.
160    pub fn spearman(column1: impl Into<String>, column2: impl Into<String>) -> Self {
161        Self::new(column1, column2, CorrelationType::Spearman)
162    }
163
164    /// Creates a covariance analyzer.
165    pub fn covariance(column1: impl Into<String>, column2: impl Into<String>) -> Self {
166        Self::new(column1, column2, CorrelationType::Covariance)
167    }
168
169    /// Computes ranks for Spearman correlation (used in tests).
170    #[allow(dead_code)]
171    fn compute_ranks(values: &[f64]) -> Vec<f64> {
172        let mut indexed: Vec<(usize, f64)> =
173            values.iter().enumerate().map(|(i, &v)| (i, v)).collect();
174        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
175
176        let mut ranks = vec![0.0; values.len()];
177        let mut i = 0;
178        while i < indexed.len() {
179            let mut j = i;
180            // Find all equal values
181            while j < indexed.len() && indexed[j].1 == indexed[i].1 {
182                j += 1;
183            }
184            // Assign average rank to all equal values
185            let avg_rank = (i + j) as f64 / 2.0 + 0.5;
186            for k in i..j {
187                ranks[indexed[k].0] = avg_rank;
188            }
189            i = j;
190        }
191        ranks
192    }
193
194    /// Extracts a numeric value from an Arrow array column, supporting various numeric types.
195    fn extract_numeric_value(column: &ArrayRef, field_name: &str) -> AnalyzerResult<f64> {
196        // Try different numeric array types that DataFusion might return
197        if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Float64Array>() {
198            Ok(arr.value(0))
199        } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Int64Array>() {
200            Ok(arr.value(0) as f64)
201        } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::UInt64Array>() {
202            Ok(arr.value(0) as f64)
203        } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Int32Array>() {
204            Ok(arr.value(0) as f64)
205        } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::UInt32Array>() {
206            Ok(arr.value(0) as f64)
207        } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Float32Array>() {
208            Ok(arr.value(0) as f64)
209        } else {
210            Err(AnalyzerError::state_computation(format!(
211                "Failed to get {field_name}: unsupported array type"
212            )))
213        }
214    }
215}
216
217#[async_trait]
218impl Analyzer for CorrelationAnalyzer {
219    type State = CorrelationState;
220    type Metric = MetricValue;
221
222    #[instrument(skip(self, ctx), fields(
223        column1 = %self.column1,
224        column2 = %self.column2,
225        correlation_type = ?self.correlation_type
226    ))]
227    async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
228        match self.correlation_type {
229            CorrelationType::Pearson | CorrelationType::Covariance => {
230                // Validate and escape column identifiers to prevent SQL injection
231                let col1_escaped = SqlSecurity::escape_identifier(&self.column1).map_err(|e| {
232                    AnalyzerError::state_computation(format!("Invalid column1 name: {e}"))
233                })?;
234                let col2_escaped = SqlSecurity::escape_identifier(&self.column2).map_err(|e| {
235                    AnalyzerError::state_computation(format!("Invalid column2 name: {e}"))
236                })?;
237
238                // Compute sums for Pearson correlation or covariance using escaped identifiers
239                let sql = format!(
240                    "SELECT 
241                        COUNT(*) as n,
242                        SUM(CAST({col1_escaped} AS DOUBLE)) as sum_x,
243                        SUM(CAST({col2_escaped} AS DOUBLE)) as sum_y,
244                        SUM(CAST({col1_escaped} AS DOUBLE) * CAST({col1_escaped} AS DOUBLE)) as sum_x2,
245                        SUM(CAST({col2_escaped} AS DOUBLE) * CAST({col2_escaped} AS DOUBLE)) as sum_y2,
246                        SUM(CAST({col1_escaped} AS DOUBLE) * CAST({col2_escaped} AS DOUBLE)) as sum_xy
247                    FROM data
248                    WHERE {col1_escaped} IS NOT NULL AND {col2_escaped} IS NOT NULL"
249                );
250
251                let df = ctx.sql(&sql).await?;
252                let batches = df.collect().await?;
253
254                if batches.is_empty() || batches[0].num_rows() == 0 {
255                    return Ok(CorrelationState {
256                        n: 0,
257                        sum_x: 0.0,
258                        sum_y: 0.0,
259                        sum_x2: 0.0,
260                        sum_y2: 0.0,
261                        sum_xy: 0.0,
262                        x_ranks: None,
263                        y_ranks: None,
264                        correlation_type: self.correlation_type.clone(),
265                    });
266                }
267
268                let batch = &batches[0];
269                let n = batch
270                    .column(0)
271                    .as_any()
272                    .downcast_ref::<arrow::array::Int64Array>()
273                    .ok_or_else(|| AnalyzerError::state_computation("Failed to get count"))?
274                    .value(0) as u64;
275
276                let sum_x = batch
277                    .column(1)
278                    .as_any()
279                    .downcast_ref::<arrow::array::Float64Array>()
280                    .ok_or_else(|| AnalyzerError::state_computation("Failed to get sum_x"))?
281                    .value(0);
282
283                let sum_y = batch
284                    .column(2)
285                    .as_any()
286                    .downcast_ref::<arrow::array::Float64Array>()
287                    .ok_or_else(|| AnalyzerError::state_computation("Failed to get sum_y"))?
288                    .value(0);
289
290                let sum_x2 = batch
291                    .column(3)
292                    .as_any()
293                    .downcast_ref::<arrow::array::Float64Array>()
294                    .ok_or_else(|| AnalyzerError::state_computation("Failed to get sum_x2"))?
295                    .value(0);
296
297                let sum_y2 = batch
298                    .column(4)
299                    .as_any()
300                    .downcast_ref::<arrow::array::Float64Array>()
301                    .ok_or_else(|| AnalyzerError::state_computation("Failed to get sum_y2"))?
302                    .value(0);
303
304                let sum_xy = batch
305                    .column(5)
306                    .as_any()
307                    .downcast_ref::<arrow::array::Float64Array>()
308                    .ok_or_else(|| AnalyzerError::state_computation("Failed to get sum_xy"))?
309                    .value(0);
310
311                Ok(CorrelationState {
312                    n,
313                    sum_x,
314                    sum_y,
315                    sum_x2,
316                    sum_y2,
317                    sum_xy,
318                    x_ranks: None,
319                    y_ranks: None,
320                    correlation_type: self.correlation_type.clone(),
321                })
322            }
323            CorrelationType::Spearman => {
324                // Validate and escape column identifiers to prevent SQL injection
325                let col1_escaped = SqlSecurity::escape_identifier(&self.column1).map_err(|e| {
326                    AnalyzerError::state_computation(format!("Invalid column1 name: {e}"))
327                })?;
328                let col2_escaped = SqlSecurity::escape_identifier(&self.column2).map_err(|e| {
329                    AnalyzerError::state_computation(format!("Invalid column2 name: {e}"))
330                })?;
331
332                // For Spearman, compute ranks directly in SQL for memory efficiency
333                // This avoids loading all data into memory and is much faster for large datasets
334                let sql = format!(
335                    "WITH ranked AS (
336                        SELECT 
337                            RANK() OVER (ORDER BY CAST({col1_escaped} AS DOUBLE)) as rank_x,
338                            RANK() OVER (ORDER BY CAST({col2_escaped} AS DOUBLE)) as rank_y
339                        FROM data
340                        WHERE {col1_escaped} IS NOT NULL AND {col2_escaped} IS NOT NULL
341                    )
342                    SELECT 
343                        COUNT(*) as n,
344                        SUM(rank_x) as sum_x,
345                        SUM(rank_y) as sum_y,
346                        SUM(rank_x * rank_x) as sum_x2,
347                        SUM(rank_y * rank_y) as sum_y2,
348                        SUM(rank_x * rank_y) as sum_xy
349                    FROM ranked"
350                );
351
352                let df = ctx.sql(&sql).await?;
353                let batches = df.collect().await?;
354
355                if batches.is_empty() || batches[0].num_rows() == 0 {
356                    return Ok(CorrelationState {
357                        n: 0,
358                        sum_x: 0.0,
359                        sum_y: 0.0,
360                        sum_x2: 0.0,
361                        sum_y2: 0.0,
362                        sum_xy: 0.0,
363                        x_ranks: None, // No need to store ranks when computed in SQL
364                        y_ranks: None,
365                        correlation_type: self.correlation_type.clone(),
366                    });
367                }
368
369                let batch = &batches[0];
370
371                // Handle count which should be Int64
372                let n = batch
373                    .column(0)
374                    .as_any()
375                    .downcast_ref::<arrow::array::Int64Array>()
376                    .ok_or_else(|| AnalyzerError::state_computation("Failed to get count"))?
377                    .value(0) as u64;
378
379                // Extract numeric values with support for various Arrow types that DataFusion might return
380                let sum_x = Self::extract_numeric_value(batch.column(1), "sum_x")?;
381
382                let sum_y = Self::extract_numeric_value(batch.column(2), "sum_y")?;
383                let sum_x2 = Self::extract_numeric_value(batch.column(3), "sum_x2")?;
384                let sum_y2 = Self::extract_numeric_value(batch.column(4), "sum_y2")?;
385                let sum_xy = Self::extract_numeric_value(batch.column(5), "sum_xy")?;
386
387                Ok(CorrelationState {
388                    n,
389                    sum_x,
390                    sum_y,
391                    sum_x2,
392                    sum_y2,
393                    sum_xy,
394                    x_ranks: None, // No need to store individual ranks with SQL-based computation
395                    y_ranks: None,
396                    correlation_type: self.correlation_type.clone(),
397                })
398            }
399            CorrelationType::KendallTau => {
400                // Kendall's tau requires pairwise comparisons
401                // This is a simplified implementation
402                Err(AnalyzerError::custom("Kendall's tau not yet implemented"))
403            }
404        }
405    }
406
407    fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
408        if state.n < 2 {
409            return Ok(MetricValue::Double(f64::NAN));
410        }
411
412        let n = state.n as f64;
413
414        match state.correlation_type {
415            CorrelationType::Pearson | CorrelationType::Spearman => {
416                // Pearson correlation formula (same for Spearman on ranks)
417                let numerator = n * state.sum_xy - state.sum_x * state.sum_y;
418                let denominator = ((n * state.sum_x2 - state.sum_x * state.sum_x)
419                    * (n * state.sum_y2 - state.sum_y * state.sum_y))
420                    .sqrt();
421
422                if denominator == 0.0 {
423                    Ok(MetricValue::Double(0.0))
424                } else {
425                    Ok(MetricValue::Double(numerator / denominator))
426                }
427            }
428            CorrelationType::Covariance => {
429                // Sample covariance
430                let covariance = (state.sum_xy - (state.sum_x * state.sum_y) / n) / (n - 1.0);
431                Ok(MetricValue::Double(covariance))
432            }
433            CorrelationType::KendallTau => Ok(MetricValue::Double(f64::NAN)),
434        }
435    }
436
437    fn name(&self) -> &str {
438        "correlation"
439    }
440
441    fn description(&self) -> &str {
442        "Computes correlation between two numeric columns"
443    }
444
445    fn metric_key(&self) -> String {
446        format!(
447            "correlation_{}_{}_{}",
448            self.correlation_type.name().to_lowercase(),
449            self.column1,
450            self.column2
451        )
452    }
453
454    fn columns(&self) -> Vec<&str> {
455        vec![&self.column1, &self.column2]
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462    use arrow::array::Float64Array;
463    use arrow::datatypes::{DataType, Field, Schema};
464    use arrow::record_batch::RecordBatch;
465    use datafusion::datasource::MemTable;
466    use std::sync::Arc;
467
468    async fn create_test_context() -> SessionContext {
469        let ctx = SessionContext::new();
470
471        let schema = Arc::new(Schema::new(vec![
472            Field::new("x", DataType::Float64, true),
473            Field::new("y", DataType::Float64, true),
474        ]));
475
476        // Create perfectly correlated data: y = 2x + 1
477        let x_values: Vec<Option<f64>> = (0..100).map(|i| Some(i as f64)).collect();
478        let y_values: Vec<Option<f64>> =
479            x_values.iter().map(|x| x.map(|v| 2.0 * v + 1.0)).collect();
480
481        let batch = RecordBatch::try_new(
482            schema.clone(),
483            vec![
484                Arc::new(Float64Array::from(x_values)),
485                Arc::new(Float64Array::from(y_values)),
486            ],
487        )
488        .unwrap();
489
490        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
491        ctx.register_table("data", Arc::new(provider)).unwrap();
492
493        ctx
494    }
495
496    #[tokio::test]
497    async fn test_pearson_correlation_perfect() {
498        let ctx = create_test_context().await;
499        let analyzer = CorrelationAnalyzer::pearson("x", "y");
500
501        let state = analyzer.compute_state_from_data(&ctx).await.unwrap();
502        let metric = analyzer.compute_metric_from_state(&state).unwrap();
503
504        if let MetricValue::Double(corr) = metric {
505            assert!((corr - 1.0).abs() < 0.0001, "Expected perfect correlation");
506        } else {
507            panic!("Expected Double metric");
508        }
509    }
510
511    #[tokio::test]
512    async fn test_covariance() {
513        let ctx = create_test_context().await;
514        let analyzer = CorrelationAnalyzer::covariance("x", "y");
515
516        let state = analyzer.compute_state_from_data(&ctx).await.unwrap();
517        let metric = analyzer.compute_metric_from_state(&state).unwrap();
518
519        if let MetricValue::Double(cov) = metric {
520            // For y = 2x + 1, covariance should be 2 * Var(x)
521            // Var(x) for 0..99 is approximately 833.25
522            assert!(
523                cov > 1600.0 && cov < 1700.0,
524                "Expected covariance around 1666"
525            );
526        } else {
527            panic!("Expected Double metric");
528        }
529    }
530
531    #[tokio::test]
532    async fn test_spearman_correlation() {
533        let ctx = create_test_context().await;
534        let analyzer = CorrelationAnalyzer::spearman("x", "y");
535
536        let state = analyzer.compute_state_from_data(&ctx).await.unwrap();
537        let metric = analyzer.compute_metric_from_state(&state).unwrap();
538
539        if let MetricValue::Double(corr) = metric {
540            // For monotonic relationship, Spearman should be 1.0
541            assert!(
542                (corr - 1.0).abs() < 0.0001,
543                "Expected perfect rank correlation"
544            );
545        } else {
546            panic!("Expected Double metric");
547        }
548    }
549
550    #[test]
551    fn test_compute_ranks() {
552        let values = vec![3.0, 1.0, 4.0, 1.0, 5.0];
553        let ranks = CorrelationAnalyzer::compute_ranks(&values);
554
555        // Expected ranks: [3.0, 1.5, 4.0, 1.5, 5.0]
556        assert_eq!(ranks[0], 3.0);
557        assert_eq!(ranks[1], 1.5); // Tied for rank 1 and 2
558        assert_eq!(ranks[2], 4.0);
559        assert_eq!(ranks[3], 1.5); // Tied for rank 1 and 2
560        assert_eq!(ranks[4], 5.0);
561    }
562}