term_guard/constraints/
correlation.rs

1//! Unified correlation constraint that consolidates correlation and relationship validations.
2//!
3//! This module provides a single, flexible correlation constraint that replaces:
4//! - `CorrelationConstraint` - Pearson correlation between columns
5//! - `MutualInformationConstraint` - Mutual information between columns
6//!
7//! And adds support for other correlation types and multi-column relationships.
8
9use crate::constraints::Assertion;
10use crate::core::{current_validation_context, Constraint, ConstraintMetadata, ConstraintResult};
11use crate::prelude::*;
12use crate::security::SqlSecurity;
13use arrow::array::Array;
14use async_trait::async_trait;
15use datafusion::prelude::*;
16use serde::{Deserialize, Serialize};
17use tracing::instrument;
18/// Types of correlation that can be computed.
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
20pub enum CorrelationType {
21    /// Pearson correlation coefficient (-1 to 1)
22    Pearson,
23    /// Spearman rank correlation coefficient
24    Spearman,
25    /// Kendall's tau correlation
26    KendallTau,
27    /// Mutual information (non-negative)
28    MutualInformation {
29        /// Number of bins for discretization
30        bins: usize,
31    },
32    /// Covariance
33    Covariance,
34    /// Custom correlation using SQL expression
35    Custom { sql_expression: String },
36}
37
38impl CorrelationType {
39    /// Returns a human-readable name for this correlation type.
40    fn name(&self) -> &str {
41        match self {
42            CorrelationType::Pearson => "Pearson correlation",
43            CorrelationType::Spearman => "Spearman correlation",
44            CorrelationType::KendallTau => "Kendall's tau",
45            CorrelationType::MutualInformation { .. } => "mutual information",
46            CorrelationType::Covariance => "covariance",
47            CorrelationType::Custom { .. } => "custom correlation",
48        }
49    }
50
51    /// Returns the constraint name for backward compatibility.
52    fn constraint_name(&self) -> &str {
53        match self {
54            CorrelationType::Pearson => "correlation",
55            CorrelationType::Spearman => "spearman_correlation",
56            CorrelationType::KendallTau => "kendall_correlation",
57            CorrelationType::MutualInformation { .. } => "mutual_information",
58            CorrelationType::Covariance => "covariance",
59            CorrelationType::Custom { .. } => "custom_correlation",
60        }
61    }
62}
63
64/// Configuration for multi-column correlation analysis.
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
66pub struct MultiCorrelationConfig {
67    /// Columns to analyze (must have at least 2)
68    pub columns: Vec<String>,
69    /// Type of correlation to compute
70    pub correlation_type: CorrelationType,
71    /// Whether to compute pairwise correlations
72    pub pairwise: bool,
73    /// Minimum threshold for correlation strength (optional)
74    pub min_correlation: Option<f64>,
75}
76
77/// Types of correlation validation that can be performed.
78#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
79pub enum CorrelationValidation {
80    /// Check correlation between two columns
81    Pairwise {
82        column1: String,
83        column2: String,
84        correlation_type: CorrelationType,
85        assertion: Assertion,
86    },
87
88    /// Check that correlation is within expected range
89    Range {
90        column1: String,
91        column2: String,
92        correlation_type: CorrelationType,
93        min: f64,
94        max: f64,
95    },
96
97    /// Check multiple correlations
98    MultiColumn(MultiCorrelationConfig),
99
100    /// Check that columns are independent (low correlation)
101    Independence {
102        column1: String,
103        column2: String,
104        max_correlation: f64,
105    },
106
107    /// Check correlation stability over time/segments
108    Stability {
109        column1: String,
110        column2: String,
111        segment_column: String,
112        max_variance: f64,
113    },
114}
115
116/// A unified constraint that validates correlations and relationships between columns.
117///
118/// This constraint replaces individual correlation constraints and provides
119/// a consistent interface for all correlation-based validations.
120///
121/// # Examples
122///
123/// ```rust
124/// use term_guard::constraints::{CorrelationConstraint, CorrelationType, Assertion};
125/// use term_guard::core::Constraint;
126///
127/// // Check Pearson correlation
128/// let pearson_check = CorrelationConstraint::pearson(
129///     "height",
130///     "weight",
131///     Assertion::Between(0.6, 0.9)
132/// );
133///
134/// // Check independence
135/// let independence_check = CorrelationConstraint::independence(
136///     "user_id",
137///     "transaction_amount",
138///     0.1
139/// );
140///
141/// // Check mutual information
142/// let mi_check = CorrelationConstraint::mutual_information(
143///     "category",
144///     "price",
145///     10, // bins
146///     Assertion::GreaterThan(0.5)
147/// );
148/// ```
149#[derive(Debug, Clone)]
150pub struct CorrelationConstraint {
151    /// The type of validation to perform
152    validation: CorrelationValidation,
153}
154
155impl CorrelationConstraint {
156    /// Creates a new unified correlation constraint.
157    ///
158    /// # Arguments
159    ///
160    /// * `validation` - The type of validation to perform
161    ///
162    /// # Errors
163    ///
164    /// Returns an error if column names are invalid.
165    pub fn new(validation: CorrelationValidation) -> Result<Self> {
166        // Validate column names
167        match &validation {
168            CorrelationValidation::Pairwise {
169                column1, column2, ..
170            }
171            | CorrelationValidation::Range {
172                column1, column2, ..
173            }
174            | CorrelationValidation::Independence {
175                column1, column2, ..
176            }
177            | CorrelationValidation::Stability {
178                column1, column2, ..
179            } => {
180                SqlSecurity::validate_identifier(column1)?;
181                SqlSecurity::validate_identifier(column2)?;
182            }
183            CorrelationValidation::MultiColumn(config) => {
184                if config.columns.len() < 2 {
185                    return Err(TermError::Configuration(
186                        "At least 2 columns required for correlation analysis".to_string(),
187                    ));
188                }
189                for column in &config.columns {
190                    SqlSecurity::validate_identifier(column)?;
191                }
192            }
193        }
194
195        Ok(Self { validation })
196    }
197
198    /// Convenience constructor for Pearson correlation.
199    pub fn pearson(
200        column1: impl Into<String>,
201        column2: impl Into<String>,
202        assertion: Assertion,
203    ) -> Result<Self> {
204        Self::new(CorrelationValidation::Pairwise {
205            column1: column1.into(),
206            column2: column2.into(),
207            correlation_type: CorrelationType::Pearson,
208            assertion,
209        })
210    }
211
212    /// Convenience constructor for Spearman correlation.
213    pub fn spearman(
214        column1: impl Into<String>,
215        column2: impl Into<String>,
216        assertion: Assertion,
217    ) -> Result<Self> {
218        Self::new(CorrelationValidation::Pairwise {
219            column1: column1.into(),
220            column2: column2.into(),
221            correlation_type: CorrelationType::Spearman,
222            assertion,
223        })
224    }
225
226    /// Convenience constructor for mutual information.
227    pub fn mutual_information(
228        column1: impl Into<String>,
229        column2: impl Into<String>,
230        bins: usize,
231        assertion: Assertion,
232    ) -> Result<Self> {
233        Self::new(CorrelationValidation::Pairwise {
234            column1: column1.into(),
235            column2: column2.into(),
236            correlation_type: CorrelationType::MutualInformation { bins },
237            assertion,
238        })
239    }
240
241    /// Convenience constructor for independence check.
242    pub fn independence(
243        column1: impl Into<String>,
244        column2: impl Into<String>,
245        max_correlation: f64,
246    ) -> Result<Self> {
247        if !(0.0..=1.0).contains(&max_correlation) {
248            return Err(TermError::Configuration(
249                "Max correlation must be between 0.0 and 1.0".to_string(),
250            ));
251        }
252        Self::new(CorrelationValidation::Independence {
253            column1: column1.into(),
254            column2: column2.into(),
255            max_correlation,
256        })
257    }
258
259    /// Generates SQL for Pearson correlation.
260    fn pearson_sql(&self, col1: &str, col2: &str) -> Result<String> {
261        let escaped_col1 = SqlSecurity::escape_identifier(col1)?;
262        let escaped_col2 = SqlSecurity::escape_identifier(col2)?;
263
264        // Using DataFusion's CORR function
265        Ok(format!("CORR({escaped_col1}, {escaped_col2})"))
266    }
267
268    /// Generates SQL for covariance.
269    fn covariance_sql(&self, col1: &str, col2: &str) -> Result<String> {
270        let escaped_col1 = SqlSecurity::escape_identifier(col1)?;
271        let escaped_col2 = SqlSecurity::escape_identifier(col2)?;
272
273        // Using DataFusion's COVAR_SAMP function
274        Ok(format!("COVAR_SAMP({escaped_col1}, {escaped_col2})"))
275    }
276
277    /// Generates SQL for Spearman correlation (rank-based).
278    #[allow(dead_code)]
279    fn spearman_sql(&self, col1: &str, col2: &str) -> Result<String> {
280        let escaped_col1 = SqlSecurity::escape_identifier(col1)?;
281        let escaped_col2 = SqlSecurity::escape_identifier(col2)?;
282
283        // Spearman correlation is Pearson correlation of ranks
284        // Using window functions to compute ranks
285        Ok(format!(
286            "CORR(
287                RANK() OVER (ORDER BY {escaped_col1}) AS rank1,
288                RANK() OVER (ORDER BY {escaped_col2}) AS rank2
289            )"
290        ))
291    }
292}
293
294#[async_trait]
295impl Constraint for CorrelationConstraint {
296    #[instrument(skip(self, ctx), fields(
297        validation = ?self.validation
298    ))]
299    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
300        // Get the table name from the validation context
301        let validation_ctx = current_validation_context();
302        let table_name = validation_ctx.table_name();
303
304        match &self.validation {
305            CorrelationValidation::Pairwise {
306                column1,
307                column2,
308                correlation_type,
309                assertion,
310            } => {
311                let sql = match correlation_type {
312                    CorrelationType::Pearson => {
313                        format!(
314                            "SELECT {} as corr_value FROM {table_name}",
315                            self.pearson_sql(column1, column2)?
316                        )
317                    }
318                    CorrelationType::Covariance => {
319                        format!(
320                            "SELECT {} as corr_value FROM {table_name}",
321                            self.covariance_sql(column1, column2)?
322                        )
323                    }
324                    CorrelationType::Custom { sql_expression } => {
325                        // Basic validation to prevent obvious SQL injection
326                        if sql_expression.contains(';')
327                            || sql_expression.to_lowercase().contains("drop")
328                        {
329                            return Ok(ConstraintResult::failure(
330                                "Custom SQL expression contains potentially unsafe content",
331                            ));
332                        }
333                        let escaped_col1 = SqlSecurity::escape_identifier(column1)?;
334                        let escaped_col2 = SqlSecurity::escape_identifier(column2)?;
335                        let expr = sql_expression
336                            .replace("{column1}", &escaped_col1)
337                            .replace("{column2}", &escaped_col2);
338                        format!("SELECT {expr} as corr_value FROM {table_name}")
339                    }
340                    _ => {
341                        // Other correlation types would require more complex implementation
342                        return Ok(ConstraintResult::skipped(
343                            "Correlation type not yet implemented",
344                        ));
345                    }
346                };
347
348                let df = ctx.sql(&sql).await?;
349                let batches = df.collect().await?;
350
351                if batches.is_empty() || batches[0].num_rows() == 0 {
352                    return Ok(ConstraintResult::skipped("No data to validate"));
353                }
354
355                let value = batches[0]
356                    .column(0)
357                    .as_any()
358                    .downcast_ref::<arrow::array::Float64Array>()
359                    .ok_or_else(|| {
360                        TermError::Internal("Failed to downcast to Float64Array".to_string())
361                    })?
362                    .value(0);
363
364                if assertion.evaluate(value) {
365                    Ok(ConstraintResult::success_with_metric(value))
366                } else {
367                    Ok(ConstraintResult::failure_with_metric(
368                        value,
369                        format!(
370                            "{} between {column1} and {column2} is {value} which does not {assertion}",
371                            correlation_type.name()
372                        ),
373                    ))
374                }
375            }
376            CorrelationValidation::Range {
377                column1,
378                column2,
379                correlation_type,
380                min,
381                max,
382            } => {
383                // This is essentially the same as Pairwise with a Between assertion
384                let result = self
385                    .evaluate_with_validation(
386                        ctx,
387                        &CorrelationValidation::Pairwise {
388                            column1: column1.clone(),
389                            column2: column2.clone(),
390                            correlation_type: correlation_type.clone(),
391                            assertion: Assertion::Between(*min, *max),
392                        },
393                    )
394                    .await?;
395                Ok(result)
396            }
397            CorrelationValidation::Independence {
398                column1,
399                column2,
400                max_correlation,
401            } => {
402                // Get the table name from the validation context
403
404                let validation_ctx = current_validation_context();
405
406                let table_name = validation_ctx.table_name();
407
408                let sql = format!(
409                    "SELECT ABS({}) as abs_corr FROM {table_name}",
410                    self.pearson_sql(column1, column2)?
411                );
412
413                let df = ctx.sql(&sql).await?;
414                let batches = df.collect().await?;
415
416                if batches.is_empty() || batches[0].num_rows() == 0 {
417                    return Ok(ConstraintResult::skipped("No data to validate"));
418                }
419
420                let abs_corr = batches[0]
421                    .column(0)
422                    .as_any()
423                    .downcast_ref::<arrow::array::Float64Array>()
424                    .ok_or_else(|| {
425                        TermError::Internal("Failed to downcast to Float64Array".to_string())
426                    })?
427                    .value(0);
428
429                if abs_corr <= *max_correlation {
430                    Ok(ConstraintResult::success_with_metric(abs_corr))
431                } else {
432                    Ok(ConstraintResult::failure_with_metric(
433                        abs_corr,
434                        format!(
435                            "Columns {column1} and {column2} have correlation {abs_corr} exceeding independence threshold {max_correlation}"
436                        ),
437                    ))
438                }
439            }
440            _ => Ok(ConstraintResult::skipped(
441                "Validation type not yet implemented",
442            )),
443        }
444    }
445
446    fn name(&self) -> &str {
447        match &self.validation {
448            CorrelationValidation::Pairwise {
449                correlation_type, ..
450            } => correlation_type.constraint_name(),
451            CorrelationValidation::Range { .. } => "correlation_range",
452            CorrelationValidation::Independence { .. } => "independence",
453            CorrelationValidation::MultiColumn { .. } => "multi_correlation",
454            CorrelationValidation::Stability { .. } => "correlation_stability",
455        }
456    }
457
458    fn metadata(&self) -> ConstraintMetadata {
459        let description = match &self.validation {
460            CorrelationValidation::Pairwise {
461                column1,
462                column2,
463                correlation_type,
464                ..
465            } => format!(
466                "Validates {} between '{column1}' and '{column2}'",
467                correlation_type.name()
468            ),
469            CorrelationValidation::Range {
470                column1, column2, ..
471            } => format!(
472                "Validates correlation range between '{column1}' and '{column2}'"
473            ),
474            CorrelationValidation::Independence {
475                column1, column2, ..
476            } => format!(
477                "Validates independence between '{column1}' and '{column2}'"
478            ),
479            CorrelationValidation::MultiColumn(config) => format!(
480                "Validates correlations among columns: {}",
481                config.columns.join(", ")
482            ),
483            CorrelationValidation::Stability {
484                column1,
485                column2,
486                segment_column,
487                ..
488            } => format!(
489                "Validates correlation stability between '{column1}' and '{column2}' across '{segment_column}'"
490            ),
491        };
492
493        ConstraintMetadata::new().with_description(description)
494    }
495}
496
497impl CorrelationConstraint {
498    /// Helper method to evaluate with a different validation (for internal use).
499    async fn evaluate_with_validation(
500        &self,
501        ctx: &SessionContext,
502        validation: &CorrelationValidation,
503    ) -> Result<ConstraintResult> {
504        let temp_constraint = Self::new(validation.clone())?;
505        temp_constraint.evaluate(ctx).await
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use crate::core::ConstraintStatus;
513    use arrow::array::Float64Array;
514    use arrow::datatypes::{DataType, Field, Schema};
515    use arrow::record_batch::RecordBatch;
516    use datafusion::datasource::MemTable;
517    use std::sync::Arc;
518
519    use crate::test_helpers::evaluate_constraint_with_context;
520    async fn create_test_context_correlated() -> SessionContext {
521        let ctx = SessionContext::new();
522
523        let schema = Arc::new(Schema::new(vec![
524            Field::new("x", DataType::Float64, true),
525            Field::new("y", DataType::Float64, true),
526        ]));
527
528        // Create correlated data: y ≈ 2x + noise
529        let mut x_values = Vec::new();
530        let mut y_values = Vec::new();
531
532        for i in 0..100 {
533            let x = i as f64;
534            let y = 2.0 * x + (i % 10) as f64 - 5.0; // Some noise
535            x_values.push(Some(x));
536            y_values.push(Some(y));
537        }
538
539        let batch = RecordBatch::try_new(
540            schema.clone(),
541            vec![
542                Arc::new(Float64Array::from(x_values)),
543                Arc::new(Float64Array::from(y_values)),
544            ],
545        )
546        .unwrap();
547
548        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
549        ctx.register_table("data", Arc::new(provider)).unwrap();
550
551        ctx
552    }
553
554    async fn create_test_context_independent() -> SessionContext {
555        let ctx = SessionContext::new();
556
557        let schema = Arc::new(Schema::new(vec![
558            Field::new("x", DataType::Float64, true),
559            Field::new("y", DataType::Float64, true),
560        ]));
561
562        // Create independent data
563        let mut x_values = Vec::new();
564        let mut y_values = Vec::new();
565
566        for i in 0..100 {
567            x_values.push(Some(i as f64));
568            y_values.push(Some(((i * 37) % 100) as f64)); // Pseudo-random
569        }
570
571        let batch = RecordBatch::try_new(
572            schema.clone(),
573            vec![
574                Arc::new(Float64Array::from(x_values)),
575                Arc::new(Float64Array::from(y_values)),
576            ],
577        )
578        .unwrap();
579
580        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
581        ctx.register_table("data", Arc::new(provider)).unwrap();
582
583        ctx
584    }
585
586    #[tokio::test]
587    async fn test_pearson_correlation() {
588        let ctx = create_test_context_correlated().await;
589
590        let constraint =
591            CorrelationConstraint::pearson("x", "y", Assertion::GreaterThan(0.9)).unwrap();
592
593        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
594            .await
595            .unwrap();
596        assert_eq!(result.status, ConstraintStatus::Success);
597        assert!(result.metric.unwrap() > 0.9);
598    }
599
600    #[tokio::test]
601    async fn test_independence_check() {
602        let ctx = create_test_context_independent().await;
603
604        let constraint = CorrelationConstraint::independence("x", "y", 0.3).unwrap();
605
606        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
607            .await
608            .unwrap();
609        // Independent data should have low correlation
610        assert_eq!(result.status, ConstraintStatus::Success);
611    }
612
613    #[tokio::test]
614    async fn test_correlation_range() {
615        let ctx = create_test_context_correlated().await;
616
617        let constraint = CorrelationConstraint::new(CorrelationValidation::Range {
618            column1: "x".to_string(),
619            column2: "y".to_string(),
620            correlation_type: CorrelationType::Pearson,
621            min: 0.8,
622            max: 1.0,
623        })
624        .unwrap();
625
626        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
627            .await
628            .unwrap();
629        assert_eq!(result.status, ConstraintStatus::Success);
630    }
631
632    #[test]
633    fn test_invalid_max_correlation() {
634        let result = CorrelationConstraint::independence("x", "y", 1.5);
635        assert!(result.is_err());
636        assert!(result
637            .unwrap_err()
638            .to_string()
639            .contains("Max correlation must be between 0.0 and 1.0"));
640    }
641
642    #[test]
643    fn test_multi_column_validation() {
644        let config = MultiCorrelationConfig {
645            columns: vec!["a".to_string(), "b".to_string(), "c".to_string()],
646            correlation_type: CorrelationType::Pearson,
647            pairwise: true,
648            min_correlation: Some(0.5),
649        };
650
651        let result = CorrelationConstraint::new(CorrelationValidation::MultiColumn(config));
652        assert!(result.is_ok());
653    }
654
655    #[test]
656    fn test_multi_column_too_few() {
657        let config = MultiCorrelationConfig {
658            columns: vec!["a".to_string()],
659            correlation_type: CorrelationType::Pearson,
660            pairwise: true,
661            min_correlation: None,
662        };
663
664        let result = CorrelationConstraint::new(CorrelationValidation::MultiColumn(config));
665        assert!(result.is_err());
666        assert!(result
667            .unwrap_err()
668            .to_string()
669            .contains("At least 2 columns required"));
670    }
671}