term_guard/constraints/
statistics.rs

1//! Unified statistical constraint that consolidates all statistical checks.
2//!
3//! This module provides a single, flexible statistical constraint that replaces:
4//! - `MinConstraint`
5//! - `MaxConstraint`
6//! - `MeanConstraint`
7//! - `SumConstraint`
8//! - `StandardDeviationConstraint`
9//!
10//! And adds support for new statistics like variance, median, and percentiles.
11
12use crate::constraints::Assertion;
13use crate::core::{Constraint, ConstraintMetadata, ConstraintResult};
14use crate::prelude::*;
15use crate::security::SqlSecurity;
16use arrow::array::Array;
17use async_trait::async_trait;
18use datafusion::prelude::*;
19use serde::{Deserialize, Serialize};
20use std::fmt;
21use tracing::instrument;
22
23/// Types of statistics that can be computed.
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub enum StatisticType {
26    /// Minimum value
27    Min,
28    /// Maximum value
29    Max,
30    /// Mean/average value
31    Mean,
32    /// Sum of all values
33    Sum,
34    /// Standard deviation
35    StandardDeviation,
36    /// Variance
37    Variance,
38    /// Median (50th percentile)
39    Median,
40    /// Specific percentile (0.0 to 1.0)
41    Percentile(f64),
42}
43
44impl StatisticType {
45    /// Returns the SQL function name for this statistic.
46    fn sql_function(&self) -> String {
47        match self {
48            StatisticType::Min => "MIN".to_string(),
49            StatisticType::Max => "MAX".to_string(),
50            StatisticType::Mean => "AVG".to_string(),
51            StatisticType::Sum => "SUM".to_string(),
52            StatisticType::StandardDeviation => "STDDEV".to_string(),
53            StatisticType::Variance => "VARIANCE".to_string(),
54            StatisticType::Median => "APPROX_PERCENTILE_CONT".to_string(),
55            StatisticType::Percentile(_) => "APPROX_PERCENTILE_CONT".to_string(),
56        }
57    }
58
59    /// Returns the SQL expression for this statistic.
60    fn sql_expression(&self, column: &str) -> String {
61        match self {
62            StatisticType::Median => {
63                let func = self.sql_function();
64                format!("{func}({column}, 0.5)")
65            }
66            StatisticType::Percentile(p) => {
67                let func = self.sql_function();
68                format!("{func}({column}, {p})")
69            }
70            _ => {
71                let func = self.sql_function();
72                format!("{func}({column})")
73            }
74        }
75    }
76
77    /// Returns a human-readable name for this statistic.
78    fn name(&self) -> &str {
79        match self {
80            StatisticType::Min => "minimum",
81            StatisticType::Max => "maximum",
82            StatisticType::Mean => "mean",
83            StatisticType::Sum => "sum",
84            StatisticType::StandardDeviation => "standard deviation",
85            StatisticType::Variance => "variance",
86            StatisticType::Median => "median",
87            StatisticType::Percentile(p) => {
88                if (*p - 0.5).abs() < f64::EPSILON {
89                    "median"
90                } else {
91                    "percentile"
92                }
93            }
94        }
95    }
96
97    /// Returns the constraint name for backward compatibility.
98    fn constraint_name(&self) -> &str {
99        match self {
100            StatisticType::Min => "min",
101            StatisticType::Max => "max",
102            StatisticType::Mean => "mean",
103            StatisticType::Sum => "sum",
104            StatisticType::StandardDeviation => "standard_deviation",
105            StatisticType::Variance => "variance",
106            StatisticType::Median => "median",
107            StatisticType::Percentile(_) => "percentile",
108        }
109    }
110}
111
112impl fmt::Display for StatisticType {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        match self {
115            StatisticType::Percentile(p) => write!(f, "{}({p})", self.name()),
116            _ => write!(f, "{}", self.name()),
117        }
118    }
119}
120
121/// A unified constraint that checks statistical properties of a column.
122///
123/// This constraint replaces the individual statistical constraints and provides
124/// a consistent interface for all statistical checks.
125///
126/// # Examples
127///
128/// ```rust
129/// use term_guard::constraints::{StatisticalConstraint, StatisticType, Assertion};
130/// use term_guard::core::Constraint;
131///
132/// // Check that mean is between 25 and 35
133/// let mean_check = StatisticalConstraint::new(
134///     "age",
135///     StatisticType::Mean,
136///     Assertion::Between(25.0, 35.0)
137/// );
138///
139/// // Check that maximum is less than 100
140/// let max_check = StatisticalConstraint::new(
141///     "score",
142///     StatisticType::Max,
143///     Assertion::LessThan(100.0)
144/// );
145///
146/// // Check 95th percentile for SLA
147/// let p95_check = StatisticalConstraint::new(
148///     "response_time",
149///     StatisticType::Percentile(0.95),
150///     Assertion::LessThan(1000.0)
151/// );
152/// ```
153#[derive(Debug, Clone)]
154pub struct StatisticalConstraint {
155    /// The column to compute statistics on
156    column: String,
157    /// The type of statistic to compute
158    statistic: StatisticType,
159    /// The assertion to evaluate against the statistic
160    assertion: Assertion,
161}
162
163impl StatisticalConstraint {
164    /// Creates a new statistical constraint.
165    ///
166    /// # Arguments
167    ///
168    /// * `column` - The column to check
169    /// * `statistic` - The type of statistic to compute
170    /// * `assertion` - The assertion to evaluate
171    ///
172    /// # Errors
173    ///
174    /// Returns error if column name is invalid or if percentile is out of range.
175    pub fn new(
176        column: impl Into<String>,
177        statistic: StatisticType,
178        assertion: Assertion,
179    ) -> Result<Self> {
180        let column_str = column.into();
181        SqlSecurity::validate_identifier(&column_str)?;
182
183        // Validate percentile range
184        if let StatisticType::Percentile(p) = &statistic {
185            if !(0.0..=1.0).contains(p) {
186                return Err(TermError::SecurityError(
187                    "Percentile must be between 0.0 and 1.0".to_string(),
188                ));
189            }
190        }
191
192        Ok(Self {
193            column: column_str,
194            statistic,
195            assertion,
196        })
197    }
198
199    /// Creates a minimum value constraint.
200    pub fn min(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
201        Self::new(column, StatisticType::Min, assertion)
202    }
203
204    /// Creates a maximum value constraint.
205    pub fn max(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
206        Self::new(column, StatisticType::Max, assertion)
207    }
208
209    /// Creates a mean/average constraint.
210    pub fn mean(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
211        Self::new(column, StatisticType::Mean, assertion)
212    }
213
214    /// Creates a sum constraint.
215    pub fn sum(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
216        Self::new(column, StatisticType::Sum, assertion)
217    }
218
219    /// Creates a standard deviation constraint.
220    pub fn standard_deviation(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
221        Self::new(column, StatisticType::StandardDeviation, assertion)
222    }
223
224    /// Creates a variance constraint.
225    pub fn variance(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
226        Self::new(column, StatisticType::Variance, assertion)
227    }
228
229    /// Creates a median constraint.
230    pub fn median(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
231        Self::new(column, StatisticType::Median, assertion)
232    }
233
234    /// Creates a percentile constraint.
235    ///
236    /// # Errors
237    ///
238    /// Returns error if column name is invalid or percentile is not between 0.0 and 1.0
239    pub fn percentile(
240        column: impl Into<String>,
241        percentile: f64,
242        assertion: Assertion,
243    ) -> Result<Self> {
244        Self::new(column, StatisticType::Percentile(percentile), assertion)
245    }
246}
247
248#[async_trait]
249impl Constraint for StatisticalConstraint {
250    #[instrument(skip(self, ctx), fields(
251        column = %self.column,
252        statistic = %self.statistic,
253        assertion = %self.assertion
254    ))]
255    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
256        let column_identifier = SqlSecurity::escape_identifier(&self.column)?;
257        let stat_expr = self.statistic.sql_expression(&column_identifier);
258        let sql = format!("SELECT {stat_expr} as stat_value FROM data");
259
260        let df = ctx.sql(&sql).await?;
261        let batches = df.collect().await?;
262
263        if batches.is_empty() {
264            return Ok(ConstraintResult::skipped("No data to validate"));
265        }
266
267        let batch = &batches[0];
268        if batch.num_rows() == 0 {
269            return Ok(ConstraintResult::skipped("No data to validate"));
270        }
271
272        // Extract the statistic value - try Int64 first, then Float64
273        let value = if let Ok(array) = batch
274            .column(0)
275            .as_any()
276            .downcast_ref::<arrow::array::Int64Array>()
277            .ok_or_else(|| TermError::Internal("Failed to extract statistic value".to_string()))
278        {
279            if array.is_null(0) {
280                let stat_name = self.statistic.name();
281                return Ok(ConstraintResult::failure(format!(
282                    "{stat_name} is null (no non-null values)"
283                )));
284            }
285            array.value(0) as f64
286        } else if let Ok(array) = batch
287            .column(0)
288            .as_any()
289            .downcast_ref::<arrow::array::Float64Array>()
290            .ok_or_else(|| TermError::Internal("Failed to extract statistic value".to_string()))
291        {
292            if array.is_null(0) {
293                let stat_name = self.statistic.name();
294                return Ok(ConstraintResult::failure(format!(
295                    "{stat_name} is null (no non-null values)"
296                )));
297            }
298            array.value(0)
299        } else {
300            return Err(TermError::Internal(
301                "Failed to extract statistic value".to_string(),
302            ));
303        };
304
305        if self.assertion.evaluate(value) {
306            Ok(ConstraintResult::success_with_metric(value))
307        } else {
308            Ok(ConstraintResult::failure_with_metric(
309                value,
310                format!(
311                    "{} {value} does not {}",
312                    self.statistic.name(),
313                    self.assertion
314                ),
315            ))
316        }
317    }
318
319    fn name(&self) -> &str {
320        self.statistic.constraint_name()
321    }
322
323    fn column(&self) -> Option<&str> {
324        Some(&self.column)
325    }
326
327    fn metadata(&self) -> ConstraintMetadata {
328        let mut metadata = ConstraintMetadata::for_column(&self.column)
329            .with_description(format!(
330                "Checks that {} of {} {}",
331                self.statistic.name(),
332                self.column,
333                self.assertion.description()
334            ))
335            .with_custom("assertion", self.assertion.to_string())
336            .with_custom("statistic_type", self.statistic.to_string())
337            .with_custom("constraint_type", "statistical");
338
339        if let StatisticType::Percentile(p) = self.statistic {
340            metadata = metadata.with_custom("percentile", p.to_string());
341        }
342
343        metadata
344    }
345}
346
347/// A constraint that can compute multiple statistics in a single query for performance optimization.
348///
349/// This is useful when you need to validate multiple statistics on the same column,
350/// as it reduces the number of table scans required.
351///
352/// # Examples
353///
354/// ```rust
355/// use term_guard::constraints::{MultiStatisticalConstraint, StatisticType, Assertion};
356///
357/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
358/// // Check multiple statistics on the same column in one pass
359/// let multi_stats = MultiStatisticalConstraint::new(
360///     "response_time",
361///     vec![
362///         (StatisticType::Min, Assertion::GreaterThanOrEqual(0.0)),
363///         (StatisticType::Max, Assertion::LessThan(5000.0)),
364///         (StatisticType::Mean, Assertion::Between(100.0, 1000.0)),
365///         (StatisticType::Percentile(0.95), Assertion::LessThan(2000.0)),
366///     ]
367/// )?;
368/// # Ok(())
369/// # }
370/// ```
371#[derive(Debug, Clone)]
372pub struct MultiStatisticalConstraint {
373    column: String,
374    statistics: Vec<(StatisticType, Assertion)>,
375}
376
377impl MultiStatisticalConstraint {
378    /// Creates a new multi-statistical constraint.
379    ///
380    /// # Arguments
381    ///
382    /// * `column` - The column to compute statistics on
383    /// * `statistics` - A vector of (statistic_type, assertion) pairs to evaluate
384    ///
385    /// # Errors
386    ///
387    /// Returns error if column name is invalid or if any percentile is out of range.
388    pub fn new(
389        column: impl Into<String>,
390        statistics: Vec<(StatisticType, Assertion)>,
391    ) -> Result<Self> {
392        let column_str = column.into();
393        SqlSecurity::validate_identifier(&column_str)?;
394
395        // Validate all percentile values
396        for (stat, _) in &statistics {
397            if let StatisticType::Percentile(p) = stat {
398                if !(0.0..=1.0).contains(p) {
399                    return Err(TermError::SecurityError(
400                        "Percentile must be between 0.0 and 1.0".to_string(),
401                    ));
402                }
403            }
404        }
405
406        Ok(Self {
407            column: column_str,
408            statistics,
409        })
410    }
411}
412
413#[async_trait]
414impl Constraint for MultiStatisticalConstraint {
415    #[instrument(skip(self, ctx), fields(
416        column = %self.column,
417        num_statistics = %self.statistics.len()
418    ))]
419    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
420        let column_identifier = SqlSecurity::escape_identifier(&self.column)?;
421
422        // Build SQL with all statistics computed in one query
423        let sql_parts: Vec<String> = self
424            .statistics
425            .iter()
426            .enumerate()
427            .map(|(i, (stat, _))| {
428                let expr = stat.sql_expression(&column_identifier);
429                format!("{expr} as stat_{i}")
430            })
431            .collect();
432
433        let parts = sql_parts.join(", ");
434        let sql = format!("SELECT {parts} FROM data");
435
436        let df = ctx.sql(&sql).await?;
437        let batches = df.collect().await?;
438
439        if batches.is_empty() {
440            return Ok(ConstraintResult::skipped("No data to validate"));
441        }
442
443        let batch = &batches[0];
444        if batch.num_rows() == 0 {
445            return Ok(ConstraintResult::skipped("No data to validate"));
446        }
447
448        // Check each statistic
449        let mut failures = Vec::new();
450        let mut all_metrics = Vec::new();
451
452        for (i, (stat_type, assertion)) in self.statistics.iter().enumerate() {
453            let column = batch.column(i);
454
455            // Extract value
456            let value = if let Some(array) =
457                column.as_any().downcast_ref::<arrow::array::Float64Array>()
458            {
459                if array.is_null(0) {
460                    let name = stat_type.name();
461                    failures.push(format!("{name} is null"));
462                    continue;
463                }
464                array.value(0)
465            } else if let Some(array) = column.as_any().downcast_ref::<arrow::array::Int64Array>() {
466                if array.is_null(0) {
467                    let name = stat_type.name();
468                    failures.push(format!("{name} is null"));
469                    continue;
470                }
471                array.value(0) as f64
472            } else {
473                let name = stat_type.name();
474                failures.push(format!("Failed to compute {name}"));
475                continue;
476            };
477
478            all_metrics.push((stat_type.name().to_string(), value));
479
480            if !assertion.evaluate(value) {
481                failures.push(format!(
482                    "{} is {value} which does not {assertion}",
483                    stat_type.name()
484                ));
485            }
486        }
487
488        if failures.is_empty() {
489            // All assertions passed - return the first metric as representative
490            let first_metric = all_metrics.first().map(|(_, v)| *v).unwrap_or(0.0);
491            Ok(ConstraintResult::success_with_metric(first_metric))
492        } else {
493            Ok(ConstraintResult::failure(failures.join("; ")))
494        }
495    }
496
497    fn name(&self) -> &str {
498        "multi_statistical"
499    }
500
501    fn column(&self) -> Option<&str> {
502        Some(&self.column)
503    }
504
505    fn metadata(&self) -> ConstraintMetadata {
506        let stat_names: Vec<String> = self
507            .statistics
508            .iter()
509            .map(|(s, _)| s.name().to_string())
510            .collect();
511
512        ConstraintMetadata::for_column(&self.column)
513            .with_description({
514                let stats = stat_names.join(", ");
515                format!(
516                    "Checks multiple statistics ({stats}) for column {}",
517                    self.column
518                )
519            })
520            .with_custom("statistics_count", self.statistics.len().to_string())
521            .with_custom("constraint_type", "multi_statistical")
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use crate::core::ConstraintStatus;
529    use arrow::array::Float64Array;
530    use arrow::datatypes::{DataType, Field, Schema};
531    use arrow::record_batch::RecordBatch;
532    use datafusion::datasource::MemTable;
533    use std::sync::Arc;
534
535    async fn create_test_context(values: Vec<Option<f64>>) -> SessionContext {
536        let ctx = SessionContext::new();
537
538        let schema = Arc::new(Schema::new(vec![Field::new(
539            "value",
540            DataType::Float64,
541            true,
542        )]));
543
544        let array = Float64Array::from(values);
545        let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
546
547        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
548        ctx.register_table("data", Arc::new(provider)).unwrap();
549
550        ctx
551    }
552
553    #[tokio::test]
554    async fn test_mean_constraint() {
555        let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
556        let constraint = StatisticalConstraint::mean("value", Assertion::Equals(20.0)).unwrap();
557
558        let result = constraint.evaluate(&ctx).await.unwrap();
559        assert_eq!(result.status, ConstraintStatus::Success);
560        assert_eq!(result.metric, Some(20.0));
561    }
562
563    #[tokio::test]
564    async fn test_min_max_constraints() {
565        let ctx = create_test_context(vec![Some(5.0), Some(10.0), Some(15.0)]).await;
566
567        let min_constraint = StatisticalConstraint::min("value", Assertion::Equals(5.0)).unwrap();
568        let result = min_constraint.evaluate(&ctx).await.unwrap();
569        assert_eq!(result.status, ConstraintStatus::Success);
570        assert_eq!(result.metric, Some(5.0));
571
572        let max_constraint = StatisticalConstraint::max("value", Assertion::Equals(15.0)).unwrap();
573        let result = max_constraint.evaluate(&ctx).await.unwrap();
574        assert_eq!(result.status, ConstraintStatus::Success);
575        assert_eq!(result.metric, Some(15.0));
576    }
577
578    #[tokio::test]
579    async fn test_sum_constraint() {
580        let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
581        let constraint = StatisticalConstraint::sum("value", Assertion::Equals(60.0)).unwrap();
582
583        let result = constraint.evaluate(&ctx).await.unwrap();
584        assert_eq!(result.status, ConstraintStatus::Success);
585        assert_eq!(result.metric, Some(60.0));
586    }
587
588    #[tokio::test]
589    async fn test_with_nulls() {
590        let ctx = create_test_context(vec![Some(10.0), None, Some(20.0)]).await;
591        let constraint = StatisticalConstraint::mean("value", Assertion::Equals(15.0)).unwrap();
592
593        let result = constraint.evaluate(&ctx).await.unwrap();
594        assert_eq!(result.status, ConstraintStatus::Success);
595        assert_eq!(result.metric, Some(15.0));
596    }
597
598    #[tokio::test]
599    async fn test_all_nulls() {
600        let ctx = create_test_context(vec![None, None, None]).await;
601        let constraint = StatisticalConstraint::mean("value", Assertion::Equals(0.0)).unwrap();
602
603        let result = constraint.evaluate(&ctx).await.unwrap();
604        assert_eq!(result.status, ConstraintStatus::Failure);
605        assert!(result.message.unwrap().contains("null"));
606    }
607
608    #[test]
609    fn test_statistic_type_display() {
610        assert_eq!(StatisticType::Min.to_string(), "minimum");
611        assert_eq!(StatisticType::Mean.to_string(), "mean");
612        assert_eq!(
613            StatisticType::Percentile(0.95).to_string(),
614            "percentile(0.95)"
615        );
616        assert_eq!(StatisticType::Median.to_string(), "median");
617    }
618
619    #[tokio::test]
620    async fn test_multi_statistical_constraint() {
621        let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0), Some(40.0)]).await;
622
623        let constraint = MultiStatisticalConstraint::new(
624            "value",
625            vec![
626                (StatisticType::Min, Assertion::GreaterThanOrEqual(10.0)),
627                (StatisticType::Max, Assertion::LessThanOrEqual(40.0)),
628                (StatisticType::Mean, Assertion::Equals(25.0)),
629                (StatisticType::Sum, Assertion::Equals(100.0)),
630            ],
631        )
632        .unwrap();
633
634        let result = constraint.evaluate(&ctx).await.unwrap();
635        assert_eq!(result.status, ConstraintStatus::Success);
636    }
637
638    #[tokio::test]
639    async fn test_multi_statistical_constraint_failure() {
640        let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
641
642        let constraint = MultiStatisticalConstraint::new(
643            "value",
644            vec![
645                (StatisticType::Min, Assertion::Equals(5.0)), // Will fail
646                (StatisticType::Max, Assertion::Equals(30.0)),
647            ],
648        )
649        .unwrap();
650
651        let result = constraint.evaluate(&ctx).await.unwrap();
652        assert_eq!(result.status, ConstraintStatus::Failure);
653        assert!(result.message.unwrap().contains("minimum is 10"));
654    }
655
656    #[test]
657    fn test_invalid_percentile() {
658        let result = StatisticalConstraint::new(
659            "value",
660            StatisticType::Percentile(1.5),
661            Assertion::LessThan(100.0),
662        );
663
664        assert!(result.is_err());
665        assert!(result
666            .unwrap_err()
667            .to_string()
668            .contains("Percentile must be between 0.0 and 1.0"));
669    }
670}