1use crate::constraints::Assertion;
10use crate::core::{current_validation_context, Constraint, ConstraintMetadata, ConstraintResult};
11use crate::prelude::*;
12use crate::security::SqlSecurity;
13use arrow::array::Array;
14use async_trait::async_trait;
15use datafusion::prelude::*;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use tracing::{debug, instrument};
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21pub enum QuantileMethod {
22 Approximate,
24 Exact,
26 Auto { threshold: usize },
28}
29
30impl Default for QuantileMethod {
31 fn default() -> Self {
32 Self::Auto { threshold: 10000 }
33 }
34}
35
36#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
38pub struct QuantileCheck {
39 pub quantile: f64,
41 pub assertion: Assertion,
43}
44
45impl QuantileCheck {
46 pub fn new(quantile: f64, assertion: Assertion) -> Result<Self> {
48 if !(0.0..=1.0).contains(&quantile) {
49 return Err(TermError::Configuration(
50 "Quantile must be between 0.0 and 1.0".to_string(),
51 ));
52 }
53 Ok(Self {
54 quantile,
55 assertion,
56 })
57 }
58}
59
60#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub struct DistributionConfig {
63 pub quantiles: Vec<f64>,
65 pub include_bounds: bool,
67 pub compute_iqr: bool,
69}
70
71impl Default for DistributionConfig {
72 fn default() -> Self {
73 Self {
74 quantiles: vec![0.25, 0.5, 0.75],
75 include_bounds: true,
76 compute_iqr: true,
77 }
78 }
79}
80
81#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
83pub enum QuantileValidation {
84 Single(QuantileCheck),
86
87 Multiple(Vec<QuantileCheck>),
89
90 Distribution {
92 config: DistributionConfig,
93 iqr_assertion: Option<Assertion>,
95 quantile_assertions: HashMap<String, Assertion>,
97 },
98
99 Monotonic {
101 quantiles: Vec<f64>,
102 strict: bool,
104 },
105
106 Custom {
108 sql_expression: String,
109 assertion: Assertion,
110 },
111}
112
113#[derive(Debug, Clone)]
145pub struct QuantileConstraint {
146 column: String,
148 validation: QuantileValidation,
150 method: QuantileMethod,
152}
153
154impl QuantileConstraint {
155 pub fn new(column: impl Into<String>, validation: QuantileValidation) -> Result<Self> {
166 Self::with_method(column, validation, QuantileMethod::default())
167 }
168
169 pub fn with_method(
171 column: impl Into<String>,
172 validation: QuantileValidation,
173 method: QuantileMethod,
174 ) -> Result<Self> {
175 let column_str = column.into();
176 SqlSecurity::validate_identifier(&column_str)?;
177
178 Ok(Self {
179 column: column_str,
180 validation,
181 method,
182 })
183 }
184
185 pub fn median(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
187 Self::new(
188 column,
189 QuantileValidation::Single(QuantileCheck::new(0.5, assertion)?),
190 )
191 }
192
193 pub fn percentile(
195 column: impl Into<String>,
196 quantile: f64,
197 assertion: Assertion,
198 ) -> Result<Self> {
199 Self::new(
200 column,
201 QuantileValidation::Single(QuantileCheck::new(quantile, assertion)?),
202 )
203 }
204
205 pub fn multiple(column: impl Into<String>, checks: Vec<QuantileCheck>) -> Result<Self> {
207 if checks.is_empty() {
208 return Err(TermError::Configuration(
209 "At least one quantile check is required".to_string(),
210 ));
211 }
212 Self::new(column, QuantileValidation::Multiple(checks))
213 }
214
215 pub fn distribution(column: impl Into<String>, config: DistributionConfig) -> Result<Self> {
217 Self::new(
218 column,
219 QuantileValidation::Distribution {
220 config,
221 iqr_assertion: None,
222 quantile_assertions: HashMap::new(),
223 },
224 )
225 }
226
227 fn approx_quantile_sql(&self, quantile: f64) -> Result<String> {
229 let escaped_column = SqlSecurity::escape_identifier(&self.column)?;
230 Ok(format!(
231 "APPROX_PERCENTILE_CONT({quantile}) WITHIN GROUP (ORDER BY {escaped_column})"
232 ))
233 }
234
235 #[allow(dead_code)]
237 fn exact_quantile_sql(&self, quantile: f64) -> Result<String> {
238 self.approx_quantile_sql(quantile)
241 }
242
243 async fn should_use_exact(&self, ctx: &SessionContext) -> Result<bool> {
245 match &self.method {
246 QuantileMethod::Exact => Ok(true),
247 QuantileMethod::Approximate => Ok(false),
248 QuantileMethod::Auto { threshold } => {
249 let validation_ctx = current_validation_context();
251 let table_name = validation_ctx.table_name();
252
253 let count_sql = format!("SELECT COUNT(*) as cnt FROM {table_name}");
254 let df = ctx.sql(&count_sql).await?;
255 let batches = df.collect().await?;
256
257 if batches.is_empty() || batches[0].num_rows() == 0 {
258 return Ok(true);
259 }
260
261 let count = batches[0]
262 .column(0)
263 .as_any()
264 .downcast_ref::<arrow::array::Int64Array>()
265 .ok_or_else(|| {
266 TermError::Internal("Failed to downcast to Int64Array".to_string())
267 })?
268 .value(0) as usize;
269
270 Ok(count <= *threshold)
271 }
272 }
273 }
274}
275
276#[async_trait]
277impl Constraint for QuantileConstraint {
278 #[instrument(skip(self, ctx), fields(
279 column = %self.column,
280 validation = ?self.validation
281 ))]
282 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
283 let _use_exact = self.should_use_exact(ctx).await?;
284
285 match &self.validation {
286 QuantileValidation::Single(check) => {
287 let validation_ctx = current_validation_context();
291
292 let table_name = validation_ctx.table_name();
293
294 let sql = format!(
295 "SELECT {} as q_value FROM {table_name}",
296 self.approx_quantile_sql(check.quantile)?
297 );
298
299 debug!("Quantile SQL: {}", sql);
300 let df = ctx.sql(&sql).await?;
301 let batches = df.collect().await?;
302
303 if batches.is_empty() || batches[0].num_rows() == 0 {
304 return Ok(ConstraintResult::skipped("No data to validate"));
305 }
306
307 let column = batches[0].column(0);
309 let value = if let Some(arr) =
310 column.as_any().downcast_ref::<arrow::array::Float64Array>()
311 {
312 arr.value(0)
313 } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Int64Array>()
314 {
315 arr.value(0) as f64
316 } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Int32Array>()
317 {
318 arr.value(0) as f64
319 } else {
320 return Err(TermError::TypeMismatch {
321 expected: "Float64, Int64, or Int32".to_string(),
322 found: format!("{:?}", column.data_type()),
323 });
324 };
325
326 if check.assertion.evaluate(value) {
327 Ok(ConstraintResult::success_with_metric(value))
328 } else {
329 Ok(ConstraintResult::failure_with_metric(
330 value,
331 format!(
332 "Quantile {} is {value} which does not {}",
333 check.quantile, check.assertion
334 ),
335 ))
336 }
337 }
338 QuantileValidation::Multiple(checks) => {
339 let sql_parts: Vec<String> = checks
341 .iter()
342 .enumerate()
343 .map(|(i, check)| {
344 self.approx_quantile_sql(check.quantile)
346 .map(|q_sql| format!("{q_sql} as q_{i}"))
347 })
348 .collect::<Result<Vec<_>>>()?;
349
350 let parts = sql_parts.join(", ");
351 let validation_ctx = current_validation_context();
353 let table_name = validation_ctx.table_name();
354
355 let sql = format!("SELECT {parts} FROM {table_name}");
356 let df = ctx.sql(&sql).await?;
357 let batches = df.collect().await?;
358
359 if batches.is_empty() || batches[0].num_rows() == 0 {
360 return Ok(ConstraintResult::skipped("No data to validate"));
361 }
362
363 let mut failures = Vec::new();
364 let batch = &batches[0];
365
366 for (i, check) in checks.iter().enumerate() {
367 let column = batch.column(i);
368 let value = if let Some(arr) =
369 column.as_any().downcast_ref::<arrow::array::Float64Array>()
370 {
371 arr.value(0)
372 } else if let Some(arr) =
373 column.as_any().downcast_ref::<arrow::array::Int64Array>()
374 {
375 arr.value(0) as f64
376 } else if let Some(arr) =
377 column.as_any().downcast_ref::<arrow::array::Int32Array>()
378 {
379 arr.value(0) as f64
380 } else {
381 return Err(TermError::TypeMismatch {
382 expected: "Float64, Int64, or Int32".to_string(),
383 found: format!("{:?}", column.data_type()),
384 });
385 };
386
387 if !check.assertion.evaluate(value) {
388 let q_pct = (check.quantile * 100.0) as i32;
389 failures.push(format!(
390 "Q{q_pct} is {value} which does not {}",
391 check.assertion
392 ));
393 }
394 }
395
396 if failures.is_empty() {
397 Ok(ConstraintResult::success())
398 } else {
399 Ok(ConstraintResult::failure(failures.join("; ")))
400 }
401 }
402 QuantileValidation::Monotonic { quantiles, strict } => {
403 let sql_parts: Vec<String> = quantiles
405 .iter()
406 .enumerate()
407 .map(|(i, q)| {
408 self.approx_quantile_sql(*q)
410 .map(|q_sql| format!("{q_sql} as q_{i}"))
411 })
412 .collect::<Result<Vec<_>>>()?;
413
414 let parts = sql_parts.join(", ");
415 let validation_ctx = current_validation_context();
417 let table_name = validation_ctx.table_name();
418
419 let sql = format!("SELECT {parts} FROM {table_name}");
420 let df = ctx.sql(&sql).await?;
421 let batches = df.collect().await?;
422
423 if batches.is_empty() || batches[0].num_rows() == 0 {
424 return Ok(ConstraintResult::skipped("No data to validate"));
425 }
426
427 let batch = &batches[0];
428 let mut values = Vec::new();
429
430 for i in 0..quantiles.len() {
431 let column = batch.column(i);
432 let value = if let Some(arr) =
433 column.as_any().downcast_ref::<arrow::array::Float64Array>()
434 {
435 arr.value(0)
436 } else if let Some(arr) =
437 column.as_any().downcast_ref::<arrow::array::Int64Array>()
438 {
439 arr.value(0) as f64
440 } else if let Some(arr) =
441 column.as_any().downcast_ref::<arrow::array::Int32Array>()
442 {
443 arr.value(0) as f64
444 } else {
445 return Err(TermError::TypeMismatch {
446 expected: "Float64, Int64, or Int32".to_string(),
447 found: format!("{:?}", column.data_type()),
448 });
449 };
450 values.push(value);
451 }
452
453 let mut is_monotonic = true;
455 for i in 1..values.len() {
456 if *strict {
457 if values[i] <= values[i - 1] {
458 is_monotonic = false;
459 break;
460 }
461 } else if values[i] < values[i - 1] {
462 is_monotonic = false;
463 break;
464 }
465 }
466
467 if is_monotonic {
468 Ok(ConstraintResult::success())
469 } else {
470 let monotonic_type = if *strict { "strictly" } else { "" };
471 Ok(ConstraintResult::failure(format!(
472 "Quantiles are not {monotonic_type} monotonic: {values:?}"
473 )))
474 }
475 }
476 _ => {
477 Ok(ConstraintResult::skipped(
479 "Validation type not yet implemented",
480 ))
481 }
482 }
483 }
484
485 fn name(&self) -> &str {
486 "quantile"
487 }
488
489 fn metadata(&self) -> ConstraintMetadata {
490 ConstraintMetadata::for_column(&self.column).with_description(format!(
491 "Validates quantile properties for column '{}'",
492 self.column
493 ))
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use crate::core::ConstraintStatus;
501 use arrow::array::Float64Array;
502 use arrow::datatypes::{DataType, Field, Schema};
503 use arrow::record_batch::RecordBatch;
504 use datafusion::datasource::MemTable;
505 use std::sync::Arc;
506
507 use crate::test_helpers::evaluate_constraint_with_context;
508 async fn create_test_context(values: Vec<Option<f64>>) -> SessionContext {
509 let ctx = SessionContext::new();
510
511 let schema = Arc::new(Schema::new(vec![Field::new(
512 "value",
513 DataType::Float64,
514 true,
515 )]));
516
517 let array = Float64Array::from(values);
518 let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
519
520 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
521 ctx.register_table("data", Arc::new(provider)).unwrap();
522
523 ctx
524 }
525
526 #[tokio::test]
527 async fn test_median_check() {
528 let values: Vec<Option<f64>> = (1..=100).map(|i| Some(i as f64)).collect();
529 let ctx = create_test_context(values).await;
530
531 let constraint =
532 QuantileConstraint::median("value", Assertion::Between(45.0, 55.0)).unwrap();
533
534 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
535 .await
536 .unwrap();
537 assert_eq!(result.status, ConstraintStatus::Success);
538 }
539
540 #[tokio::test]
541 async fn test_percentile_check() {
542 let values: Vec<Option<f64>> = (1..=100).map(|i| Some(i as f64)).collect();
543 let ctx = create_test_context(values).await;
544
545 let constraint =
546 QuantileConstraint::percentile("value", 0.95, Assertion::Between(94.0, 96.0)).unwrap();
547
548 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
549 .await
550 .unwrap();
551 assert_eq!(result.status, ConstraintStatus::Success);
552 }
553
554 #[tokio::test]
555 async fn test_multiple_quantiles() {
556 let values: Vec<Option<f64>> = (1..=100).map(|i| Some(i as f64)).collect();
557 let ctx = create_test_context(values).await;
558
559 let constraint = QuantileConstraint::multiple(
560 "value",
561 vec![
562 QuantileCheck::new(0.25, Assertion::Between(24.0, 26.0)).unwrap(),
563 QuantileCheck::new(0.75, Assertion::Between(74.0, 76.0)).unwrap(),
564 ],
565 )
566 .unwrap();
567
568 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
569 .await
570 .unwrap();
571 assert_eq!(result.status, ConstraintStatus::Success);
572 }
573
574 #[tokio::test]
575 async fn test_monotonic_check() {
576 let values: Vec<Option<f64>> = (1..=100).map(|i| Some(i as f64)).collect();
577 let ctx = create_test_context(values).await;
578
579 let constraint = QuantileConstraint::new(
580 "value",
581 QuantileValidation::Monotonic {
582 quantiles: vec![0.1, 0.5, 0.9],
583 strict: true,
584 },
585 )
586 .unwrap();
587
588 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
589 .await
590 .unwrap();
591 assert_eq!(result.status, ConstraintStatus::Success);
592 }
593
594 #[test]
595 fn test_invalid_quantile() {
596 let result = QuantileCheck::new(1.5, Assertion::LessThan(100.0));
597 assert!(result.is_err());
598 assert!(result
599 .unwrap_err()
600 .to_string()
601 .contains("Quantile must be between 0.0 and 1.0"));
602 }
603}