term_guard/constraints/
correlation.rs

1//! Unified correlation constraint that consolidates correlation and relationship validations.
2//!
3//! This module provides a single, flexible correlation constraint that replaces:
4//! - `CorrelationConstraint` - Pearson correlation between columns
5//! - `MutualInformationConstraint` - Mutual information between columns
6//!
7//! And adds support for other correlation types and multi-column relationships.
8
9use crate::constraints::Assertion;
10use crate::core::{Constraint, ConstraintMetadata, ConstraintResult};
11use crate::prelude::*;
12use crate::security::SqlSecurity;
13use arrow::array::Array;
14use async_trait::async_trait;
15use datafusion::prelude::*;
16use serde::{Deserialize, Serialize};
17use tracing::instrument;
18
19/// Types of correlation that can be computed.
20#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21pub enum CorrelationType {
22    /// Pearson correlation coefficient (-1 to 1)
23    Pearson,
24    /// Spearman rank correlation coefficient
25    Spearman,
26    /// Kendall's tau correlation
27    KendallTau,
28    /// Mutual information (non-negative)
29    MutualInformation {
30        /// Number of bins for discretization
31        bins: usize,
32    },
33    /// Covariance
34    Covariance,
35    /// Custom correlation using SQL expression
36    Custom { sql_expression: String },
37}
38
39impl CorrelationType {
40    /// Returns a human-readable name for this correlation type.
41    fn name(&self) -> &str {
42        match self {
43            CorrelationType::Pearson => "Pearson correlation",
44            CorrelationType::Spearman => "Spearman correlation",
45            CorrelationType::KendallTau => "Kendall's tau",
46            CorrelationType::MutualInformation { .. } => "mutual information",
47            CorrelationType::Covariance => "covariance",
48            CorrelationType::Custom { .. } => "custom correlation",
49        }
50    }
51
52    /// Returns the constraint name for backward compatibility.
53    fn constraint_name(&self) -> &str {
54        match self {
55            CorrelationType::Pearson => "correlation",
56            CorrelationType::Spearman => "spearman_correlation",
57            CorrelationType::KendallTau => "kendall_correlation",
58            CorrelationType::MutualInformation { .. } => "mutual_information",
59            CorrelationType::Covariance => "covariance",
60            CorrelationType::Custom { .. } => "custom_correlation",
61        }
62    }
63}
64
65/// Configuration for multi-column correlation analysis.
66#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67pub struct MultiCorrelationConfig {
68    /// Columns to analyze (must have at least 2)
69    pub columns: Vec<String>,
70    /// Type of correlation to compute
71    pub correlation_type: CorrelationType,
72    /// Whether to compute pairwise correlations
73    pub pairwise: bool,
74    /// Minimum threshold for correlation strength (optional)
75    pub min_correlation: Option<f64>,
76}
77
78/// Types of correlation validation that can be performed.
79#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
80pub enum CorrelationValidation {
81    /// Check correlation between two columns
82    Pairwise {
83        column1: String,
84        column2: String,
85        correlation_type: CorrelationType,
86        assertion: Assertion,
87    },
88
89    /// Check that correlation is within expected range
90    Range {
91        column1: String,
92        column2: String,
93        correlation_type: CorrelationType,
94        min: f64,
95        max: f64,
96    },
97
98    /// Check multiple correlations
99    MultiColumn(MultiCorrelationConfig),
100
101    /// Check that columns are independent (low correlation)
102    Independence {
103        column1: String,
104        column2: String,
105        max_correlation: f64,
106    },
107
108    /// Check correlation stability over time/segments
109    Stability {
110        column1: String,
111        column2: String,
112        segment_column: String,
113        max_variance: f64,
114    },
115}
116
117/// A unified constraint that validates correlations and relationships between columns.
118///
119/// This constraint replaces individual correlation constraints and provides
120/// a consistent interface for all correlation-based validations.
121///
122/// # Examples
123///
124/// ```rust
125/// use term_guard::constraints::{CorrelationConstraint, CorrelationType, Assertion};
126/// use term_guard::core::Constraint;
127///
128/// // Check Pearson correlation
129/// let pearson_check = CorrelationConstraint::pearson(
130///     "height",
131///     "weight",
132///     Assertion::Between(0.6, 0.9)
133/// );
134///
135/// // Check independence
136/// let independence_check = CorrelationConstraint::independence(
137///     "user_id",
138///     "transaction_amount",
139///     0.1
140/// );
141///
142/// // Check mutual information
143/// let mi_check = CorrelationConstraint::mutual_information(
144///     "category",
145///     "price",
146///     10, // bins
147///     Assertion::GreaterThan(0.5)
148/// );
149/// ```
150#[derive(Debug, Clone)]
151pub struct CorrelationConstraint {
152    /// The type of validation to perform
153    validation: CorrelationValidation,
154}
155
156impl CorrelationConstraint {
157    /// Creates a new unified correlation constraint.
158    ///
159    /// # Arguments
160    ///
161    /// * `validation` - The type of validation to perform
162    ///
163    /// # Errors
164    ///
165    /// Returns an error if column names are invalid.
166    pub fn new(validation: CorrelationValidation) -> Result<Self> {
167        // Validate column names
168        match &validation {
169            CorrelationValidation::Pairwise {
170                column1, column2, ..
171            }
172            | CorrelationValidation::Range {
173                column1, column2, ..
174            }
175            | CorrelationValidation::Independence {
176                column1, column2, ..
177            }
178            | CorrelationValidation::Stability {
179                column1, column2, ..
180            } => {
181                SqlSecurity::validate_identifier(column1)?;
182                SqlSecurity::validate_identifier(column2)?;
183            }
184            CorrelationValidation::MultiColumn(config) => {
185                if config.columns.len() < 2 {
186                    return Err(TermError::Configuration(
187                        "At least 2 columns required for correlation analysis".to_string(),
188                    ));
189                }
190                for column in &config.columns {
191                    SqlSecurity::validate_identifier(column)?;
192                }
193            }
194        }
195
196        Ok(Self { validation })
197    }
198
199    /// Convenience constructor for Pearson correlation.
200    pub fn pearson(
201        column1: impl Into<String>,
202        column2: impl Into<String>,
203        assertion: Assertion,
204    ) -> Result<Self> {
205        Self::new(CorrelationValidation::Pairwise {
206            column1: column1.into(),
207            column2: column2.into(),
208            correlation_type: CorrelationType::Pearson,
209            assertion,
210        })
211    }
212
213    /// Convenience constructor for Spearman correlation.
214    pub fn spearman(
215        column1: impl Into<String>,
216        column2: impl Into<String>,
217        assertion: Assertion,
218    ) -> Result<Self> {
219        Self::new(CorrelationValidation::Pairwise {
220            column1: column1.into(),
221            column2: column2.into(),
222            correlation_type: CorrelationType::Spearman,
223            assertion,
224        })
225    }
226
227    /// Convenience constructor for mutual information.
228    pub fn mutual_information(
229        column1: impl Into<String>,
230        column2: impl Into<String>,
231        bins: usize,
232        assertion: Assertion,
233    ) -> Result<Self> {
234        Self::new(CorrelationValidation::Pairwise {
235            column1: column1.into(),
236            column2: column2.into(),
237            correlation_type: CorrelationType::MutualInformation { bins },
238            assertion,
239        })
240    }
241
242    /// Convenience constructor for independence check.
243    pub fn independence(
244        column1: impl Into<String>,
245        column2: impl Into<String>,
246        max_correlation: f64,
247    ) -> Result<Self> {
248        if !(0.0..=1.0).contains(&max_correlation) {
249            return Err(TermError::Configuration(
250                "Max correlation must be between 0.0 and 1.0".to_string(),
251            ));
252        }
253        Self::new(CorrelationValidation::Independence {
254            column1: column1.into(),
255            column2: column2.into(),
256            max_correlation,
257        })
258    }
259
260    /// Generates SQL for Pearson correlation.
261    fn pearson_sql(&self, col1: &str, col2: &str) -> Result<String> {
262        let escaped_col1 = SqlSecurity::escape_identifier(col1)?;
263        let escaped_col2 = SqlSecurity::escape_identifier(col2)?;
264
265        // Using DataFusion's CORR function
266        Ok(format!("CORR({escaped_col1}, {escaped_col2})"))
267    }
268
269    /// Generates SQL for covariance.
270    fn covariance_sql(&self, col1: &str, col2: &str) -> Result<String> {
271        let escaped_col1 = SqlSecurity::escape_identifier(col1)?;
272        let escaped_col2 = SqlSecurity::escape_identifier(col2)?;
273
274        // Using DataFusion's COVAR_SAMP function
275        Ok(format!("COVAR_SAMP({escaped_col1}, {escaped_col2})"))
276    }
277
278    /// Generates SQL for Spearman correlation (rank-based).
279    #[allow(dead_code)]
280    fn spearman_sql(&self, col1: &str, col2: &str) -> Result<String> {
281        let escaped_col1 = SqlSecurity::escape_identifier(col1)?;
282        let escaped_col2 = SqlSecurity::escape_identifier(col2)?;
283
284        // Spearman correlation is Pearson correlation of ranks
285        // Using window functions to compute ranks
286        Ok(format!(
287            "CORR(
288                RANK() OVER (ORDER BY {escaped_col1}) AS rank1,
289                RANK() OVER (ORDER BY {escaped_col2}) AS rank2
290            )"
291        ))
292    }
293}
294
295#[async_trait]
296impl Constraint for CorrelationConstraint {
297    #[instrument(skip(self, ctx), fields(
298        validation = ?self.validation
299    ))]
300    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
301        match &self.validation {
302            CorrelationValidation::Pairwise {
303                column1,
304                column2,
305                correlation_type,
306                assertion,
307            } => {
308                let sql = match correlation_type {
309                    CorrelationType::Pearson => {
310                        format!(
311                            "SELECT {} as corr_value FROM data",
312                            self.pearson_sql(column1, column2)?
313                        )
314                    }
315                    CorrelationType::Covariance => {
316                        format!(
317                            "SELECT {} as corr_value FROM data",
318                            self.covariance_sql(column1, column2)?
319                        )
320                    }
321                    CorrelationType::Custom { sql_expression } => {
322                        // Basic validation to prevent obvious SQL injection
323                        if sql_expression.contains(';')
324                            || sql_expression.to_lowercase().contains("drop")
325                        {
326                            return Ok(ConstraintResult::failure(
327                                "Custom SQL expression contains potentially unsafe content",
328                            ));
329                        }
330                        let escaped_col1 = SqlSecurity::escape_identifier(column1)?;
331                        let escaped_col2 = SqlSecurity::escape_identifier(column2)?;
332                        let expr = sql_expression
333                            .replace("{column1}", &escaped_col1)
334                            .replace("{column2}", &escaped_col2);
335                        format!("SELECT {expr} as corr_value FROM data")
336                    }
337                    _ => {
338                        // Other correlation types would require more complex implementation
339                        return Ok(ConstraintResult::skipped(
340                            "Correlation type not yet implemented",
341                        ));
342                    }
343                };
344
345                let df = ctx.sql(&sql).await?;
346                let batches = df.collect().await?;
347
348                if batches.is_empty() || batches[0].num_rows() == 0 {
349                    return Ok(ConstraintResult::skipped("No data to validate"));
350                }
351
352                let value = batches[0]
353                    .column(0)
354                    .as_any()
355                    .downcast_ref::<arrow::array::Float64Array>()
356                    .ok_or_else(|| {
357                        TermError::Internal("Failed to downcast to Float64Array".to_string())
358                    })?
359                    .value(0);
360
361                if assertion.evaluate(value) {
362                    Ok(ConstraintResult::success_with_metric(value))
363                } else {
364                    Ok(ConstraintResult::failure_with_metric(
365                        value,
366                        format!(
367                            "{} between {column1} and {column2} is {value} which does not {assertion}",
368                            correlation_type.name()
369                        ),
370                    ))
371                }
372            }
373            CorrelationValidation::Range {
374                column1,
375                column2,
376                correlation_type,
377                min,
378                max,
379            } => {
380                // This is essentially the same as Pairwise with a Between assertion
381                let result = self
382                    .evaluate_with_validation(
383                        ctx,
384                        &CorrelationValidation::Pairwise {
385                            column1: column1.clone(),
386                            column2: column2.clone(),
387                            correlation_type: correlation_type.clone(),
388                            assertion: Assertion::Between(*min, *max),
389                        },
390                    )
391                    .await?;
392                Ok(result)
393            }
394            CorrelationValidation::Independence {
395                column1,
396                column2,
397                max_correlation,
398            } => {
399                let sql = format!(
400                    "SELECT ABS({}) as abs_corr FROM data",
401                    self.pearson_sql(column1, column2)?
402                );
403
404                let df = ctx.sql(&sql).await?;
405                let batches = df.collect().await?;
406
407                if batches.is_empty() || batches[0].num_rows() == 0 {
408                    return Ok(ConstraintResult::skipped("No data to validate"));
409                }
410
411                let abs_corr = batches[0]
412                    .column(0)
413                    .as_any()
414                    .downcast_ref::<arrow::array::Float64Array>()
415                    .ok_or_else(|| {
416                        TermError::Internal("Failed to downcast to Float64Array".to_string())
417                    })?
418                    .value(0);
419
420                if abs_corr <= *max_correlation {
421                    Ok(ConstraintResult::success_with_metric(abs_corr))
422                } else {
423                    Ok(ConstraintResult::failure_with_metric(
424                        abs_corr,
425                        format!(
426                            "Columns {column1} and {column2} have correlation {abs_corr} exceeding independence threshold {max_correlation}"
427                        ),
428                    ))
429                }
430            }
431            _ => Ok(ConstraintResult::skipped(
432                "Validation type not yet implemented",
433            )),
434        }
435    }
436
437    fn name(&self) -> &str {
438        match &self.validation {
439            CorrelationValidation::Pairwise {
440                correlation_type, ..
441            } => correlation_type.constraint_name(),
442            CorrelationValidation::Range { .. } => "correlation_range",
443            CorrelationValidation::Independence { .. } => "independence",
444            CorrelationValidation::MultiColumn { .. } => "multi_correlation",
445            CorrelationValidation::Stability { .. } => "correlation_stability",
446        }
447    }
448
449    fn metadata(&self) -> ConstraintMetadata {
450        let description = match &self.validation {
451            CorrelationValidation::Pairwise {
452                column1,
453                column2,
454                correlation_type,
455                ..
456            } => format!(
457                "Validates {} between '{column1}' and '{column2}'",
458                correlation_type.name()
459            ),
460            CorrelationValidation::Range {
461                column1, column2, ..
462            } => format!(
463                "Validates correlation range between '{column1}' and '{column2}'"
464            ),
465            CorrelationValidation::Independence {
466                column1, column2, ..
467            } => format!(
468                "Validates independence between '{column1}' and '{column2}'"
469            ),
470            CorrelationValidation::MultiColumn(config) => format!(
471                "Validates correlations among columns: {}",
472                config.columns.join(", ")
473            ),
474            CorrelationValidation::Stability {
475                column1,
476                column2,
477                segment_column,
478                ..
479            } => format!(
480                "Validates correlation stability between '{column1}' and '{column2}' across '{segment_column}'"
481            ),
482        };
483
484        ConstraintMetadata::new().with_description(description)
485    }
486}
487
488impl CorrelationConstraint {
489    /// Helper method to evaluate with a different validation (for internal use).
490    async fn evaluate_with_validation(
491        &self,
492        ctx: &SessionContext,
493        validation: &CorrelationValidation,
494    ) -> Result<ConstraintResult> {
495        let temp_constraint = Self::new(validation.clone())?;
496        temp_constraint.evaluate(ctx).await
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use crate::core::ConstraintStatus;
504    use arrow::array::Float64Array;
505    use arrow::datatypes::{DataType, Field, Schema};
506    use arrow::record_batch::RecordBatch;
507    use datafusion::datasource::MemTable;
508    use std::sync::Arc;
509
510    async fn create_test_context_correlated() -> SessionContext {
511        let ctx = SessionContext::new();
512
513        let schema = Arc::new(Schema::new(vec![
514            Field::new("x", DataType::Float64, true),
515            Field::new("y", DataType::Float64, true),
516        ]));
517
518        // Create correlated data: y ≈ 2x + noise
519        let mut x_values = Vec::new();
520        let mut y_values = Vec::new();
521
522        for i in 0..100 {
523            let x = i as f64;
524            let y = 2.0 * x + (i % 10) as f64 - 5.0; // Some noise
525            x_values.push(Some(x));
526            y_values.push(Some(y));
527        }
528
529        let batch = RecordBatch::try_new(
530            schema.clone(),
531            vec![
532                Arc::new(Float64Array::from(x_values)),
533                Arc::new(Float64Array::from(y_values)),
534            ],
535        )
536        .unwrap();
537
538        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
539        ctx.register_table("data", Arc::new(provider)).unwrap();
540
541        ctx
542    }
543
544    async fn create_test_context_independent() -> SessionContext {
545        let ctx = SessionContext::new();
546
547        let schema = Arc::new(Schema::new(vec![
548            Field::new("x", DataType::Float64, true),
549            Field::new("y", DataType::Float64, true),
550        ]));
551
552        // Create independent data
553        let mut x_values = Vec::new();
554        let mut y_values = Vec::new();
555
556        for i in 0..100 {
557            x_values.push(Some(i as f64));
558            y_values.push(Some(((i * 37) % 100) as f64)); // Pseudo-random
559        }
560
561        let batch = RecordBatch::try_new(
562            schema.clone(),
563            vec![
564                Arc::new(Float64Array::from(x_values)),
565                Arc::new(Float64Array::from(y_values)),
566            ],
567        )
568        .unwrap();
569
570        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
571        ctx.register_table("data", Arc::new(provider)).unwrap();
572
573        ctx
574    }
575
576    #[tokio::test]
577    async fn test_pearson_correlation() {
578        let ctx = create_test_context_correlated().await;
579
580        let constraint =
581            CorrelationConstraint::pearson("x", "y", Assertion::GreaterThan(0.9)).unwrap();
582
583        let result = constraint.evaluate(&ctx).await.unwrap();
584        assert_eq!(result.status, ConstraintStatus::Success);
585        assert!(result.metric.unwrap() > 0.9);
586    }
587
588    #[tokio::test]
589    async fn test_independence_check() {
590        let ctx = create_test_context_independent().await;
591
592        let constraint = CorrelationConstraint::independence("x", "y", 0.3).unwrap();
593
594        let result = constraint.evaluate(&ctx).await.unwrap();
595        // Independent data should have low correlation
596        assert_eq!(result.status, ConstraintStatus::Success);
597    }
598
599    #[tokio::test]
600    async fn test_correlation_range() {
601        let ctx = create_test_context_correlated().await;
602
603        let constraint = CorrelationConstraint::new(CorrelationValidation::Range {
604            column1: "x".to_string(),
605            column2: "y".to_string(),
606            correlation_type: CorrelationType::Pearson,
607            min: 0.8,
608            max: 1.0,
609        })
610        .unwrap();
611
612        let result = constraint.evaluate(&ctx).await.unwrap();
613        assert_eq!(result.status, ConstraintStatus::Success);
614    }
615
616    #[test]
617    fn test_invalid_max_correlation() {
618        let result = CorrelationConstraint::independence("x", "y", 1.5);
619        assert!(result.is_err());
620        assert!(result
621            .unwrap_err()
622            .to_string()
623            .contains("Max correlation must be between 0.0 and 1.0"));
624    }
625
626    #[test]
627    fn test_multi_column_validation() {
628        let config = MultiCorrelationConfig {
629            columns: vec!["a".to_string(), "b".to_string(), "c".to_string()],
630            correlation_type: CorrelationType::Pearson,
631            pairwise: true,
632            min_correlation: Some(0.5),
633        };
634
635        let result = CorrelationConstraint::new(CorrelationValidation::MultiColumn(config));
636        assert!(result.is_ok());
637    }
638
639    #[test]
640    fn test_multi_column_too_few() {
641        let config = MultiCorrelationConfig {
642            columns: vec!["a".to_string()],
643            correlation_type: CorrelationType::Pearson,
644            pairwise: true,
645            min_correlation: None,
646        };
647
648        let result = CorrelationConstraint::new(CorrelationValidation::MultiColumn(config));
649        assert!(result.is_err());
650        assert!(result
651            .unwrap_err()
652            .to_string()
653            .contains("At least 2 columns required"));
654    }
655}