1use crate::constraints::Assertion;
13use crate::core::{current_validation_context, Constraint, ConstraintMetadata, ConstraintResult};
14use crate::prelude::*;
15use crate::security::SqlSecurity;
16use arrow::array::Array;
17use async_trait::async_trait;
18use datafusion::prelude::*;
19use serde::{Deserialize, Serialize};
20use std::fmt;
21use tracing::instrument;
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
24pub enum StatisticType {
25 Min,
27 Max,
29 Mean,
31 Sum,
33 StandardDeviation,
35 Variance,
37 Median,
39 Percentile(f64),
41}
42
43impl StatisticType {
44 fn sql_function(&self) -> String {
46 match self {
47 StatisticType::Min => "MIN".to_string(),
48 StatisticType::Max => "MAX".to_string(),
49 StatisticType::Mean => "AVG".to_string(),
50 StatisticType::Sum => "SUM".to_string(),
51 StatisticType::StandardDeviation => "STDDEV".to_string(),
52 StatisticType::Variance => "VARIANCE".to_string(),
53 StatisticType::Median => "APPROX_PERCENTILE_CONT".to_string(),
54 StatisticType::Percentile(_) => "APPROX_PERCENTILE_CONT".to_string(),
55 }
56 }
57
58 fn sql_expression(&self, column: &str) -> String {
60 match self {
61 StatisticType::Median => {
62 let func = self.sql_function();
63 format!("{func}({column}, 0.5)")
64 }
65 StatisticType::Percentile(p) => {
66 let func = self.sql_function();
67 format!("{func}({column}, {p})")
68 }
69 _ => {
70 let func = self.sql_function();
71 format!("{func}({column})")
72 }
73 }
74 }
75
76 fn name(&self) -> &str {
78 match self {
79 StatisticType::Min => "minimum",
80 StatisticType::Max => "maximum",
81 StatisticType::Mean => "mean",
82 StatisticType::Sum => "sum",
83 StatisticType::StandardDeviation => "standard deviation",
84 StatisticType::Variance => "variance",
85 StatisticType::Median => "median",
86 StatisticType::Percentile(p) => {
87 if (*p - 0.5).abs() < f64::EPSILON {
88 "median"
89 } else {
90 "percentile"
91 }
92 }
93 }
94 }
95
96 fn constraint_name(&self) -> &str {
98 match self {
99 StatisticType::Min => "min",
100 StatisticType::Max => "max",
101 StatisticType::Mean => "mean",
102 StatisticType::Sum => "sum",
103 StatisticType::StandardDeviation => "standard_deviation",
104 StatisticType::Variance => "variance",
105 StatisticType::Median => "median",
106 StatisticType::Percentile(_) => "percentile",
107 }
108 }
109}
110
111impl fmt::Display for StatisticType {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 match self {
114 StatisticType::Percentile(p) => write!(f, "{}({p})", self.name()),
115 _ => write!(f, "{}", self.name()),
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
153pub struct StatisticalConstraint {
154 column: String,
156 statistic: StatisticType,
158 assertion: Assertion,
160}
161
162impl StatisticalConstraint {
163 pub fn new(
175 column: impl Into<String>,
176 statistic: StatisticType,
177 assertion: Assertion,
178 ) -> Result<Self> {
179 let column_str = column.into();
180 SqlSecurity::validate_identifier(&column_str)?;
181
182 if let StatisticType::Percentile(p) = &statistic {
184 if !(0.0..=1.0).contains(p) {
185 return Err(TermError::SecurityError(
186 "Percentile must be between 0.0 and 1.0".to_string(),
187 ));
188 }
189 }
190
191 Ok(Self {
192 column: column_str,
193 statistic,
194 assertion,
195 })
196 }
197
198 pub fn min(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
200 Self::new(column, StatisticType::Min, assertion)
201 }
202
203 pub fn max(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
205 Self::new(column, StatisticType::Max, assertion)
206 }
207
208 pub fn mean(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
210 Self::new(column, StatisticType::Mean, assertion)
211 }
212
213 pub fn sum(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
215 Self::new(column, StatisticType::Sum, assertion)
216 }
217
218 pub fn standard_deviation(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
220 Self::new(column, StatisticType::StandardDeviation, assertion)
221 }
222
223 pub fn variance(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
225 Self::new(column, StatisticType::Variance, assertion)
226 }
227
228 pub fn median(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
230 Self::new(column, StatisticType::Median, assertion)
231 }
232
233 pub fn percentile(
239 column: impl Into<String>,
240 percentile: f64,
241 assertion: Assertion,
242 ) -> Result<Self> {
243 Self::new(column, StatisticType::Percentile(percentile), assertion)
244 }
245}
246
247#[async_trait]
248impl Constraint for StatisticalConstraint {
249 #[instrument(skip(self, ctx), fields(
250 column = %self.column,
251 statistic = %self.statistic,
252 assertion = %self.assertion
253 ))]
254 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
255 let column_identifier = SqlSecurity::escape_identifier(&self.column)?;
256 let stat_expr = self.statistic.sql_expression(&column_identifier);
257 let validation_ctx = current_validation_context();
260
261 let table_name = validation_ctx.table_name();
262
263 let sql = format!("SELECT {stat_expr} as stat_value FROM {table_name}");
264
265 let df = ctx.sql(&sql).await?;
266 let batches = df.collect().await?;
267
268 if batches.is_empty() {
269 return Ok(ConstraintResult::skipped("No data to validate"));
270 }
271
272 let batch = &batches[0];
273 if batch.num_rows() == 0 {
274 return Ok(ConstraintResult::skipped("No data to validate"));
275 }
276
277 let value = if let Ok(array) = batch
279 .column(0)
280 .as_any()
281 .downcast_ref::<arrow::array::Int64Array>()
282 .ok_or_else(|| TermError::Internal("Failed to extract statistic value".to_string()))
283 {
284 if array.is_null(0) {
285 let stat_name = self.statistic.name();
286 return Ok(ConstraintResult::failure(format!(
287 "{stat_name} is null (no non-null values)"
288 )));
289 }
290 array.value(0) as f64
291 } else if let Ok(array) = batch
292 .column(0)
293 .as_any()
294 .downcast_ref::<arrow::array::Float64Array>()
295 .ok_or_else(|| TermError::Internal("Failed to extract statistic value".to_string()))
296 {
297 if array.is_null(0) {
298 let stat_name = self.statistic.name();
299 return Ok(ConstraintResult::failure(format!(
300 "{stat_name} is null (no non-null values)"
301 )));
302 }
303 array.value(0)
304 } else {
305 return Err(TermError::Internal(
306 "Failed to extract statistic value".to_string(),
307 ));
308 };
309
310 if self.assertion.evaluate(value) {
311 Ok(ConstraintResult::success_with_metric(value))
312 } else {
313 Ok(ConstraintResult::failure_with_metric(
314 value,
315 format!(
316 "{} {value} does not {}",
317 self.statistic.name(),
318 self.assertion
319 ),
320 ))
321 }
322 }
323
324 fn name(&self) -> &str {
325 self.statistic.constraint_name()
326 }
327
328 fn column(&self) -> Option<&str> {
329 Some(&self.column)
330 }
331
332 fn metadata(&self) -> ConstraintMetadata {
333 let mut metadata = ConstraintMetadata::for_column(&self.column)
334 .with_description(format!(
335 "Checks that {} of {} {}",
336 self.statistic.name(),
337 self.column,
338 self.assertion.description()
339 ))
340 .with_custom("assertion", self.assertion.to_string())
341 .with_custom("statistic_type", self.statistic.to_string())
342 .with_custom("constraint_type", "statistical");
343
344 if let StatisticType::Percentile(p) = self.statistic {
345 metadata = metadata.with_custom("percentile", p.to_string());
346 }
347
348 metadata
349 }
350}
351
352#[derive(Debug, Clone)]
377pub struct MultiStatisticalConstraint {
378 column: String,
379 statistics: Vec<(StatisticType, Assertion)>,
380}
381
382impl MultiStatisticalConstraint {
383 pub fn new(
394 column: impl Into<String>,
395 statistics: Vec<(StatisticType, Assertion)>,
396 ) -> Result<Self> {
397 let column_str = column.into();
398 SqlSecurity::validate_identifier(&column_str)?;
399
400 for (stat, _) in &statistics {
402 if let StatisticType::Percentile(p) = stat {
403 if !(0.0..=1.0).contains(p) {
404 return Err(TermError::SecurityError(
405 "Percentile must be between 0.0 and 1.0".to_string(),
406 ));
407 }
408 }
409 }
410
411 Ok(Self {
412 column: column_str,
413 statistics,
414 })
415 }
416}
417
418#[async_trait]
419impl Constraint for MultiStatisticalConstraint {
420 #[instrument(skip(self, ctx), fields(
421 column = %self.column,
422 num_statistics = %self.statistics.len()
423 ))]
424 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
425 let column_identifier = SqlSecurity::escape_identifier(&self.column)?;
426
427 let sql_parts: Vec<String> = self
429 .statistics
430 .iter()
431 .enumerate()
432 .map(|(i, (stat, _))| {
433 let expr = stat.sql_expression(&column_identifier);
434 format!("{expr} as stat_{i}")
435 })
436 .collect();
437
438 let parts = sql_parts.join(", ");
439 let validation_ctx = current_validation_context();
441 let table_name = validation_ctx.table_name();
442
443 let sql = format!("SELECT {parts} FROM {table_name}");
444
445 let df = ctx.sql(&sql).await?;
446 let batches = df.collect().await?;
447
448 if batches.is_empty() {
449 return Ok(ConstraintResult::skipped("No data to validate"));
450 }
451
452 let batch = &batches[0];
453 if batch.num_rows() == 0 {
454 return Ok(ConstraintResult::skipped("No data to validate"));
455 }
456
457 let mut failures = Vec::new();
459 let mut all_metrics = Vec::new();
460
461 for (i, (stat_type, assertion)) in self.statistics.iter().enumerate() {
462 let column = batch.column(i);
463
464 let value = if let Some(array) =
466 column.as_any().downcast_ref::<arrow::array::Float64Array>()
467 {
468 if array.is_null(0) {
469 let name = stat_type.name();
470 failures.push(format!("{name} is null"));
471 continue;
472 }
473 array.value(0)
474 } else if let Some(array) = column.as_any().downcast_ref::<arrow::array::Int64Array>() {
475 if array.is_null(0) {
476 let name = stat_type.name();
477 failures.push(format!("{name} is null"));
478 continue;
479 }
480 array.value(0) as f64
481 } else {
482 let name = stat_type.name();
483 failures.push(format!("Failed to compute {name}"));
484 continue;
485 };
486
487 all_metrics.push((stat_type.name().to_string(), value));
488
489 if !assertion.evaluate(value) {
490 failures.push(format!(
491 "{} is {value} which does not {assertion}",
492 stat_type.name()
493 ));
494 }
495 }
496
497 if failures.is_empty() {
498 let first_metric = all_metrics.first().map(|(_, v)| *v).unwrap_or(0.0);
500 Ok(ConstraintResult::success_with_metric(first_metric))
501 } else {
502 Ok(ConstraintResult::failure(failures.join("; ")))
503 }
504 }
505
506 fn name(&self) -> &str {
507 "multi_statistical"
508 }
509
510 fn column(&self) -> Option<&str> {
511 Some(&self.column)
512 }
513
514 fn metadata(&self) -> ConstraintMetadata {
515 let stat_names: Vec<String> = self
516 .statistics
517 .iter()
518 .map(|(s, _)| s.name().to_string())
519 .collect();
520
521 ConstraintMetadata::for_column(&self.column)
522 .with_description({
523 let stats = stat_names.join(", ");
524 format!(
525 "Checks multiple statistics ({stats}) for column {}",
526 self.column
527 )
528 })
529 .with_custom("statistics_count", self.statistics.len().to_string())
530 .with_custom("constraint_type", "multi_statistical")
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537 use crate::core::ConstraintStatus;
538 use arrow::array::Float64Array;
539 use arrow::datatypes::{DataType, Field, Schema};
540 use arrow::record_batch::RecordBatch;
541 use datafusion::datasource::MemTable;
542 use std::sync::Arc;
543
544 use crate::test_helpers::evaluate_constraint_with_context;
545 async fn create_test_context(values: Vec<Option<f64>>) -> SessionContext {
546 let ctx = SessionContext::new();
547
548 let schema = Arc::new(Schema::new(vec![Field::new(
549 "value",
550 DataType::Float64,
551 true,
552 )]));
553
554 let array = Float64Array::from(values);
555 let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
556
557 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
558 ctx.register_table("data", Arc::new(provider)).unwrap();
559
560 ctx
561 }
562
563 #[tokio::test]
564 async fn test_mean_constraint() {
565 let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
566 let constraint = StatisticalConstraint::mean("value", Assertion::Equals(20.0)).unwrap();
567
568 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
569 .await
570 .unwrap();
571 assert_eq!(result.status, ConstraintStatus::Success);
572 assert_eq!(result.metric, Some(20.0));
573 }
574
575 #[tokio::test]
576 async fn test_min_max_constraints() {
577 let ctx = create_test_context(vec![Some(5.0), Some(10.0), Some(15.0)]).await;
578
579 let min_constraint = StatisticalConstraint::min("value", Assertion::Equals(5.0)).unwrap();
580 let result = evaluate_constraint_with_context(&min_constraint, &ctx, "data")
581 .await
582 .unwrap();
583 assert_eq!(result.status, ConstraintStatus::Success);
584 assert_eq!(result.metric, Some(5.0));
585
586 let max_constraint = StatisticalConstraint::max("value", Assertion::Equals(15.0)).unwrap();
587 let result = evaluate_constraint_with_context(&max_constraint, &ctx, "data")
588 .await
589 .unwrap();
590 assert_eq!(result.status, ConstraintStatus::Success);
591 assert_eq!(result.metric, Some(15.0));
592 }
593
594 #[tokio::test]
595 async fn test_sum_constraint() {
596 let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
597 let constraint = StatisticalConstraint::sum("value", Assertion::Equals(60.0)).unwrap();
598
599 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
600 .await
601 .unwrap();
602 assert_eq!(result.status, ConstraintStatus::Success);
603 assert_eq!(result.metric, Some(60.0));
604 }
605
606 #[tokio::test]
607 async fn test_with_nulls() {
608 let ctx = create_test_context(vec![Some(10.0), None, Some(20.0)]).await;
609 let constraint = StatisticalConstraint::mean("value", Assertion::Equals(15.0)).unwrap();
610
611 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
612 .await
613 .unwrap();
614 assert_eq!(result.status, ConstraintStatus::Success);
615 assert_eq!(result.metric, Some(15.0));
616 }
617
618 #[tokio::test]
619 async fn test_all_nulls() {
620 let ctx = create_test_context(vec![None, None, None]).await;
621 let constraint = StatisticalConstraint::mean("value", Assertion::Equals(0.0)).unwrap();
622
623 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
624 .await
625 .unwrap();
626 assert_eq!(result.status, ConstraintStatus::Failure);
627 assert!(result.message.unwrap().contains("null"));
628 }
629
630 #[test]
631 fn test_statistic_type_display() {
632 assert_eq!(StatisticType::Min.to_string(), "minimum");
633 assert_eq!(StatisticType::Mean.to_string(), "mean");
634 assert_eq!(
635 StatisticType::Percentile(0.95).to_string(),
636 "percentile(0.95)"
637 );
638 assert_eq!(StatisticType::Median.to_string(), "median");
639 }
640
641 #[tokio::test]
642 async fn test_multi_statistical_constraint() {
643 let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0), Some(40.0)]).await;
644
645 let constraint = MultiStatisticalConstraint::new(
646 "value",
647 vec![
648 (StatisticType::Min, Assertion::GreaterThanOrEqual(10.0)),
649 (StatisticType::Max, Assertion::LessThanOrEqual(40.0)),
650 (StatisticType::Mean, Assertion::Equals(25.0)),
651 (StatisticType::Sum, Assertion::Equals(100.0)),
652 ],
653 )
654 .unwrap();
655
656 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
657 .await
658 .unwrap();
659 assert_eq!(result.status, ConstraintStatus::Success);
660 }
661
662 #[tokio::test]
663 async fn test_multi_statistical_constraint_failure() {
664 let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
665
666 let constraint = MultiStatisticalConstraint::new(
667 "value",
668 vec![
669 (StatisticType::Min, Assertion::Equals(5.0)), (StatisticType::Max, Assertion::Equals(30.0)),
671 ],
672 )
673 .unwrap();
674
675 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
676 .await
677 .unwrap();
678 assert_eq!(result.status, ConstraintStatus::Failure);
679 assert!(result.message.unwrap().contains("minimum is 10"));
680 }
681
682 #[test]
683 fn test_invalid_percentile() {
684 let result = StatisticalConstraint::new(
685 "value",
686 StatisticType::Percentile(1.5),
687 Assertion::LessThan(100.0),
688 );
689
690 assert!(result.is_err());
691 assert!(result
692 .unwrap_err()
693 .to_string()
694 .contains("Percentile must be between 0.0 and 1.0"));
695 }
696}