term_guard/analyzers/advanced/
correlation.rs1use crate::analyzers::{Analyzer, AnalyzerError, AnalyzerResult, AnalyzerState, MetricValue};
8use crate::security::SqlSecurity;
9use arrow::array::{Array, ArrayRef};
10use async_trait::async_trait;
11use datafusion::prelude::*;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14use tracing::instrument;
15
16#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
18pub enum CorrelationType {
19 Pearson,
21 Spearman,
23 KendallTau,
25 Covariance,
27}
28
29impl CorrelationType {
30 pub fn name(&self) -> &str {
32 match self {
33 CorrelationType::Pearson => "Pearson",
34 CorrelationType::Spearman => "Spearman",
35 CorrelationType::KendallTau => "Kendall's tau",
36 CorrelationType::Covariance => "Covariance",
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CorrelationState {
44 pub n: u64,
46 pub sum_x: f64,
48 pub sum_y: f64,
50 pub sum_x2: f64,
52 pub sum_y2: f64,
54 pub sum_xy: f64,
56 pub x_ranks: Option<Vec<f64>>,
58 pub y_ranks: Option<Vec<f64>>,
60 pub correlation_type: CorrelationType,
62}
63
64impl AnalyzerState for CorrelationState {
65 fn merge(states: Vec<Self>) -> AnalyzerResult<Self>
66 where
67 Self: Sized,
68 {
69 if states.is_empty() {
70 return Err(AnalyzerError::state_merge("Cannot merge empty states"));
71 }
72
73 let first = &states[0];
74 let correlation_type = first.correlation_type.clone();
75
76 if matches!(
78 correlation_type,
79 CorrelationType::Pearson | CorrelationType::Covariance
80 ) {
81 let mut merged = CorrelationState {
82 n: 0,
83 sum_x: 0.0,
84 sum_y: 0.0,
85 sum_x2: 0.0,
86 sum_y2: 0.0,
87 sum_xy: 0.0,
88 x_ranks: None,
89 y_ranks: None,
90 correlation_type,
91 };
92
93 for state in states {
94 merged.n += state.n;
95 merged.sum_x += state.sum_x;
96 merged.sum_y += state.sum_y;
97 merged.sum_x2 += state.sum_x2;
98 merged.sum_y2 += state.sum_y2;
99 merged.sum_xy += state.sum_xy;
100 }
101
102 Ok(merged)
103 } else {
104 Err(AnalyzerError::state_merge(
107 "Cannot merge rank-based correlation states",
108 ))
109 }
110 }
111
112 fn is_empty(&self) -> bool {
113 self.n == 0
114 }
115}
116
117#[derive(Debug, Clone)]
131pub struct CorrelationAnalyzer {
132 column1: String,
134 column2: String,
136 correlation_type: CorrelationType,
138}
139
140impl CorrelationAnalyzer {
141 pub fn new(
143 column1: impl Into<String>,
144 column2: impl Into<String>,
145 correlation_type: CorrelationType,
146 ) -> Self {
147 Self {
148 column1: column1.into(),
149 column2: column2.into(),
150 correlation_type,
151 }
152 }
153
154 pub fn pearson(column1: impl Into<String>, column2: impl Into<String>) -> Self {
156 Self::new(column1, column2, CorrelationType::Pearson)
157 }
158
159 pub fn spearman(column1: impl Into<String>, column2: impl Into<String>) -> Self {
161 Self::new(column1, column2, CorrelationType::Spearman)
162 }
163
164 pub fn covariance(column1: impl Into<String>, column2: impl Into<String>) -> Self {
166 Self::new(column1, column2, CorrelationType::Covariance)
167 }
168
169 #[allow(dead_code)]
171 fn compute_ranks(values: &[f64]) -> Vec<f64> {
172 let mut indexed: Vec<(usize, f64)> =
173 values.iter().enumerate().map(|(i, &v)| (i, v)).collect();
174 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
175
176 let mut ranks = vec![0.0; values.len()];
177 let mut i = 0;
178 while i < indexed.len() {
179 let mut j = i;
180 while j < indexed.len() && indexed[j].1 == indexed[i].1 {
182 j += 1;
183 }
184 let avg_rank = (i + j) as f64 / 2.0 + 0.5;
186 for k in i..j {
187 ranks[indexed[k].0] = avg_rank;
188 }
189 i = j;
190 }
191 ranks
192 }
193
194 fn extract_numeric_value(column: &ArrayRef, field_name: &str) -> AnalyzerResult<f64> {
196 if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Float64Array>() {
198 Ok(arr.value(0))
199 } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Int64Array>() {
200 Ok(arr.value(0) as f64)
201 } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::UInt64Array>() {
202 Ok(arr.value(0) as f64)
203 } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Int32Array>() {
204 Ok(arr.value(0) as f64)
205 } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::UInt32Array>() {
206 Ok(arr.value(0) as f64)
207 } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Float32Array>() {
208 Ok(arr.value(0) as f64)
209 } else {
210 Err(AnalyzerError::state_computation(format!(
211 "Failed to get {field_name}: unsupported array type"
212 )))
213 }
214 }
215}
216
217#[async_trait]
218impl Analyzer for CorrelationAnalyzer {
219 type State = CorrelationState;
220 type Metric = MetricValue;
221
222 #[instrument(skip(self, ctx), fields(
223 column1 = %self.column1,
224 column2 = %self.column2,
225 correlation_type = ?self.correlation_type
226 ))]
227 async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
228 match self.correlation_type {
229 CorrelationType::Pearson | CorrelationType::Covariance => {
230 let col1_escaped = SqlSecurity::escape_identifier(&self.column1).map_err(|e| {
232 AnalyzerError::state_computation(format!("Invalid column1 name: {e}"))
233 })?;
234 let col2_escaped = SqlSecurity::escape_identifier(&self.column2).map_err(|e| {
235 AnalyzerError::state_computation(format!("Invalid column2 name: {e}"))
236 })?;
237
238 let sql = format!(
240 "SELECT
241 COUNT(*) as n,
242 SUM(CAST({col1_escaped} AS DOUBLE)) as sum_x,
243 SUM(CAST({col2_escaped} AS DOUBLE)) as sum_y,
244 SUM(CAST({col1_escaped} AS DOUBLE) * CAST({col1_escaped} AS DOUBLE)) as sum_x2,
245 SUM(CAST({col2_escaped} AS DOUBLE) * CAST({col2_escaped} AS DOUBLE)) as sum_y2,
246 SUM(CAST({col1_escaped} AS DOUBLE) * CAST({col2_escaped} AS DOUBLE)) as sum_xy
247 FROM data
248 WHERE {col1_escaped} IS NOT NULL AND {col2_escaped} IS NOT NULL"
249 );
250
251 let df = ctx.sql(&sql).await?;
252 let batches = df.collect().await?;
253
254 if batches.is_empty() || batches[0].num_rows() == 0 {
255 return Ok(CorrelationState {
256 n: 0,
257 sum_x: 0.0,
258 sum_y: 0.0,
259 sum_x2: 0.0,
260 sum_y2: 0.0,
261 sum_xy: 0.0,
262 x_ranks: None,
263 y_ranks: None,
264 correlation_type: self.correlation_type.clone(),
265 });
266 }
267
268 let batch = &batches[0];
269 let n = batch
270 .column(0)
271 .as_any()
272 .downcast_ref::<arrow::array::Int64Array>()
273 .ok_or_else(|| AnalyzerError::state_computation("Failed to get count"))?
274 .value(0) as u64;
275
276 let sum_x = batch
277 .column(1)
278 .as_any()
279 .downcast_ref::<arrow::array::Float64Array>()
280 .ok_or_else(|| AnalyzerError::state_computation("Failed to get sum_x"))?
281 .value(0);
282
283 let sum_y = batch
284 .column(2)
285 .as_any()
286 .downcast_ref::<arrow::array::Float64Array>()
287 .ok_or_else(|| AnalyzerError::state_computation("Failed to get sum_y"))?
288 .value(0);
289
290 let sum_x2 = batch
291 .column(3)
292 .as_any()
293 .downcast_ref::<arrow::array::Float64Array>()
294 .ok_or_else(|| AnalyzerError::state_computation("Failed to get sum_x2"))?
295 .value(0);
296
297 let sum_y2 = batch
298 .column(4)
299 .as_any()
300 .downcast_ref::<arrow::array::Float64Array>()
301 .ok_or_else(|| AnalyzerError::state_computation("Failed to get sum_y2"))?
302 .value(0);
303
304 let sum_xy = batch
305 .column(5)
306 .as_any()
307 .downcast_ref::<arrow::array::Float64Array>()
308 .ok_or_else(|| AnalyzerError::state_computation("Failed to get sum_xy"))?
309 .value(0);
310
311 Ok(CorrelationState {
312 n,
313 sum_x,
314 sum_y,
315 sum_x2,
316 sum_y2,
317 sum_xy,
318 x_ranks: None,
319 y_ranks: None,
320 correlation_type: self.correlation_type.clone(),
321 })
322 }
323 CorrelationType::Spearman => {
324 let col1_escaped = SqlSecurity::escape_identifier(&self.column1).map_err(|e| {
326 AnalyzerError::state_computation(format!("Invalid column1 name: {e}"))
327 })?;
328 let col2_escaped = SqlSecurity::escape_identifier(&self.column2).map_err(|e| {
329 AnalyzerError::state_computation(format!("Invalid column2 name: {e}"))
330 })?;
331
332 let sql = format!(
335 "WITH ranked AS (
336 SELECT
337 RANK() OVER (ORDER BY CAST({col1_escaped} AS DOUBLE)) as rank_x,
338 RANK() OVER (ORDER BY CAST({col2_escaped} AS DOUBLE)) as rank_y
339 FROM data
340 WHERE {col1_escaped} IS NOT NULL AND {col2_escaped} IS NOT NULL
341 )
342 SELECT
343 COUNT(*) as n,
344 SUM(rank_x) as sum_x,
345 SUM(rank_y) as sum_y,
346 SUM(rank_x * rank_x) as sum_x2,
347 SUM(rank_y * rank_y) as sum_y2,
348 SUM(rank_x * rank_y) as sum_xy
349 FROM ranked"
350 );
351
352 let df = ctx.sql(&sql).await?;
353 let batches = df.collect().await?;
354
355 if batches.is_empty() || batches[0].num_rows() == 0 {
356 return Ok(CorrelationState {
357 n: 0,
358 sum_x: 0.0,
359 sum_y: 0.0,
360 sum_x2: 0.0,
361 sum_y2: 0.0,
362 sum_xy: 0.0,
363 x_ranks: None, y_ranks: None,
365 correlation_type: self.correlation_type.clone(),
366 });
367 }
368
369 let batch = &batches[0];
370
371 let n = batch
373 .column(0)
374 .as_any()
375 .downcast_ref::<arrow::array::Int64Array>()
376 .ok_or_else(|| AnalyzerError::state_computation("Failed to get count"))?
377 .value(0) as u64;
378
379 let sum_x = Self::extract_numeric_value(batch.column(1), "sum_x")?;
381
382 let sum_y = Self::extract_numeric_value(batch.column(2), "sum_y")?;
383 let sum_x2 = Self::extract_numeric_value(batch.column(3), "sum_x2")?;
384 let sum_y2 = Self::extract_numeric_value(batch.column(4), "sum_y2")?;
385 let sum_xy = Self::extract_numeric_value(batch.column(5), "sum_xy")?;
386
387 Ok(CorrelationState {
388 n,
389 sum_x,
390 sum_y,
391 sum_x2,
392 sum_y2,
393 sum_xy,
394 x_ranks: None, y_ranks: None,
396 correlation_type: self.correlation_type.clone(),
397 })
398 }
399 CorrelationType::KendallTau => {
400 Err(AnalyzerError::custom("Kendall's tau not yet implemented"))
403 }
404 }
405 }
406
407 fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
408 if state.n < 2 {
409 return Ok(MetricValue::Double(f64::NAN));
410 }
411
412 let n = state.n as f64;
413
414 match state.correlation_type {
415 CorrelationType::Pearson | CorrelationType::Spearman => {
416 let numerator = n * state.sum_xy - state.sum_x * state.sum_y;
418 let denominator = ((n * state.sum_x2 - state.sum_x * state.sum_x)
419 * (n * state.sum_y2 - state.sum_y * state.sum_y))
420 .sqrt();
421
422 if denominator == 0.0 {
423 Ok(MetricValue::Double(0.0))
424 } else {
425 Ok(MetricValue::Double(numerator / denominator))
426 }
427 }
428 CorrelationType::Covariance => {
429 let covariance = (state.sum_xy - (state.sum_x * state.sum_y) / n) / (n - 1.0);
431 Ok(MetricValue::Double(covariance))
432 }
433 CorrelationType::KendallTau => Ok(MetricValue::Double(f64::NAN)),
434 }
435 }
436
437 fn name(&self) -> &str {
438 "correlation"
439 }
440
441 fn description(&self) -> &str {
442 "Computes correlation between two numeric columns"
443 }
444
445 fn metric_key(&self) -> String {
446 format!(
447 "correlation_{}_{}_{}",
448 self.correlation_type.name().to_lowercase(),
449 self.column1,
450 self.column2
451 )
452 }
453
454 fn columns(&self) -> Vec<&str> {
455 vec![&self.column1, &self.column2]
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462 use arrow::array::Float64Array;
463 use arrow::datatypes::{DataType, Field, Schema};
464 use arrow::record_batch::RecordBatch;
465 use datafusion::datasource::MemTable;
466 use std::sync::Arc;
467
468 async fn create_test_context() -> SessionContext {
469 let ctx = SessionContext::new();
470
471 let schema = Arc::new(Schema::new(vec![
472 Field::new("x", DataType::Float64, true),
473 Field::new("y", DataType::Float64, true),
474 ]));
475
476 let x_values: Vec<Option<f64>> = (0..100).map(|i| Some(i as f64)).collect();
478 let y_values: Vec<Option<f64>> =
479 x_values.iter().map(|x| x.map(|v| 2.0 * v + 1.0)).collect();
480
481 let batch = RecordBatch::try_new(
482 schema.clone(),
483 vec![
484 Arc::new(Float64Array::from(x_values)),
485 Arc::new(Float64Array::from(y_values)),
486 ],
487 )
488 .unwrap();
489
490 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
491 ctx.register_table("data", Arc::new(provider)).unwrap();
492
493 ctx
494 }
495
496 #[tokio::test]
497 async fn test_pearson_correlation_perfect() {
498 let ctx = create_test_context().await;
499 let analyzer = CorrelationAnalyzer::pearson("x", "y");
500
501 let state = analyzer.compute_state_from_data(&ctx).await.unwrap();
502 let metric = analyzer.compute_metric_from_state(&state).unwrap();
503
504 if let MetricValue::Double(corr) = metric {
505 assert!((corr - 1.0).abs() < 0.0001, "Expected perfect correlation");
506 } else {
507 panic!("Expected Double metric");
508 }
509 }
510
511 #[tokio::test]
512 async fn test_covariance() {
513 let ctx = create_test_context().await;
514 let analyzer = CorrelationAnalyzer::covariance("x", "y");
515
516 let state = analyzer.compute_state_from_data(&ctx).await.unwrap();
517 let metric = analyzer.compute_metric_from_state(&state).unwrap();
518
519 if let MetricValue::Double(cov) = metric {
520 assert!(
523 cov > 1600.0 && cov < 1700.0,
524 "Expected covariance around 1666"
525 );
526 } else {
527 panic!("Expected Double metric");
528 }
529 }
530
531 #[tokio::test]
532 async fn test_spearman_correlation() {
533 let ctx = create_test_context().await;
534 let analyzer = CorrelationAnalyzer::spearman("x", "y");
535
536 let state = analyzer.compute_state_from_data(&ctx).await.unwrap();
537 let metric = analyzer.compute_metric_from_state(&state).unwrap();
538
539 if let MetricValue::Double(corr) = metric {
540 assert!(
542 (corr - 1.0).abs() < 0.0001,
543 "Expected perfect rank correlation"
544 );
545 } else {
546 panic!("Expected Double metric");
547 }
548 }
549
550 #[test]
551 fn test_compute_ranks() {
552 let values = vec![3.0, 1.0, 4.0, 1.0, 5.0];
553 let ranks = CorrelationAnalyzer::compute_ranks(&values);
554
555 assert_eq!(ranks[0], 3.0);
557 assert_eq!(ranks[1], 1.5); assert_eq!(ranks[2], 4.0);
559 assert_eq!(ranks[3], 1.5); assert_eq!(ranks[4], 5.0);
561 }
562}