1use crate::constraints::Assertion;
10use crate::core::{current_validation_context, Constraint, ConstraintMetadata, ConstraintResult};
11use crate::prelude::*;
12use crate::security::SqlSecurity;
13use arrow::array::Array;
14use async_trait::async_trait;
15use datafusion::prelude::*;
16use serde::{Deserialize, Serialize};
17use tracing::instrument;
18#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
20pub enum CorrelationType {
21 Pearson,
23 Spearman,
25 KendallTau,
27 MutualInformation {
29 bins: usize,
31 },
32 Covariance,
34 Custom { sql_expression: String },
36}
37
38impl CorrelationType {
39 fn name(&self) -> &str {
41 match self {
42 CorrelationType::Pearson => "Pearson correlation",
43 CorrelationType::Spearman => "Spearman correlation",
44 CorrelationType::KendallTau => "Kendall's tau",
45 CorrelationType::MutualInformation { .. } => "mutual information",
46 CorrelationType::Covariance => "covariance",
47 CorrelationType::Custom { .. } => "custom correlation",
48 }
49 }
50
51 fn constraint_name(&self) -> &str {
53 match self {
54 CorrelationType::Pearson => "correlation",
55 CorrelationType::Spearman => "spearman_correlation",
56 CorrelationType::KendallTau => "kendall_correlation",
57 CorrelationType::MutualInformation { .. } => "mutual_information",
58 CorrelationType::Covariance => "covariance",
59 CorrelationType::Custom { .. } => "custom_correlation",
60 }
61 }
62}
63
64#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
66pub struct MultiCorrelationConfig {
67 pub columns: Vec<String>,
69 pub correlation_type: CorrelationType,
71 pub pairwise: bool,
73 pub min_correlation: Option<f64>,
75}
76
77#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
79pub enum CorrelationValidation {
80 Pairwise {
82 column1: String,
83 column2: String,
84 correlation_type: CorrelationType,
85 assertion: Assertion,
86 },
87
88 Range {
90 column1: String,
91 column2: String,
92 correlation_type: CorrelationType,
93 min: f64,
94 max: f64,
95 },
96
97 MultiColumn(MultiCorrelationConfig),
99
100 Independence {
102 column1: String,
103 column2: String,
104 max_correlation: f64,
105 },
106
107 Stability {
109 column1: String,
110 column2: String,
111 segment_column: String,
112 max_variance: f64,
113 },
114}
115
116#[derive(Debug, Clone)]
150pub struct CorrelationConstraint {
151 validation: CorrelationValidation,
153}
154
155impl CorrelationConstraint {
156 pub fn new(validation: CorrelationValidation) -> Result<Self> {
166 match &validation {
168 CorrelationValidation::Pairwise {
169 column1, column2, ..
170 }
171 | CorrelationValidation::Range {
172 column1, column2, ..
173 }
174 | CorrelationValidation::Independence {
175 column1, column2, ..
176 }
177 | CorrelationValidation::Stability {
178 column1, column2, ..
179 } => {
180 SqlSecurity::validate_identifier(column1)?;
181 SqlSecurity::validate_identifier(column2)?;
182 }
183 CorrelationValidation::MultiColumn(config) => {
184 if config.columns.len() < 2 {
185 return Err(TermError::Configuration(
186 "At least 2 columns required for correlation analysis".to_string(),
187 ));
188 }
189 for column in &config.columns {
190 SqlSecurity::validate_identifier(column)?;
191 }
192 }
193 }
194
195 Ok(Self { validation })
196 }
197
198 pub fn pearson(
200 column1: impl Into<String>,
201 column2: impl Into<String>,
202 assertion: Assertion,
203 ) -> Result<Self> {
204 Self::new(CorrelationValidation::Pairwise {
205 column1: column1.into(),
206 column2: column2.into(),
207 correlation_type: CorrelationType::Pearson,
208 assertion,
209 })
210 }
211
212 pub fn spearman(
214 column1: impl Into<String>,
215 column2: impl Into<String>,
216 assertion: Assertion,
217 ) -> Result<Self> {
218 Self::new(CorrelationValidation::Pairwise {
219 column1: column1.into(),
220 column2: column2.into(),
221 correlation_type: CorrelationType::Spearman,
222 assertion,
223 })
224 }
225
226 pub fn mutual_information(
228 column1: impl Into<String>,
229 column2: impl Into<String>,
230 bins: usize,
231 assertion: Assertion,
232 ) -> Result<Self> {
233 Self::new(CorrelationValidation::Pairwise {
234 column1: column1.into(),
235 column2: column2.into(),
236 correlation_type: CorrelationType::MutualInformation { bins },
237 assertion,
238 })
239 }
240
241 pub fn independence(
243 column1: impl Into<String>,
244 column2: impl Into<String>,
245 max_correlation: f64,
246 ) -> Result<Self> {
247 if !(0.0..=1.0).contains(&max_correlation) {
248 return Err(TermError::Configuration(
249 "Max correlation must be between 0.0 and 1.0".to_string(),
250 ));
251 }
252 Self::new(CorrelationValidation::Independence {
253 column1: column1.into(),
254 column2: column2.into(),
255 max_correlation,
256 })
257 }
258
259 fn pearson_sql(&self, col1: &str, col2: &str) -> Result<String> {
261 let escaped_col1 = SqlSecurity::escape_identifier(col1)?;
262 let escaped_col2 = SqlSecurity::escape_identifier(col2)?;
263
264 Ok(format!("CORR({escaped_col1}, {escaped_col2})"))
266 }
267
268 fn covariance_sql(&self, col1: &str, col2: &str) -> Result<String> {
270 let escaped_col1 = SqlSecurity::escape_identifier(col1)?;
271 let escaped_col2 = SqlSecurity::escape_identifier(col2)?;
272
273 Ok(format!("COVAR_SAMP({escaped_col1}, {escaped_col2})"))
275 }
276
277 #[allow(dead_code)]
279 fn spearman_sql(&self, col1: &str, col2: &str) -> Result<String> {
280 let escaped_col1 = SqlSecurity::escape_identifier(col1)?;
281 let escaped_col2 = SqlSecurity::escape_identifier(col2)?;
282
283 Ok(format!(
286 "CORR(
287 RANK() OVER (ORDER BY {escaped_col1}) AS rank1,
288 RANK() OVER (ORDER BY {escaped_col2}) AS rank2
289 )"
290 ))
291 }
292}
293
294#[async_trait]
295impl Constraint for CorrelationConstraint {
296 #[instrument(skip(self, ctx), fields(
297 validation = ?self.validation
298 ))]
299 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
300 let validation_ctx = current_validation_context();
302 let table_name = validation_ctx.table_name();
303
304 match &self.validation {
305 CorrelationValidation::Pairwise {
306 column1,
307 column2,
308 correlation_type,
309 assertion,
310 } => {
311 let sql = match correlation_type {
312 CorrelationType::Pearson => {
313 format!(
314 "SELECT {} as corr_value FROM {table_name}",
315 self.pearson_sql(column1, column2)?
316 )
317 }
318 CorrelationType::Covariance => {
319 format!(
320 "SELECT {} as corr_value FROM {table_name}",
321 self.covariance_sql(column1, column2)?
322 )
323 }
324 CorrelationType::Custom { sql_expression } => {
325 if sql_expression.contains(';')
327 || sql_expression.to_lowercase().contains("drop")
328 {
329 return Ok(ConstraintResult::failure(
330 "Custom SQL expression contains potentially unsafe content",
331 ));
332 }
333 let escaped_col1 = SqlSecurity::escape_identifier(column1)?;
334 let escaped_col2 = SqlSecurity::escape_identifier(column2)?;
335 let expr = sql_expression
336 .replace("{column1}", &escaped_col1)
337 .replace("{column2}", &escaped_col2);
338 format!("SELECT {expr} as corr_value FROM {table_name}")
339 }
340 _ => {
341 return Ok(ConstraintResult::skipped(
343 "Correlation type not yet implemented",
344 ));
345 }
346 };
347
348 let df = ctx.sql(&sql).await?;
349 let batches = df.collect().await?;
350
351 if batches.is_empty() || batches[0].num_rows() == 0 {
352 return Ok(ConstraintResult::skipped("No data to validate"));
353 }
354
355 let value = batches[0]
356 .column(0)
357 .as_any()
358 .downcast_ref::<arrow::array::Float64Array>()
359 .ok_or_else(|| {
360 TermError::Internal("Failed to downcast to Float64Array".to_string())
361 })?
362 .value(0);
363
364 if assertion.evaluate(value) {
365 Ok(ConstraintResult::success_with_metric(value))
366 } else {
367 Ok(ConstraintResult::failure_with_metric(
368 value,
369 format!(
370 "{} between {column1} and {column2} is {value} which does not {assertion}",
371 correlation_type.name()
372 ),
373 ))
374 }
375 }
376 CorrelationValidation::Range {
377 column1,
378 column2,
379 correlation_type,
380 min,
381 max,
382 } => {
383 let result = self
385 .evaluate_with_validation(
386 ctx,
387 &CorrelationValidation::Pairwise {
388 column1: column1.clone(),
389 column2: column2.clone(),
390 correlation_type: correlation_type.clone(),
391 assertion: Assertion::Between(*min, *max),
392 },
393 )
394 .await?;
395 Ok(result)
396 }
397 CorrelationValidation::Independence {
398 column1,
399 column2,
400 max_correlation,
401 } => {
402 let validation_ctx = current_validation_context();
405
406 let table_name = validation_ctx.table_name();
407
408 let sql = format!(
409 "SELECT ABS({}) as abs_corr FROM {table_name}",
410 self.pearson_sql(column1, column2)?
411 );
412
413 let df = ctx.sql(&sql).await?;
414 let batches = df.collect().await?;
415
416 if batches.is_empty() || batches[0].num_rows() == 0 {
417 return Ok(ConstraintResult::skipped("No data to validate"));
418 }
419
420 let abs_corr = batches[0]
421 .column(0)
422 .as_any()
423 .downcast_ref::<arrow::array::Float64Array>()
424 .ok_or_else(|| {
425 TermError::Internal("Failed to downcast to Float64Array".to_string())
426 })?
427 .value(0);
428
429 if abs_corr <= *max_correlation {
430 Ok(ConstraintResult::success_with_metric(abs_corr))
431 } else {
432 Ok(ConstraintResult::failure_with_metric(
433 abs_corr,
434 format!(
435 "Columns {column1} and {column2} have correlation {abs_corr} exceeding independence threshold {max_correlation}"
436 ),
437 ))
438 }
439 }
440 _ => Ok(ConstraintResult::skipped(
441 "Validation type not yet implemented",
442 )),
443 }
444 }
445
446 fn name(&self) -> &str {
447 match &self.validation {
448 CorrelationValidation::Pairwise {
449 correlation_type, ..
450 } => correlation_type.constraint_name(),
451 CorrelationValidation::Range { .. } => "correlation_range",
452 CorrelationValidation::Independence { .. } => "independence",
453 CorrelationValidation::MultiColumn { .. } => "multi_correlation",
454 CorrelationValidation::Stability { .. } => "correlation_stability",
455 }
456 }
457
458 fn metadata(&self) -> ConstraintMetadata {
459 let description = match &self.validation {
460 CorrelationValidation::Pairwise {
461 column1,
462 column2,
463 correlation_type,
464 ..
465 } => format!(
466 "Validates {} between '{column1}' and '{column2}'",
467 correlation_type.name()
468 ),
469 CorrelationValidation::Range {
470 column1, column2, ..
471 } => format!(
472 "Validates correlation range between '{column1}' and '{column2}'"
473 ),
474 CorrelationValidation::Independence {
475 column1, column2, ..
476 } => format!(
477 "Validates independence between '{column1}' and '{column2}'"
478 ),
479 CorrelationValidation::MultiColumn(config) => format!(
480 "Validates correlations among columns: {}",
481 config.columns.join(", ")
482 ),
483 CorrelationValidation::Stability {
484 column1,
485 column2,
486 segment_column,
487 ..
488 } => format!(
489 "Validates correlation stability between '{column1}' and '{column2}' across '{segment_column}'"
490 ),
491 };
492
493 ConstraintMetadata::new().with_description(description)
494 }
495}
496
497impl CorrelationConstraint {
498 async fn evaluate_with_validation(
500 &self,
501 ctx: &SessionContext,
502 validation: &CorrelationValidation,
503 ) -> Result<ConstraintResult> {
504 let temp_constraint = Self::new(validation.clone())?;
505 temp_constraint.evaluate(ctx).await
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use crate::core::ConstraintStatus;
513 use arrow::array::Float64Array;
514 use arrow::datatypes::{DataType, Field, Schema};
515 use arrow::record_batch::RecordBatch;
516 use datafusion::datasource::MemTable;
517 use std::sync::Arc;
518
519 use crate::test_helpers::evaluate_constraint_with_context;
520 async fn create_test_context_correlated() -> SessionContext {
521 let ctx = SessionContext::new();
522
523 let schema = Arc::new(Schema::new(vec![
524 Field::new("x", DataType::Float64, true),
525 Field::new("y", DataType::Float64, true),
526 ]));
527
528 let mut x_values = Vec::new();
530 let mut y_values = Vec::new();
531
532 for i in 0..100 {
533 let x = i as f64;
534 let y = 2.0 * x + (i % 10) as f64 - 5.0; x_values.push(Some(x));
536 y_values.push(Some(y));
537 }
538
539 let batch = RecordBatch::try_new(
540 schema.clone(),
541 vec![
542 Arc::new(Float64Array::from(x_values)),
543 Arc::new(Float64Array::from(y_values)),
544 ],
545 )
546 .unwrap();
547
548 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
549 ctx.register_table("data", Arc::new(provider)).unwrap();
550
551 ctx
552 }
553
554 async fn create_test_context_independent() -> SessionContext {
555 let ctx = SessionContext::new();
556
557 let schema = Arc::new(Schema::new(vec![
558 Field::new("x", DataType::Float64, true),
559 Field::new("y", DataType::Float64, true),
560 ]));
561
562 let mut x_values = Vec::new();
564 let mut y_values = Vec::new();
565
566 for i in 0..100 {
567 x_values.push(Some(i as f64));
568 y_values.push(Some(((i * 37) % 100) as f64)); }
570
571 let batch = RecordBatch::try_new(
572 schema.clone(),
573 vec![
574 Arc::new(Float64Array::from(x_values)),
575 Arc::new(Float64Array::from(y_values)),
576 ],
577 )
578 .unwrap();
579
580 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
581 ctx.register_table("data", Arc::new(provider)).unwrap();
582
583 ctx
584 }
585
586 #[tokio::test]
587 async fn test_pearson_correlation() {
588 let ctx = create_test_context_correlated().await;
589
590 let constraint =
591 CorrelationConstraint::pearson("x", "y", Assertion::GreaterThan(0.9)).unwrap();
592
593 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
594 .await
595 .unwrap();
596 assert_eq!(result.status, ConstraintStatus::Success);
597 assert!(result.metric.unwrap() > 0.9);
598 }
599
600 #[tokio::test]
601 async fn test_independence_check() {
602 let ctx = create_test_context_independent().await;
603
604 let constraint = CorrelationConstraint::independence("x", "y", 0.3).unwrap();
605
606 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
607 .await
608 .unwrap();
609 assert_eq!(result.status, ConstraintStatus::Success);
611 }
612
613 #[tokio::test]
614 async fn test_correlation_range() {
615 let ctx = create_test_context_correlated().await;
616
617 let constraint = CorrelationConstraint::new(CorrelationValidation::Range {
618 column1: "x".to_string(),
619 column2: "y".to_string(),
620 correlation_type: CorrelationType::Pearson,
621 min: 0.8,
622 max: 1.0,
623 })
624 .unwrap();
625
626 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
627 .await
628 .unwrap();
629 assert_eq!(result.status, ConstraintStatus::Success);
630 }
631
632 #[test]
633 fn test_invalid_max_correlation() {
634 let result = CorrelationConstraint::independence("x", "y", 1.5);
635 assert!(result.is_err());
636 assert!(result
637 .unwrap_err()
638 .to_string()
639 .contains("Max correlation must be between 0.0 and 1.0"));
640 }
641
642 #[test]
643 fn test_multi_column_validation() {
644 let config = MultiCorrelationConfig {
645 columns: vec!["a".to_string(), "b".to_string(), "c".to_string()],
646 correlation_type: CorrelationType::Pearson,
647 pairwise: true,
648 min_correlation: Some(0.5),
649 };
650
651 let result = CorrelationConstraint::new(CorrelationValidation::MultiColumn(config));
652 assert!(result.is_ok());
653 }
654
655 #[test]
656 fn test_multi_column_too_few() {
657 let config = MultiCorrelationConfig {
658 columns: vec!["a".to_string()],
659 correlation_type: CorrelationType::Pearson,
660 pairwise: true,
661 min_correlation: None,
662 };
663
664 let result = CorrelationConstraint::new(CorrelationValidation::MultiColumn(config));
665 assert!(result.is_err());
666 assert!(result
667 .unwrap_err()
668 .to_string()
669 .contains("At least 2 columns required"));
670 }
671}