1use crate::constraints::Assertion;
13use crate::core::{Constraint, ConstraintMetadata, ConstraintResult};
14use crate::prelude::*;
15use crate::security::SqlSecurity;
16use arrow::array::Array;
17use async_trait::async_trait;
18use datafusion::prelude::*;
19use serde::{Deserialize, Serialize};
20use std::fmt;
21use tracing::instrument;
22
23#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub enum StatisticType {
26 Min,
28 Max,
30 Mean,
32 Sum,
34 StandardDeviation,
36 Variance,
38 Median,
40 Percentile(f64),
42}
43
44impl StatisticType {
45 fn sql_function(&self) -> String {
47 match self {
48 StatisticType::Min => "MIN".to_string(),
49 StatisticType::Max => "MAX".to_string(),
50 StatisticType::Mean => "AVG".to_string(),
51 StatisticType::Sum => "SUM".to_string(),
52 StatisticType::StandardDeviation => "STDDEV".to_string(),
53 StatisticType::Variance => "VARIANCE".to_string(),
54 StatisticType::Median => "APPROX_PERCENTILE_CONT".to_string(),
55 StatisticType::Percentile(_) => "APPROX_PERCENTILE_CONT".to_string(),
56 }
57 }
58
59 fn sql_expression(&self, column: &str) -> String {
61 match self {
62 StatisticType::Median => {
63 let func = self.sql_function();
64 format!("{func}({column}, 0.5)")
65 }
66 StatisticType::Percentile(p) => {
67 let func = self.sql_function();
68 format!("{func}({column}, {p})")
69 }
70 _ => {
71 let func = self.sql_function();
72 format!("{func}({column})")
73 }
74 }
75 }
76
77 fn name(&self) -> &str {
79 match self {
80 StatisticType::Min => "minimum",
81 StatisticType::Max => "maximum",
82 StatisticType::Mean => "mean",
83 StatisticType::Sum => "sum",
84 StatisticType::StandardDeviation => "standard deviation",
85 StatisticType::Variance => "variance",
86 StatisticType::Median => "median",
87 StatisticType::Percentile(p) => {
88 if (*p - 0.5).abs() < f64::EPSILON {
89 "median"
90 } else {
91 "percentile"
92 }
93 }
94 }
95 }
96
97 fn constraint_name(&self) -> &str {
99 match self {
100 StatisticType::Min => "min",
101 StatisticType::Max => "max",
102 StatisticType::Mean => "mean",
103 StatisticType::Sum => "sum",
104 StatisticType::StandardDeviation => "standard_deviation",
105 StatisticType::Variance => "variance",
106 StatisticType::Median => "median",
107 StatisticType::Percentile(_) => "percentile",
108 }
109 }
110}
111
112impl fmt::Display for StatisticType {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 match self {
115 StatisticType::Percentile(p) => write!(f, "{}({p})", self.name()),
116 _ => write!(f, "{}", self.name()),
117 }
118 }
119}
120
121#[derive(Debug, Clone)]
154pub struct StatisticalConstraint {
155 column: String,
157 statistic: StatisticType,
159 assertion: Assertion,
161}
162
163impl StatisticalConstraint {
164 pub fn new(
176 column: impl Into<String>,
177 statistic: StatisticType,
178 assertion: Assertion,
179 ) -> Result<Self> {
180 let column_str = column.into();
181 SqlSecurity::validate_identifier(&column_str)?;
182
183 if let StatisticType::Percentile(p) = &statistic {
185 if !(0.0..=1.0).contains(p) {
186 return Err(TermError::SecurityError(
187 "Percentile must be between 0.0 and 1.0".to_string(),
188 ));
189 }
190 }
191
192 Ok(Self {
193 column: column_str,
194 statistic,
195 assertion,
196 })
197 }
198
199 pub fn min(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
201 Self::new(column, StatisticType::Min, assertion)
202 }
203
204 pub fn max(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
206 Self::new(column, StatisticType::Max, assertion)
207 }
208
209 pub fn mean(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
211 Self::new(column, StatisticType::Mean, assertion)
212 }
213
214 pub fn sum(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
216 Self::new(column, StatisticType::Sum, assertion)
217 }
218
219 pub fn standard_deviation(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
221 Self::new(column, StatisticType::StandardDeviation, assertion)
222 }
223
224 pub fn variance(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
226 Self::new(column, StatisticType::Variance, assertion)
227 }
228
229 pub fn median(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
231 Self::new(column, StatisticType::Median, assertion)
232 }
233
234 pub fn percentile(
240 column: impl Into<String>,
241 percentile: f64,
242 assertion: Assertion,
243 ) -> Result<Self> {
244 Self::new(column, StatisticType::Percentile(percentile), assertion)
245 }
246}
247
248#[async_trait]
249impl Constraint for StatisticalConstraint {
250 #[instrument(skip(self, ctx), fields(
251 column = %self.column,
252 statistic = %self.statistic,
253 assertion = %self.assertion
254 ))]
255 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
256 let column_identifier = SqlSecurity::escape_identifier(&self.column)?;
257 let stat_expr = self.statistic.sql_expression(&column_identifier);
258 let sql = format!("SELECT {stat_expr} as stat_value FROM data");
259
260 let df = ctx.sql(&sql).await?;
261 let batches = df.collect().await?;
262
263 if batches.is_empty() {
264 return Ok(ConstraintResult::skipped("No data to validate"));
265 }
266
267 let batch = &batches[0];
268 if batch.num_rows() == 0 {
269 return Ok(ConstraintResult::skipped("No data to validate"));
270 }
271
272 let value = if let Ok(array) = batch
274 .column(0)
275 .as_any()
276 .downcast_ref::<arrow::array::Int64Array>()
277 .ok_or_else(|| TermError::Internal("Failed to extract statistic value".to_string()))
278 {
279 if array.is_null(0) {
280 let stat_name = self.statistic.name();
281 return Ok(ConstraintResult::failure(format!(
282 "{stat_name} is null (no non-null values)"
283 )));
284 }
285 array.value(0) as f64
286 } else if let Ok(array) = batch
287 .column(0)
288 .as_any()
289 .downcast_ref::<arrow::array::Float64Array>()
290 .ok_or_else(|| TermError::Internal("Failed to extract statistic value".to_string()))
291 {
292 if array.is_null(0) {
293 let stat_name = self.statistic.name();
294 return Ok(ConstraintResult::failure(format!(
295 "{stat_name} is null (no non-null values)"
296 )));
297 }
298 array.value(0)
299 } else {
300 return Err(TermError::Internal(
301 "Failed to extract statistic value".to_string(),
302 ));
303 };
304
305 if self.assertion.evaluate(value) {
306 Ok(ConstraintResult::success_with_metric(value))
307 } else {
308 Ok(ConstraintResult::failure_with_metric(
309 value,
310 format!(
311 "{} {value} does not {}",
312 self.statistic.name(),
313 self.assertion
314 ),
315 ))
316 }
317 }
318
319 fn name(&self) -> &str {
320 self.statistic.constraint_name()
321 }
322
323 fn column(&self) -> Option<&str> {
324 Some(&self.column)
325 }
326
327 fn metadata(&self) -> ConstraintMetadata {
328 let mut metadata = ConstraintMetadata::for_column(&self.column)
329 .with_description(format!(
330 "Checks that {} of {} {}",
331 self.statistic.name(),
332 self.column,
333 self.assertion.description()
334 ))
335 .with_custom("assertion", self.assertion.to_string())
336 .with_custom("statistic_type", self.statistic.to_string())
337 .with_custom("constraint_type", "statistical");
338
339 if let StatisticType::Percentile(p) = self.statistic {
340 metadata = metadata.with_custom("percentile", p.to_string());
341 }
342
343 metadata
344 }
345}
346
347#[derive(Debug, Clone)]
372pub struct MultiStatisticalConstraint {
373 column: String,
374 statistics: Vec<(StatisticType, Assertion)>,
375}
376
377impl MultiStatisticalConstraint {
378 pub fn new(
389 column: impl Into<String>,
390 statistics: Vec<(StatisticType, Assertion)>,
391 ) -> Result<Self> {
392 let column_str = column.into();
393 SqlSecurity::validate_identifier(&column_str)?;
394
395 for (stat, _) in &statistics {
397 if let StatisticType::Percentile(p) = stat {
398 if !(0.0..=1.0).contains(p) {
399 return Err(TermError::SecurityError(
400 "Percentile must be between 0.0 and 1.0".to_string(),
401 ));
402 }
403 }
404 }
405
406 Ok(Self {
407 column: column_str,
408 statistics,
409 })
410 }
411}
412
413#[async_trait]
414impl Constraint for MultiStatisticalConstraint {
415 #[instrument(skip(self, ctx), fields(
416 column = %self.column,
417 num_statistics = %self.statistics.len()
418 ))]
419 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
420 let column_identifier = SqlSecurity::escape_identifier(&self.column)?;
421
422 let sql_parts: Vec<String> = self
424 .statistics
425 .iter()
426 .enumerate()
427 .map(|(i, (stat, _))| {
428 let expr = stat.sql_expression(&column_identifier);
429 format!("{expr} as stat_{i}")
430 })
431 .collect();
432
433 let parts = sql_parts.join(", ");
434 let sql = format!("SELECT {parts} FROM data");
435
436 let df = ctx.sql(&sql).await?;
437 let batches = df.collect().await?;
438
439 if batches.is_empty() {
440 return Ok(ConstraintResult::skipped("No data to validate"));
441 }
442
443 let batch = &batches[0];
444 if batch.num_rows() == 0 {
445 return Ok(ConstraintResult::skipped("No data to validate"));
446 }
447
448 let mut failures = Vec::new();
450 let mut all_metrics = Vec::new();
451
452 for (i, (stat_type, assertion)) in self.statistics.iter().enumerate() {
453 let column = batch.column(i);
454
455 let value = if let Some(array) =
457 column.as_any().downcast_ref::<arrow::array::Float64Array>()
458 {
459 if array.is_null(0) {
460 let name = stat_type.name();
461 failures.push(format!("{name} is null"));
462 continue;
463 }
464 array.value(0)
465 } else if let Some(array) = column.as_any().downcast_ref::<arrow::array::Int64Array>() {
466 if array.is_null(0) {
467 let name = stat_type.name();
468 failures.push(format!("{name} is null"));
469 continue;
470 }
471 array.value(0) as f64
472 } else {
473 let name = stat_type.name();
474 failures.push(format!("Failed to compute {name}"));
475 continue;
476 };
477
478 all_metrics.push((stat_type.name().to_string(), value));
479
480 if !assertion.evaluate(value) {
481 failures.push(format!(
482 "{} is {value} which does not {assertion}",
483 stat_type.name()
484 ));
485 }
486 }
487
488 if failures.is_empty() {
489 let first_metric = all_metrics.first().map(|(_, v)| *v).unwrap_or(0.0);
491 Ok(ConstraintResult::success_with_metric(first_metric))
492 } else {
493 Ok(ConstraintResult::failure(failures.join("; ")))
494 }
495 }
496
497 fn name(&self) -> &str {
498 "multi_statistical"
499 }
500
501 fn column(&self) -> Option<&str> {
502 Some(&self.column)
503 }
504
505 fn metadata(&self) -> ConstraintMetadata {
506 let stat_names: Vec<String> = self
507 .statistics
508 .iter()
509 .map(|(s, _)| s.name().to_string())
510 .collect();
511
512 ConstraintMetadata::for_column(&self.column)
513 .with_description({
514 let stats = stat_names.join(", ");
515 format!(
516 "Checks multiple statistics ({stats}) for column {}",
517 self.column
518 )
519 })
520 .with_custom("statistics_count", self.statistics.len().to_string())
521 .with_custom("constraint_type", "multi_statistical")
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use crate::core::ConstraintStatus;
529 use arrow::array::Float64Array;
530 use arrow::datatypes::{DataType, Field, Schema};
531 use arrow::record_batch::RecordBatch;
532 use datafusion::datasource::MemTable;
533 use std::sync::Arc;
534
535 async fn create_test_context(values: Vec<Option<f64>>) -> SessionContext {
536 let ctx = SessionContext::new();
537
538 let schema = Arc::new(Schema::new(vec![Field::new(
539 "value",
540 DataType::Float64,
541 true,
542 )]));
543
544 let array = Float64Array::from(values);
545 let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
546
547 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
548 ctx.register_table("data", Arc::new(provider)).unwrap();
549
550 ctx
551 }
552
553 #[tokio::test]
554 async fn test_mean_constraint() {
555 let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
556 let constraint = StatisticalConstraint::mean("value", Assertion::Equals(20.0)).unwrap();
557
558 let result = constraint.evaluate(&ctx).await.unwrap();
559 assert_eq!(result.status, ConstraintStatus::Success);
560 assert_eq!(result.metric, Some(20.0));
561 }
562
563 #[tokio::test]
564 async fn test_min_max_constraints() {
565 let ctx = create_test_context(vec![Some(5.0), Some(10.0), Some(15.0)]).await;
566
567 let min_constraint = StatisticalConstraint::min("value", Assertion::Equals(5.0)).unwrap();
568 let result = min_constraint.evaluate(&ctx).await.unwrap();
569 assert_eq!(result.status, ConstraintStatus::Success);
570 assert_eq!(result.metric, Some(5.0));
571
572 let max_constraint = StatisticalConstraint::max("value", Assertion::Equals(15.0)).unwrap();
573 let result = max_constraint.evaluate(&ctx).await.unwrap();
574 assert_eq!(result.status, ConstraintStatus::Success);
575 assert_eq!(result.metric, Some(15.0));
576 }
577
578 #[tokio::test]
579 async fn test_sum_constraint() {
580 let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
581 let constraint = StatisticalConstraint::sum("value", Assertion::Equals(60.0)).unwrap();
582
583 let result = constraint.evaluate(&ctx).await.unwrap();
584 assert_eq!(result.status, ConstraintStatus::Success);
585 assert_eq!(result.metric, Some(60.0));
586 }
587
588 #[tokio::test]
589 async fn test_with_nulls() {
590 let ctx = create_test_context(vec![Some(10.0), None, Some(20.0)]).await;
591 let constraint = StatisticalConstraint::mean("value", Assertion::Equals(15.0)).unwrap();
592
593 let result = constraint.evaluate(&ctx).await.unwrap();
594 assert_eq!(result.status, ConstraintStatus::Success);
595 assert_eq!(result.metric, Some(15.0));
596 }
597
598 #[tokio::test]
599 async fn test_all_nulls() {
600 let ctx = create_test_context(vec![None, None, None]).await;
601 let constraint = StatisticalConstraint::mean("value", Assertion::Equals(0.0)).unwrap();
602
603 let result = constraint.evaluate(&ctx).await.unwrap();
604 assert_eq!(result.status, ConstraintStatus::Failure);
605 assert!(result.message.unwrap().contains("null"));
606 }
607
608 #[test]
609 fn test_statistic_type_display() {
610 assert_eq!(StatisticType::Min.to_string(), "minimum");
611 assert_eq!(StatisticType::Mean.to_string(), "mean");
612 assert_eq!(
613 StatisticType::Percentile(0.95).to_string(),
614 "percentile(0.95)"
615 );
616 assert_eq!(StatisticType::Median.to_string(), "median");
617 }
618
619 #[tokio::test]
620 async fn test_multi_statistical_constraint() {
621 let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0), Some(40.0)]).await;
622
623 let constraint = MultiStatisticalConstraint::new(
624 "value",
625 vec![
626 (StatisticType::Min, Assertion::GreaterThanOrEqual(10.0)),
627 (StatisticType::Max, Assertion::LessThanOrEqual(40.0)),
628 (StatisticType::Mean, Assertion::Equals(25.0)),
629 (StatisticType::Sum, Assertion::Equals(100.0)),
630 ],
631 )
632 .unwrap();
633
634 let result = constraint.evaluate(&ctx).await.unwrap();
635 assert_eq!(result.status, ConstraintStatus::Success);
636 }
637
638 #[tokio::test]
639 async fn test_multi_statistical_constraint_failure() {
640 let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
641
642 let constraint = MultiStatisticalConstraint::new(
643 "value",
644 vec![
645 (StatisticType::Min, Assertion::Equals(5.0)), (StatisticType::Max, Assertion::Equals(30.0)),
647 ],
648 )
649 .unwrap();
650
651 let result = constraint.evaluate(&ctx).await.unwrap();
652 assert_eq!(result.status, ConstraintStatus::Failure);
653 assert!(result.message.unwrap().contains("minimum is 10"));
654 }
655
656 #[test]
657 fn test_invalid_percentile() {
658 let result = StatisticalConstraint::new(
659 "value",
660 StatisticType::Percentile(1.5),
661 Assertion::LessThan(100.0),
662 );
663
664 assert!(result.is_err());
665 assert!(result
666 .unwrap_err()
667 .to_string()
668 .contains("Percentile must be between 0.0 and 1.0"));
669 }
670}