1use crate::constraints::Assertion;
10use crate::core::{Constraint, ConstraintMetadata, ConstraintResult};
11use crate::prelude::*;
12use crate::security::SqlSecurity;
13use arrow::array::Array;
14use async_trait::async_trait;
15use datafusion::prelude::*;
16use serde::{Deserialize, Serialize};
17use tracing::instrument;
18
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21pub enum CorrelationType {
22 Pearson,
24 Spearman,
26 KendallTau,
28 MutualInformation {
30 bins: usize,
32 },
33 Covariance,
35 Custom { sql_expression: String },
37}
38
39impl CorrelationType {
40 fn name(&self) -> &str {
42 match self {
43 CorrelationType::Pearson => "Pearson correlation",
44 CorrelationType::Spearman => "Spearman correlation",
45 CorrelationType::KendallTau => "Kendall's tau",
46 CorrelationType::MutualInformation { .. } => "mutual information",
47 CorrelationType::Covariance => "covariance",
48 CorrelationType::Custom { .. } => "custom correlation",
49 }
50 }
51
52 fn constraint_name(&self) -> &str {
54 match self {
55 CorrelationType::Pearson => "correlation",
56 CorrelationType::Spearman => "spearman_correlation",
57 CorrelationType::KendallTau => "kendall_correlation",
58 CorrelationType::MutualInformation { .. } => "mutual_information",
59 CorrelationType::Covariance => "covariance",
60 CorrelationType::Custom { .. } => "custom_correlation",
61 }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67pub struct MultiCorrelationConfig {
68 pub columns: Vec<String>,
70 pub correlation_type: CorrelationType,
72 pub pairwise: bool,
74 pub min_correlation: Option<f64>,
76}
77
78#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
80pub enum CorrelationValidation {
81 Pairwise {
83 column1: String,
84 column2: String,
85 correlation_type: CorrelationType,
86 assertion: Assertion,
87 },
88
89 Range {
91 column1: String,
92 column2: String,
93 correlation_type: CorrelationType,
94 min: f64,
95 max: f64,
96 },
97
98 MultiColumn(MultiCorrelationConfig),
100
101 Independence {
103 column1: String,
104 column2: String,
105 max_correlation: f64,
106 },
107
108 Stability {
110 column1: String,
111 column2: String,
112 segment_column: String,
113 max_variance: f64,
114 },
115}
116
117#[derive(Debug, Clone)]
151pub struct CorrelationConstraint {
152 validation: CorrelationValidation,
154}
155
156impl CorrelationConstraint {
157 pub fn new(validation: CorrelationValidation) -> Result<Self> {
167 match &validation {
169 CorrelationValidation::Pairwise {
170 column1, column2, ..
171 }
172 | CorrelationValidation::Range {
173 column1, column2, ..
174 }
175 | CorrelationValidation::Independence {
176 column1, column2, ..
177 }
178 | CorrelationValidation::Stability {
179 column1, column2, ..
180 } => {
181 SqlSecurity::validate_identifier(column1)?;
182 SqlSecurity::validate_identifier(column2)?;
183 }
184 CorrelationValidation::MultiColumn(config) => {
185 if config.columns.len() < 2 {
186 return Err(TermError::Configuration(
187 "At least 2 columns required for correlation analysis".to_string(),
188 ));
189 }
190 for column in &config.columns {
191 SqlSecurity::validate_identifier(column)?;
192 }
193 }
194 }
195
196 Ok(Self { validation })
197 }
198
199 pub fn pearson(
201 column1: impl Into<String>,
202 column2: impl Into<String>,
203 assertion: Assertion,
204 ) -> Result<Self> {
205 Self::new(CorrelationValidation::Pairwise {
206 column1: column1.into(),
207 column2: column2.into(),
208 correlation_type: CorrelationType::Pearson,
209 assertion,
210 })
211 }
212
213 pub fn spearman(
215 column1: impl Into<String>,
216 column2: impl Into<String>,
217 assertion: Assertion,
218 ) -> Result<Self> {
219 Self::new(CorrelationValidation::Pairwise {
220 column1: column1.into(),
221 column2: column2.into(),
222 correlation_type: CorrelationType::Spearman,
223 assertion,
224 })
225 }
226
227 pub fn mutual_information(
229 column1: impl Into<String>,
230 column2: impl Into<String>,
231 bins: usize,
232 assertion: Assertion,
233 ) -> Result<Self> {
234 Self::new(CorrelationValidation::Pairwise {
235 column1: column1.into(),
236 column2: column2.into(),
237 correlation_type: CorrelationType::MutualInformation { bins },
238 assertion,
239 })
240 }
241
242 pub fn independence(
244 column1: impl Into<String>,
245 column2: impl Into<String>,
246 max_correlation: f64,
247 ) -> Result<Self> {
248 if !(0.0..=1.0).contains(&max_correlation) {
249 return Err(TermError::Configuration(
250 "Max correlation must be between 0.0 and 1.0".to_string(),
251 ));
252 }
253 Self::new(CorrelationValidation::Independence {
254 column1: column1.into(),
255 column2: column2.into(),
256 max_correlation,
257 })
258 }
259
260 fn pearson_sql(&self, col1: &str, col2: &str) -> Result<String> {
262 let escaped_col1 = SqlSecurity::escape_identifier(col1)?;
263 let escaped_col2 = SqlSecurity::escape_identifier(col2)?;
264
265 Ok(format!("CORR({escaped_col1}, {escaped_col2})"))
267 }
268
269 fn covariance_sql(&self, col1: &str, col2: &str) -> Result<String> {
271 let escaped_col1 = SqlSecurity::escape_identifier(col1)?;
272 let escaped_col2 = SqlSecurity::escape_identifier(col2)?;
273
274 Ok(format!("COVAR_SAMP({escaped_col1}, {escaped_col2})"))
276 }
277
278 #[allow(dead_code)]
280 fn spearman_sql(&self, col1: &str, col2: &str) -> Result<String> {
281 let escaped_col1 = SqlSecurity::escape_identifier(col1)?;
282 let escaped_col2 = SqlSecurity::escape_identifier(col2)?;
283
284 Ok(format!(
287 "CORR(
288 RANK() OVER (ORDER BY {escaped_col1}) AS rank1,
289 RANK() OVER (ORDER BY {escaped_col2}) AS rank2
290 )"
291 ))
292 }
293}
294
295#[async_trait]
296impl Constraint for CorrelationConstraint {
297 #[instrument(skip(self, ctx), fields(
298 validation = ?self.validation
299 ))]
300 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
301 match &self.validation {
302 CorrelationValidation::Pairwise {
303 column1,
304 column2,
305 correlation_type,
306 assertion,
307 } => {
308 let sql = match correlation_type {
309 CorrelationType::Pearson => {
310 format!(
311 "SELECT {} as corr_value FROM data",
312 self.pearson_sql(column1, column2)?
313 )
314 }
315 CorrelationType::Covariance => {
316 format!(
317 "SELECT {} as corr_value FROM data",
318 self.covariance_sql(column1, column2)?
319 )
320 }
321 CorrelationType::Custom { sql_expression } => {
322 if sql_expression.contains(';')
324 || sql_expression.to_lowercase().contains("drop")
325 {
326 return Ok(ConstraintResult::failure(
327 "Custom SQL expression contains potentially unsafe content",
328 ));
329 }
330 let escaped_col1 = SqlSecurity::escape_identifier(column1)?;
331 let escaped_col2 = SqlSecurity::escape_identifier(column2)?;
332 let expr = sql_expression
333 .replace("{column1}", &escaped_col1)
334 .replace("{column2}", &escaped_col2);
335 format!("SELECT {expr} as corr_value FROM data")
336 }
337 _ => {
338 return Ok(ConstraintResult::skipped(
340 "Correlation type not yet implemented",
341 ));
342 }
343 };
344
345 let df = ctx.sql(&sql).await?;
346 let batches = df.collect().await?;
347
348 if batches.is_empty() || batches[0].num_rows() == 0 {
349 return Ok(ConstraintResult::skipped("No data to validate"));
350 }
351
352 let value = batches[0]
353 .column(0)
354 .as_any()
355 .downcast_ref::<arrow::array::Float64Array>()
356 .ok_or_else(|| {
357 TermError::Internal("Failed to downcast to Float64Array".to_string())
358 })?
359 .value(0);
360
361 if assertion.evaluate(value) {
362 Ok(ConstraintResult::success_with_metric(value))
363 } else {
364 Ok(ConstraintResult::failure_with_metric(
365 value,
366 format!(
367 "{} between {column1} and {column2} is {value} which does not {assertion}",
368 correlation_type.name()
369 ),
370 ))
371 }
372 }
373 CorrelationValidation::Range {
374 column1,
375 column2,
376 correlation_type,
377 min,
378 max,
379 } => {
380 let result = self
382 .evaluate_with_validation(
383 ctx,
384 &CorrelationValidation::Pairwise {
385 column1: column1.clone(),
386 column2: column2.clone(),
387 correlation_type: correlation_type.clone(),
388 assertion: Assertion::Between(*min, *max),
389 },
390 )
391 .await?;
392 Ok(result)
393 }
394 CorrelationValidation::Independence {
395 column1,
396 column2,
397 max_correlation,
398 } => {
399 let sql = format!(
400 "SELECT ABS({}) as abs_corr FROM data",
401 self.pearson_sql(column1, column2)?
402 );
403
404 let df = ctx.sql(&sql).await?;
405 let batches = df.collect().await?;
406
407 if batches.is_empty() || batches[0].num_rows() == 0 {
408 return Ok(ConstraintResult::skipped("No data to validate"));
409 }
410
411 let abs_corr = batches[0]
412 .column(0)
413 .as_any()
414 .downcast_ref::<arrow::array::Float64Array>()
415 .ok_or_else(|| {
416 TermError::Internal("Failed to downcast to Float64Array".to_string())
417 })?
418 .value(0);
419
420 if abs_corr <= *max_correlation {
421 Ok(ConstraintResult::success_with_metric(abs_corr))
422 } else {
423 Ok(ConstraintResult::failure_with_metric(
424 abs_corr,
425 format!(
426 "Columns {column1} and {column2} have correlation {abs_corr} exceeding independence threshold {max_correlation}"
427 ),
428 ))
429 }
430 }
431 _ => Ok(ConstraintResult::skipped(
432 "Validation type not yet implemented",
433 )),
434 }
435 }
436
437 fn name(&self) -> &str {
438 match &self.validation {
439 CorrelationValidation::Pairwise {
440 correlation_type, ..
441 } => correlation_type.constraint_name(),
442 CorrelationValidation::Range { .. } => "correlation_range",
443 CorrelationValidation::Independence { .. } => "independence",
444 CorrelationValidation::MultiColumn { .. } => "multi_correlation",
445 CorrelationValidation::Stability { .. } => "correlation_stability",
446 }
447 }
448
449 fn metadata(&self) -> ConstraintMetadata {
450 let description = match &self.validation {
451 CorrelationValidation::Pairwise {
452 column1,
453 column2,
454 correlation_type,
455 ..
456 } => format!(
457 "Validates {} between '{column1}' and '{column2}'",
458 correlation_type.name()
459 ),
460 CorrelationValidation::Range {
461 column1, column2, ..
462 } => format!(
463 "Validates correlation range between '{column1}' and '{column2}'"
464 ),
465 CorrelationValidation::Independence {
466 column1, column2, ..
467 } => format!(
468 "Validates independence between '{column1}' and '{column2}'"
469 ),
470 CorrelationValidation::MultiColumn(config) => format!(
471 "Validates correlations among columns: {}",
472 config.columns.join(", ")
473 ),
474 CorrelationValidation::Stability {
475 column1,
476 column2,
477 segment_column,
478 ..
479 } => format!(
480 "Validates correlation stability between '{column1}' and '{column2}' across '{segment_column}'"
481 ),
482 };
483
484 ConstraintMetadata::new().with_description(description)
485 }
486}
487
488impl CorrelationConstraint {
489 async fn evaluate_with_validation(
491 &self,
492 ctx: &SessionContext,
493 validation: &CorrelationValidation,
494 ) -> Result<ConstraintResult> {
495 let temp_constraint = Self::new(validation.clone())?;
496 temp_constraint.evaluate(ctx).await
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use crate::core::ConstraintStatus;
504 use arrow::array::Float64Array;
505 use arrow::datatypes::{DataType, Field, Schema};
506 use arrow::record_batch::RecordBatch;
507 use datafusion::datasource::MemTable;
508 use std::sync::Arc;
509
510 async fn create_test_context_correlated() -> SessionContext {
511 let ctx = SessionContext::new();
512
513 let schema = Arc::new(Schema::new(vec![
514 Field::new("x", DataType::Float64, true),
515 Field::new("y", DataType::Float64, true),
516 ]));
517
518 let mut x_values = Vec::new();
520 let mut y_values = Vec::new();
521
522 for i in 0..100 {
523 let x = i as f64;
524 let y = 2.0 * x + (i % 10) as f64 - 5.0; x_values.push(Some(x));
526 y_values.push(Some(y));
527 }
528
529 let batch = RecordBatch::try_new(
530 schema.clone(),
531 vec![
532 Arc::new(Float64Array::from(x_values)),
533 Arc::new(Float64Array::from(y_values)),
534 ],
535 )
536 .unwrap();
537
538 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
539 ctx.register_table("data", Arc::new(provider)).unwrap();
540
541 ctx
542 }
543
544 async fn create_test_context_independent() -> SessionContext {
545 let ctx = SessionContext::new();
546
547 let schema = Arc::new(Schema::new(vec![
548 Field::new("x", DataType::Float64, true),
549 Field::new("y", DataType::Float64, true),
550 ]));
551
552 let mut x_values = Vec::new();
554 let mut y_values = Vec::new();
555
556 for i in 0..100 {
557 x_values.push(Some(i as f64));
558 y_values.push(Some(((i * 37) % 100) as f64)); }
560
561 let batch = RecordBatch::try_new(
562 schema.clone(),
563 vec![
564 Arc::new(Float64Array::from(x_values)),
565 Arc::new(Float64Array::from(y_values)),
566 ],
567 )
568 .unwrap();
569
570 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
571 ctx.register_table("data", Arc::new(provider)).unwrap();
572
573 ctx
574 }
575
576 #[tokio::test]
577 async fn test_pearson_correlation() {
578 let ctx = create_test_context_correlated().await;
579
580 let constraint =
581 CorrelationConstraint::pearson("x", "y", Assertion::GreaterThan(0.9)).unwrap();
582
583 let result = constraint.evaluate(&ctx).await.unwrap();
584 assert_eq!(result.status, ConstraintStatus::Success);
585 assert!(result.metric.unwrap() > 0.9);
586 }
587
588 #[tokio::test]
589 async fn test_independence_check() {
590 let ctx = create_test_context_independent().await;
591
592 let constraint = CorrelationConstraint::independence("x", "y", 0.3).unwrap();
593
594 let result = constraint.evaluate(&ctx).await.unwrap();
595 assert_eq!(result.status, ConstraintStatus::Success);
597 }
598
599 #[tokio::test]
600 async fn test_correlation_range() {
601 let ctx = create_test_context_correlated().await;
602
603 let constraint = CorrelationConstraint::new(CorrelationValidation::Range {
604 column1: "x".to_string(),
605 column2: "y".to_string(),
606 correlation_type: CorrelationType::Pearson,
607 min: 0.8,
608 max: 1.0,
609 })
610 .unwrap();
611
612 let result = constraint.evaluate(&ctx).await.unwrap();
613 assert_eq!(result.status, ConstraintStatus::Success);
614 }
615
616 #[test]
617 fn test_invalid_max_correlation() {
618 let result = CorrelationConstraint::independence("x", "y", 1.5);
619 assert!(result.is_err());
620 assert!(result
621 .unwrap_err()
622 .to_string()
623 .contains("Max correlation must be between 0.0 and 1.0"));
624 }
625
626 #[test]
627 fn test_multi_column_validation() {
628 let config = MultiCorrelationConfig {
629 columns: vec!["a".to_string(), "b".to_string(), "c".to_string()],
630 correlation_type: CorrelationType::Pearson,
631 pairwise: true,
632 min_correlation: Some(0.5),
633 };
634
635 let result = CorrelationConstraint::new(CorrelationValidation::MultiColumn(config));
636 assert!(result.is_ok());
637 }
638
639 #[test]
640 fn test_multi_column_too_few() {
641 let config = MultiCorrelationConfig {
642 columns: vec!["a".to_string()],
643 correlation_type: CorrelationType::Pearson,
644 pairwise: true,
645 min_correlation: None,
646 };
647
648 let result = CorrelationConstraint::new(CorrelationValidation::MultiColumn(config));
649 assert!(result.is_err());
650 assert!(result
651 .unwrap_err()
652 .to_string()
653 .contains("At least 2 columns required"));
654 }
655}