term_guard/constraints/
statistics.rs

1//! Unified statistical constraint that consolidates all statistical checks.
2//!
3//! This module provides a single, flexible statistical constraint that replaces:
4//! - `MinConstraint`
5//! - `MaxConstraint`
6//! - `MeanConstraint`
7//! - `SumConstraint`
8//! - `StandardDeviationConstraint`
9//!
10//! And adds support for new statistics like variance, median, and percentiles.
11
12use crate::constraints::Assertion;
13use crate::core::{current_validation_context, Constraint, ConstraintMetadata, ConstraintResult};
14use crate::prelude::*;
15use crate::security::SqlSecurity;
16use arrow::array::Array;
17use async_trait::async_trait;
18use datafusion::prelude::*;
19use serde::{Deserialize, Serialize};
20use std::fmt;
21use tracing::instrument;
22/// Types of statistics that can be computed.
23#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
24pub enum StatisticType {
25    /// Minimum value
26    Min,
27    /// Maximum value
28    Max,
29    /// Mean/average value
30    Mean,
31    /// Sum of all values
32    Sum,
33    /// Standard deviation
34    StandardDeviation,
35    /// Variance
36    Variance,
37    /// Median (50th percentile)
38    Median,
39    /// Specific percentile (0.0 to 1.0)
40    Percentile(f64),
41}
42
43impl StatisticType {
44    /// Returns the SQL function name for this statistic.
45    fn sql_function(&self) -> String {
46        match self {
47            StatisticType::Min => "MIN".to_string(),
48            StatisticType::Max => "MAX".to_string(),
49            StatisticType::Mean => "AVG".to_string(),
50            StatisticType::Sum => "SUM".to_string(),
51            StatisticType::StandardDeviation => "STDDEV".to_string(),
52            StatisticType::Variance => "VARIANCE".to_string(),
53            StatisticType::Median => "APPROX_PERCENTILE_CONT".to_string(),
54            StatisticType::Percentile(_) => "APPROX_PERCENTILE_CONT".to_string(),
55        }
56    }
57
58    /// Returns the SQL expression for this statistic.
59    fn sql_expression(&self, column: &str) -> String {
60        match self {
61            StatisticType::Median => {
62                let func = self.sql_function();
63                format!("{func}({column}, 0.5)")
64            }
65            StatisticType::Percentile(p) => {
66                let func = self.sql_function();
67                format!("{func}({column}, {p})")
68            }
69            _ => {
70                let func = self.sql_function();
71                format!("{func}({column})")
72            }
73        }
74    }
75
76    /// Returns a human-readable name for this statistic.
77    fn name(&self) -> &str {
78        match self {
79            StatisticType::Min => "minimum",
80            StatisticType::Max => "maximum",
81            StatisticType::Mean => "mean",
82            StatisticType::Sum => "sum",
83            StatisticType::StandardDeviation => "standard deviation",
84            StatisticType::Variance => "variance",
85            StatisticType::Median => "median",
86            StatisticType::Percentile(p) => {
87                if (*p - 0.5).abs() < f64::EPSILON {
88                    "median"
89                } else {
90                    "percentile"
91                }
92            }
93        }
94    }
95
96    /// Returns the constraint name for backward compatibility.
97    fn constraint_name(&self) -> &str {
98        match self {
99            StatisticType::Min => "min",
100            StatisticType::Max => "max",
101            StatisticType::Mean => "mean",
102            StatisticType::Sum => "sum",
103            StatisticType::StandardDeviation => "standard_deviation",
104            StatisticType::Variance => "variance",
105            StatisticType::Median => "median",
106            StatisticType::Percentile(_) => "percentile",
107        }
108    }
109}
110
111impl fmt::Display for StatisticType {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        match self {
114            StatisticType::Percentile(p) => write!(f, "{}({p})", self.name()),
115            _ => write!(f, "{}", self.name()),
116        }
117    }
118}
119
120/// A unified constraint that checks statistical properties of a column.
121///
122/// This constraint replaces the individual statistical constraints and provides
123/// a consistent interface for all statistical checks.
124///
125/// # Examples
126///
127/// ```rust
128/// use term_guard::constraints::{StatisticalConstraint, StatisticType, Assertion};
129/// use term_guard::core::Constraint;
130///
131/// // Check that mean is between 25 and 35
132/// let mean_check = StatisticalConstraint::new(
133///     "age",
134///     StatisticType::Mean,
135///     Assertion::Between(25.0, 35.0)
136/// );
137///
138/// // Check that maximum is less than 100
139/// let max_check = StatisticalConstraint::new(
140///     "score",
141///     StatisticType::Max,
142///     Assertion::LessThan(100.0)
143/// );
144///
145/// // Check 95th percentile for SLA
146/// let p95_check = StatisticalConstraint::new(
147///     "response_time",
148///     StatisticType::Percentile(0.95),
149///     Assertion::LessThan(1000.0)
150/// );
151/// ```
152#[derive(Debug, Clone)]
153pub struct StatisticalConstraint {
154    /// The column to compute statistics on
155    column: String,
156    /// The type of statistic to compute
157    statistic: StatisticType,
158    /// The assertion to evaluate against the statistic
159    assertion: Assertion,
160}
161
162impl StatisticalConstraint {
163    /// Creates a new statistical constraint.
164    ///
165    /// # Arguments
166    ///
167    /// * `column` - The column to check
168    /// * `statistic` - The type of statistic to compute
169    /// * `assertion` - The assertion to evaluate
170    ///
171    /// # Errors
172    ///
173    /// Returns error if column name is invalid or if percentile is out of range.
174    pub fn new(
175        column: impl Into<String>,
176        statistic: StatisticType,
177        assertion: Assertion,
178    ) -> Result<Self> {
179        let column_str = column.into();
180        SqlSecurity::validate_identifier(&column_str)?;
181
182        // Validate percentile range
183        if let StatisticType::Percentile(p) = &statistic {
184            if !(0.0..=1.0).contains(p) {
185                return Err(TermError::SecurityError(
186                    "Percentile must be between 0.0 and 1.0".to_string(),
187                ));
188            }
189        }
190
191        Ok(Self {
192            column: column_str,
193            statistic,
194            assertion,
195        })
196    }
197
198    /// Creates a minimum value constraint.
199    pub fn min(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
200        Self::new(column, StatisticType::Min, assertion)
201    }
202
203    /// Creates a maximum value constraint.
204    pub fn max(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
205        Self::new(column, StatisticType::Max, assertion)
206    }
207
208    /// Creates a mean/average constraint.
209    pub fn mean(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
210        Self::new(column, StatisticType::Mean, assertion)
211    }
212
213    /// Creates a sum constraint.
214    pub fn sum(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
215        Self::new(column, StatisticType::Sum, assertion)
216    }
217
218    /// Creates a standard deviation constraint.
219    pub fn standard_deviation(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
220        Self::new(column, StatisticType::StandardDeviation, assertion)
221    }
222
223    /// Creates a variance constraint.
224    pub fn variance(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
225        Self::new(column, StatisticType::Variance, assertion)
226    }
227
228    /// Creates a median constraint.
229    pub fn median(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
230        Self::new(column, StatisticType::Median, assertion)
231    }
232
233    /// Creates a percentile constraint.
234    ///
235    /// # Errors
236    ///
237    /// Returns error if column name is invalid or percentile is not between 0.0 and 1.0
238    pub fn percentile(
239        column: impl Into<String>,
240        percentile: f64,
241        assertion: Assertion,
242    ) -> Result<Self> {
243        Self::new(column, StatisticType::Percentile(percentile), assertion)
244    }
245}
246
247#[async_trait]
248impl Constraint for StatisticalConstraint {
249    #[instrument(skip(self, ctx), fields(
250        column = %self.column,
251        statistic = %self.statistic,
252        assertion = %self.assertion
253    ))]
254    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
255        let column_identifier = SqlSecurity::escape_identifier(&self.column)?;
256        let stat_expr = self.statistic.sql_expression(&column_identifier);
257        // Get the table name from the validation context
258
259        let validation_ctx = current_validation_context();
260
261        let table_name = validation_ctx.table_name();
262
263        let sql = format!("SELECT {stat_expr} as stat_value FROM {table_name}");
264
265        let df = ctx.sql(&sql).await?;
266        let batches = df.collect().await?;
267
268        if batches.is_empty() {
269            return Ok(ConstraintResult::skipped("No data to validate"));
270        }
271
272        let batch = &batches[0];
273        if batch.num_rows() == 0 {
274            return Ok(ConstraintResult::skipped("No data to validate"));
275        }
276
277        // Extract the statistic value - try Int64 first, then Float64
278        let value = if let Ok(array) = batch
279            .column(0)
280            .as_any()
281            .downcast_ref::<arrow::array::Int64Array>()
282            .ok_or_else(|| TermError::Internal("Failed to extract statistic value".to_string()))
283        {
284            if array.is_null(0) {
285                let stat_name = self.statistic.name();
286                return Ok(ConstraintResult::failure(format!(
287                    "{stat_name} is null (no non-null values)"
288                )));
289            }
290            array.value(0) as f64
291        } else if let Ok(array) = batch
292            .column(0)
293            .as_any()
294            .downcast_ref::<arrow::array::Float64Array>()
295            .ok_or_else(|| TermError::Internal("Failed to extract statistic value".to_string()))
296        {
297            if array.is_null(0) {
298                let stat_name = self.statistic.name();
299                return Ok(ConstraintResult::failure(format!(
300                    "{stat_name} is null (no non-null values)"
301                )));
302            }
303            array.value(0)
304        } else {
305            return Err(TermError::Internal(
306                "Failed to extract statistic value".to_string(),
307            ));
308        };
309
310        if self.assertion.evaluate(value) {
311            Ok(ConstraintResult::success_with_metric(value))
312        } else {
313            Ok(ConstraintResult::failure_with_metric(
314                value,
315                format!(
316                    "{} {value} does not {}",
317                    self.statistic.name(),
318                    self.assertion
319                ),
320            ))
321        }
322    }
323
324    fn name(&self) -> &str {
325        self.statistic.constraint_name()
326    }
327
328    fn column(&self) -> Option<&str> {
329        Some(&self.column)
330    }
331
332    fn metadata(&self) -> ConstraintMetadata {
333        let mut metadata = ConstraintMetadata::for_column(&self.column)
334            .with_description(format!(
335                "Checks that {} of {} {}",
336                self.statistic.name(),
337                self.column,
338                self.assertion.description()
339            ))
340            .with_custom("assertion", self.assertion.to_string())
341            .with_custom("statistic_type", self.statistic.to_string())
342            .with_custom("constraint_type", "statistical");
343
344        if let StatisticType::Percentile(p) = self.statistic {
345            metadata = metadata.with_custom("percentile", p.to_string());
346        }
347
348        metadata
349    }
350}
351
352/// A constraint that can compute multiple statistics in a single query for performance optimization.
353///
354/// This is useful when you need to validate multiple statistics on the same column,
355/// as it reduces the number of table scans required.
356///
357/// # Examples
358///
359/// ```rust
360/// use term_guard::constraints::{MultiStatisticalConstraint, StatisticType, Assertion};
361///
362/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
363/// // Check multiple statistics on the same column in one pass
364/// let multi_stats = MultiStatisticalConstraint::new(
365///     "response_time",
366///     vec![
367///         (StatisticType::Min, Assertion::GreaterThanOrEqual(0.0)),
368///         (StatisticType::Max, Assertion::LessThan(5000.0)),
369///         (StatisticType::Mean, Assertion::Between(100.0, 1000.0)),
370///         (StatisticType::Percentile(0.95), Assertion::LessThan(2000.0)),
371///     ]
372/// )?;
373/// # Ok(())
374/// # }
375/// ```
376#[derive(Debug, Clone)]
377pub struct MultiStatisticalConstraint {
378    column: String,
379    statistics: Vec<(StatisticType, Assertion)>,
380}
381
382impl MultiStatisticalConstraint {
383    /// Creates a new multi-statistical constraint.
384    ///
385    /// # Arguments
386    ///
387    /// * `column` - The column to compute statistics on
388    /// * `statistics` - A vector of (statistic_type, assertion) pairs to evaluate
389    ///
390    /// # Errors
391    ///
392    /// Returns error if column name is invalid or if any percentile is out of range.
393    pub fn new(
394        column: impl Into<String>,
395        statistics: Vec<(StatisticType, Assertion)>,
396    ) -> Result<Self> {
397        let column_str = column.into();
398        SqlSecurity::validate_identifier(&column_str)?;
399
400        // Validate all percentile values
401        for (stat, _) in &statistics {
402            if let StatisticType::Percentile(p) = stat {
403                if !(0.0..=1.0).contains(p) {
404                    return Err(TermError::SecurityError(
405                        "Percentile must be between 0.0 and 1.0".to_string(),
406                    ));
407                }
408            }
409        }
410
411        Ok(Self {
412            column: column_str,
413            statistics,
414        })
415    }
416}
417
418#[async_trait]
419impl Constraint for MultiStatisticalConstraint {
420    #[instrument(skip(self, ctx), fields(
421        column = %self.column,
422        num_statistics = %self.statistics.len()
423    ))]
424    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
425        let column_identifier = SqlSecurity::escape_identifier(&self.column)?;
426
427        // Build SQL with all statistics computed in one query
428        let sql_parts: Vec<String> = self
429            .statistics
430            .iter()
431            .enumerate()
432            .map(|(i, (stat, _))| {
433                let expr = stat.sql_expression(&column_identifier);
434                format!("{expr} as stat_{i}")
435            })
436            .collect();
437
438        let parts = sql_parts.join(", ");
439        // Get the table name from the validation context
440        let validation_ctx = current_validation_context();
441        let table_name = validation_ctx.table_name();
442
443        let sql = format!("SELECT {parts} FROM {table_name}");
444
445        let df = ctx.sql(&sql).await?;
446        let batches = df.collect().await?;
447
448        if batches.is_empty() {
449            return Ok(ConstraintResult::skipped("No data to validate"));
450        }
451
452        let batch = &batches[0];
453        if batch.num_rows() == 0 {
454            return Ok(ConstraintResult::skipped("No data to validate"));
455        }
456
457        // Check each statistic
458        let mut failures = Vec::new();
459        let mut all_metrics = Vec::new();
460
461        for (i, (stat_type, assertion)) in self.statistics.iter().enumerate() {
462            let column = batch.column(i);
463
464            // Extract value
465            let value = if let Some(array) =
466                column.as_any().downcast_ref::<arrow::array::Float64Array>()
467            {
468                if array.is_null(0) {
469                    let name = stat_type.name();
470                    failures.push(format!("{name} is null"));
471                    continue;
472                }
473                array.value(0)
474            } else if let Some(array) = column.as_any().downcast_ref::<arrow::array::Int64Array>() {
475                if array.is_null(0) {
476                    let name = stat_type.name();
477                    failures.push(format!("{name} is null"));
478                    continue;
479                }
480                array.value(0) as f64
481            } else {
482                let name = stat_type.name();
483                failures.push(format!("Failed to compute {name}"));
484                continue;
485            };
486
487            all_metrics.push((stat_type.name().to_string(), value));
488
489            if !assertion.evaluate(value) {
490                failures.push(format!(
491                    "{} is {value} which does not {assertion}",
492                    stat_type.name()
493                ));
494            }
495        }
496
497        if failures.is_empty() {
498            // All assertions passed - return the first metric as representative
499            let first_metric = all_metrics.first().map(|(_, v)| *v).unwrap_or(0.0);
500            Ok(ConstraintResult::success_with_metric(first_metric))
501        } else {
502            Ok(ConstraintResult::failure(failures.join("; ")))
503        }
504    }
505
506    fn name(&self) -> &str {
507        "multi_statistical"
508    }
509
510    fn column(&self) -> Option<&str> {
511        Some(&self.column)
512    }
513
514    fn metadata(&self) -> ConstraintMetadata {
515        let stat_names: Vec<String> = self
516            .statistics
517            .iter()
518            .map(|(s, _)| s.name().to_string())
519            .collect();
520
521        ConstraintMetadata::for_column(&self.column)
522            .with_description({
523                let stats = stat_names.join(", ");
524                format!(
525                    "Checks multiple statistics ({stats}) for column {}",
526                    self.column
527                )
528            })
529            .with_custom("statistics_count", self.statistics.len().to_string())
530            .with_custom("constraint_type", "multi_statistical")
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537    use crate::core::ConstraintStatus;
538    use arrow::array::Float64Array;
539    use arrow::datatypes::{DataType, Field, Schema};
540    use arrow::record_batch::RecordBatch;
541    use datafusion::datasource::MemTable;
542    use std::sync::Arc;
543
544    use crate::test_helpers::evaluate_constraint_with_context;
545    async fn create_test_context(values: Vec<Option<f64>>) -> SessionContext {
546        let ctx = SessionContext::new();
547
548        let schema = Arc::new(Schema::new(vec![Field::new(
549            "value",
550            DataType::Float64,
551            true,
552        )]));
553
554        let array = Float64Array::from(values);
555        let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
556
557        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
558        ctx.register_table("data", Arc::new(provider)).unwrap();
559
560        ctx
561    }
562
563    #[tokio::test]
564    async fn test_mean_constraint() {
565        let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
566        let constraint = StatisticalConstraint::mean("value", Assertion::Equals(20.0)).unwrap();
567
568        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
569            .await
570            .unwrap();
571        assert_eq!(result.status, ConstraintStatus::Success);
572        assert_eq!(result.metric, Some(20.0));
573    }
574
575    #[tokio::test]
576    async fn test_min_max_constraints() {
577        let ctx = create_test_context(vec![Some(5.0), Some(10.0), Some(15.0)]).await;
578
579        let min_constraint = StatisticalConstraint::min("value", Assertion::Equals(5.0)).unwrap();
580        let result = evaluate_constraint_with_context(&min_constraint, &ctx, "data")
581            .await
582            .unwrap();
583        assert_eq!(result.status, ConstraintStatus::Success);
584        assert_eq!(result.metric, Some(5.0));
585
586        let max_constraint = StatisticalConstraint::max("value", Assertion::Equals(15.0)).unwrap();
587        let result = evaluate_constraint_with_context(&max_constraint, &ctx, "data")
588            .await
589            .unwrap();
590        assert_eq!(result.status, ConstraintStatus::Success);
591        assert_eq!(result.metric, Some(15.0));
592    }
593
594    #[tokio::test]
595    async fn test_sum_constraint() {
596        let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
597        let constraint = StatisticalConstraint::sum("value", Assertion::Equals(60.0)).unwrap();
598
599        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
600            .await
601            .unwrap();
602        assert_eq!(result.status, ConstraintStatus::Success);
603        assert_eq!(result.metric, Some(60.0));
604    }
605
606    #[tokio::test]
607    async fn test_with_nulls() {
608        let ctx = create_test_context(vec![Some(10.0), None, Some(20.0)]).await;
609        let constraint = StatisticalConstraint::mean("value", Assertion::Equals(15.0)).unwrap();
610
611        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
612            .await
613            .unwrap();
614        assert_eq!(result.status, ConstraintStatus::Success);
615        assert_eq!(result.metric, Some(15.0));
616    }
617
618    #[tokio::test]
619    async fn test_all_nulls() {
620        let ctx = create_test_context(vec![None, None, None]).await;
621        let constraint = StatisticalConstraint::mean("value", Assertion::Equals(0.0)).unwrap();
622
623        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
624            .await
625            .unwrap();
626        assert_eq!(result.status, ConstraintStatus::Failure);
627        assert!(result.message.unwrap().contains("null"));
628    }
629
630    #[test]
631    fn test_statistic_type_display() {
632        assert_eq!(StatisticType::Min.to_string(), "minimum");
633        assert_eq!(StatisticType::Mean.to_string(), "mean");
634        assert_eq!(
635            StatisticType::Percentile(0.95).to_string(),
636            "percentile(0.95)"
637        );
638        assert_eq!(StatisticType::Median.to_string(), "median");
639    }
640
641    #[tokio::test]
642    async fn test_multi_statistical_constraint() {
643        let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0), Some(40.0)]).await;
644
645        let constraint = MultiStatisticalConstraint::new(
646            "value",
647            vec![
648                (StatisticType::Min, Assertion::GreaterThanOrEqual(10.0)),
649                (StatisticType::Max, Assertion::LessThanOrEqual(40.0)),
650                (StatisticType::Mean, Assertion::Equals(25.0)),
651                (StatisticType::Sum, Assertion::Equals(100.0)),
652            ],
653        )
654        .unwrap();
655
656        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
657            .await
658            .unwrap();
659        assert_eq!(result.status, ConstraintStatus::Success);
660    }
661
662    #[tokio::test]
663    async fn test_multi_statistical_constraint_failure() {
664        let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
665
666        let constraint = MultiStatisticalConstraint::new(
667            "value",
668            vec![
669                (StatisticType::Min, Assertion::Equals(5.0)), // Will fail
670                (StatisticType::Max, Assertion::Equals(30.0)),
671            ],
672        )
673        .unwrap();
674
675        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
676            .await
677            .unwrap();
678        assert_eq!(result.status, ConstraintStatus::Failure);
679        assert!(result.message.unwrap().contains("minimum is 10"));
680    }
681
682    #[test]
683    fn test_invalid_percentile() {
684        let result = StatisticalConstraint::new(
685            "value",
686            StatisticType::Percentile(1.5),
687            Assertion::LessThan(100.0),
688        );
689
690        assert!(result.is_err());
691        assert!(result
692            .unwrap_err()
693            .to_string()
694            .contains("Percentile must be between 0.0 and 1.0"));
695    }
696}