1use crate::constraints::Assertion;
10use crate::core::{Constraint, ConstraintMetadata, ConstraintResult};
11use crate::prelude::*;
12use crate::security::SqlSecurity;
13use arrow::array::Array;
14use async_trait::async_trait;
15use datafusion::prelude::*;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use tracing::{debug, instrument};
19
20#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
22pub enum QuantileMethod {
23 Approximate,
25 Exact,
27 Auto { threshold: usize },
29}
30
31impl Default for QuantileMethod {
32 fn default() -> Self {
33 Self::Auto { threshold: 10000 }
34 }
35}
36
37#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
39pub struct QuantileCheck {
40 pub quantile: f64,
42 pub assertion: Assertion,
44}
45
46impl QuantileCheck {
47 pub fn new(quantile: f64, assertion: Assertion) -> Result<Self> {
49 if !(0.0..=1.0).contains(&quantile) {
50 return Err(TermError::Configuration(
51 "Quantile must be between 0.0 and 1.0".to_string(),
52 ));
53 }
54 Ok(Self {
55 quantile,
56 assertion,
57 })
58 }
59}
60
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
63pub struct DistributionConfig {
64 pub quantiles: Vec<f64>,
66 pub include_bounds: bool,
68 pub compute_iqr: bool,
70}
71
72impl Default for DistributionConfig {
73 fn default() -> Self {
74 Self {
75 quantiles: vec![0.25, 0.5, 0.75],
76 include_bounds: true,
77 compute_iqr: true,
78 }
79 }
80}
81
82#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub enum QuantileValidation {
85 Single(QuantileCheck),
87
88 Multiple(Vec<QuantileCheck>),
90
91 Distribution {
93 config: DistributionConfig,
94 iqr_assertion: Option<Assertion>,
96 quantile_assertions: HashMap<String, Assertion>,
98 },
99
100 Monotonic {
102 quantiles: Vec<f64>,
103 strict: bool,
105 },
106
107 Custom {
109 sql_expression: String,
110 assertion: Assertion,
111 },
112}
113
114#[derive(Debug, Clone)]
146pub struct QuantileConstraint {
147 column: String,
149 validation: QuantileValidation,
151 method: QuantileMethod,
153}
154
155impl QuantileConstraint {
156 pub fn new(column: impl Into<String>, validation: QuantileValidation) -> Result<Self> {
167 Self::with_method(column, validation, QuantileMethod::default())
168 }
169
170 pub fn with_method(
172 column: impl Into<String>,
173 validation: QuantileValidation,
174 method: QuantileMethod,
175 ) -> Result<Self> {
176 let column_str = column.into();
177 SqlSecurity::validate_identifier(&column_str)?;
178
179 Ok(Self {
180 column: column_str,
181 validation,
182 method,
183 })
184 }
185
186 pub fn median(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
188 Self::new(
189 column,
190 QuantileValidation::Single(QuantileCheck::new(0.5, assertion)?),
191 )
192 }
193
194 pub fn percentile(
196 column: impl Into<String>,
197 quantile: f64,
198 assertion: Assertion,
199 ) -> Result<Self> {
200 Self::new(
201 column,
202 QuantileValidation::Single(QuantileCheck::new(quantile, assertion)?),
203 )
204 }
205
206 pub fn multiple(column: impl Into<String>, checks: Vec<QuantileCheck>) -> Result<Self> {
208 if checks.is_empty() {
209 return Err(TermError::Configuration(
210 "At least one quantile check is required".to_string(),
211 ));
212 }
213 Self::new(column, QuantileValidation::Multiple(checks))
214 }
215
216 pub fn distribution(column: impl Into<String>, config: DistributionConfig) -> Result<Self> {
218 Self::new(
219 column,
220 QuantileValidation::Distribution {
221 config,
222 iqr_assertion: None,
223 quantile_assertions: HashMap::new(),
224 },
225 )
226 }
227
228 fn approx_quantile_sql(&self, quantile: f64) -> Result<String> {
230 let escaped_column = SqlSecurity::escape_identifier(&self.column)?;
231 Ok(format!(
232 "APPROX_PERCENTILE_CONT({quantile}) WITHIN GROUP (ORDER BY {escaped_column})"
233 ))
234 }
235
236 #[allow(dead_code)]
238 fn exact_quantile_sql(&self, quantile: f64) -> Result<String> {
239 self.approx_quantile_sql(quantile)
242 }
243
244 async fn should_use_exact(&self, ctx: &SessionContext) -> Result<bool> {
246 match &self.method {
247 QuantileMethod::Exact => Ok(true),
248 QuantileMethod::Approximate => Ok(false),
249 QuantileMethod::Auto { threshold } => {
250 let count_sql = "SELECT COUNT(*) as cnt FROM data";
251 let df = ctx.sql(count_sql).await?;
252 let batches = df.collect().await?;
253
254 if batches.is_empty() || batches[0].num_rows() == 0 {
255 return Ok(true);
256 }
257
258 let count = batches[0]
259 .column(0)
260 .as_any()
261 .downcast_ref::<arrow::array::Int64Array>()
262 .ok_or_else(|| {
263 TermError::Internal("Failed to downcast to Int64Array".to_string())
264 })?
265 .value(0) as usize;
266
267 Ok(count <= *threshold)
268 }
269 }
270 }
271}
272
273#[async_trait]
274impl Constraint for QuantileConstraint {
275 #[instrument(skip(self, ctx), fields(
276 column = %self.column,
277 validation = ?self.validation
278 ))]
279 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
280 let _use_exact = self.should_use_exact(ctx).await?;
281
282 match &self.validation {
283 QuantileValidation::Single(check) => {
284 let sql = format!(
286 "SELECT {} as q_value FROM data",
287 self.approx_quantile_sql(check.quantile)?
288 );
289
290 debug!("Quantile SQL: {}", sql);
291 let df = ctx.sql(&sql).await?;
292 let batches = df.collect().await?;
293
294 if batches.is_empty() || batches[0].num_rows() == 0 {
295 return Ok(ConstraintResult::skipped("No data to validate"));
296 }
297
298 let column = batches[0].column(0);
300 let value = if let Some(arr) =
301 column.as_any().downcast_ref::<arrow::array::Float64Array>()
302 {
303 arr.value(0)
304 } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Int64Array>()
305 {
306 arr.value(0) as f64
307 } else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Int32Array>()
308 {
309 arr.value(0) as f64
310 } else {
311 return Err(TermError::TypeMismatch {
312 expected: "Float64, Int64, or Int32".to_string(),
313 found: format!("{:?}", column.data_type()),
314 });
315 };
316
317 if check.assertion.evaluate(value) {
318 Ok(ConstraintResult::success_with_metric(value))
319 } else {
320 Ok(ConstraintResult::failure_with_metric(
321 value,
322 format!(
323 "Quantile {} is {value} which does not {}",
324 check.quantile, check.assertion
325 ),
326 ))
327 }
328 }
329 QuantileValidation::Multiple(checks) => {
330 let sql_parts: Vec<String> = checks
332 .iter()
333 .enumerate()
334 .map(|(i, check)| {
335 self.approx_quantile_sql(check.quantile)
337 .map(|q_sql| format!("{q_sql} as q_{i}"))
338 })
339 .collect::<Result<Vec<_>>>()?;
340
341 let parts = sql_parts.join(", ");
342 let sql = format!("SELECT {parts} FROM data");
343 let df = ctx.sql(&sql).await?;
344 let batches = df.collect().await?;
345
346 if batches.is_empty() || batches[0].num_rows() == 0 {
347 return Ok(ConstraintResult::skipped("No data to validate"));
348 }
349
350 let mut failures = Vec::new();
351 let batch = &batches[0];
352
353 for (i, check) in checks.iter().enumerate() {
354 let column = batch.column(i);
355 let value = if let Some(arr) =
356 column.as_any().downcast_ref::<arrow::array::Float64Array>()
357 {
358 arr.value(0)
359 } else if let Some(arr) =
360 column.as_any().downcast_ref::<arrow::array::Int64Array>()
361 {
362 arr.value(0) as f64
363 } else if let Some(arr) =
364 column.as_any().downcast_ref::<arrow::array::Int32Array>()
365 {
366 arr.value(0) as f64
367 } else {
368 return Err(TermError::TypeMismatch {
369 expected: "Float64, Int64, or Int32".to_string(),
370 found: format!("{:?}", column.data_type()),
371 });
372 };
373
374 if !check.assertion.evaluate(value) {
375 let q_pct = (check.quantile * 100.0) as i32;
376 failures.push(format!(
377 "Q{q_pct} is {value} which does not {}",
378 check.assertion
379 ));
380 }
381 }
382
383 if failures.is_empty() {
384 Ok(ConstraintResult::success())
385 } else {
386 Ok(ConstraintResult::failure(failures.join("; ")))
387 }
388 }
389 QuantileValidation::Monotonic { quantiles, strict } => {
390 let sql_parts: Vec<String> = quantiles
392 .iter()
393 .enumerate()
394 .map(|(i, q)| {
395 self.approx_quantile_sql(*q)
397 .map(|q_sql| format!("{q_sql} as q_{i}"))
398 })
399 .collect::<Result<Vec<_>>>()?;
400
401 let parts = sql_parts.join(", ");
402 let sql = format!("SELECT {parts} FROM data");
403 let df = ctx.sql(&sql).await?;
404 let batches = df.collect().await?;
405
406 if batches.is_empty() || batches[0].num_rows() == 0 {
407 return Ok(ConstraintResult::skipped("No data to validate"));
408 }
409
410 let batch = &batches[0];
411 let mut values = Vec::new();
412
413 for i in 0..quantiles.len() {
414 let column = batch.column(i);
415 let value = if let Some(arr) =
416 column.as_any().downcast_ref::<arrow::array::Float64Array>()
417 {
418 arr.value(0)
419 } else if let Some(arr) =
420 column.as_any().downcast_ref::<arrow::array::Int64Array>()
421 {
422 arr.value(0) as f64
423 } else if let Some(arr) =
424 column.as_any().downcast_ref::<arrow::array::Int32Array>()
425 {
426 arr.value(0) as f64
427 } else {
428 return Err(TermError::TypeMismatch {
429 expected: "Float64, Int64, or Int32".to_string(),
430 found: format!("{:?}", column.data_type()),
431 });
432 };
433 values.push(value);
434 }
435
436 let mut is_monotonic = true;
438 for i in 1..values.len() {
439 if *strict {
440 if values[i] <= values[i - 1] {
441 is_monotonic = false;
442 break;
443 }
444 } else if values[i] < values[i - 1] {
445 is_monotonic = false;
446 break;
447 }
448 }
449
450 if is_monotonic {
451 Ok(ConstraintResult::success())
452 } else {
453 let monotonic_type = if *strict { "strictly" } else { "" };
454 Ok(ConstraintResult::failure(format!(
455 "Quantiles are not {monotonic_type} monotonic: {values:?}"
456 )))
457 }
458 }
459 _ => {
460 Ok(ConstraintResult::skipped(
462 "Validation type not yet implemented",
463 ))
464 }
465 }
466 }
467
468 fn name(&self) -> &str {
469 "quantile"
470 }
471
472 fn metadata(&self) -> ConstraintMetadata {
473 ConstraintMetadata::for_column(&self.column).with_description(format!(
474 "Validates quantile properties for column '{}'",
475 self.column
476 ))
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use crate::core::ConstraintStatus;
484 use arrow::array::Float64Array;
485 use arrow::datatypes::{DataType, Field, Schema};
486 use arrow::record_batch::RecordBatch;
487 use datafusion::datasource::MemTable;
488 use std::sync::Arc;
489
490 async fn create_test_context(values: Vec<Option<f64>>) -> SessionContext {
491 let ctx = SessionContext::new();
492
493 let schema = Arc::new(Schema::new(vec![Field::new(
494 "value",
495 DataType::Float64,
496 true,
497 )]));
498
499 let array = Float64Array::from(values);
500 let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
501
502 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
503 ctx.register_table("data", Arc::new(provider)).unwrap();
504
505 ctx
506 }
507
508 #[tokio::test]
509 async fn test_median_check() {
510 let values: Vec<Option<f64>> = (1..=100).map(|i| Some(i as f64)).collect();
511 let ctx = create_test_context(values).await;
512
513 let constraint =
514 QuantileConstraint::median("value", Assertion::Between(45.0, 55.0)).unwrap();
515
516 let result = constraint.evaluate(&ctx).await.unwrap();
517 assert_eq!(result.status, ConstraintStatus::Success);
518 }
519
520 #[tokio::test]
521 async fn test_percentile_check() {
522 let values: Vec<Option<f64>> = (1..=100).map(|i| Some(i as f64)).collect();
523 let ctx = create_test_context(values).await;
524
525 let constraint =
526 QuantileConstraint::percentile("value", 0.95, Assertion::Between(94.0, 96.0)).unwrap();
527
528 let result = constraint.evaluate(&ctx).await.unwrap();
529 assert_eq!(result.status, ConstraintStatus::Success);
530 }
531
532 #[tokio::test]
533 async fn test_multiple_quantiles() {
534 let values: Vec<Option<f64>> = (1..=100).map(|i| Some(i as f64)).collect();
535 let ctx = create_test_context(values).await;
536
537 let constraint = QuantileConstraint::multiple(
538 "value",
539 vec![
540 QuantileCheck::new(0.25, Assertion::Between(24.0, 26.0)).unwrap(),
541 QuantileCheck::new(0.75, Assertion::Between(74.0, 76.0)).unwrap(),
542 ],
543 )
544 .unwrap();
545
546 let result = constraint.evaluate(&ctx).await.unwrap();
547 assert_eq!(result.status, ConstraintStatus::Success);
548 }
549
550 #[tokio::test]
551 async fn test_monotonic_check() {
552 let values: Vec<Option<f64>> = (1..=100).map(|i| Some(i as f64)).collect();
553 let ctx = create_test_context(values).await;
554
555 let constraint = QuantileConstraint::new(
556 "value",
557 QuantileValidation::Monotonic {
558 quantiles: vec![0.1, 0.5, 0.9],
559 strict: true,
560 },
561 )
562 .unwrap();
563
564 let result = constraint.evaluate(&ctx).await.unwrap();
565 assert_eq!(result.status, ConstraintStatus::Success);
566 }
567
568 #[test]
569 fn test_invalid_quantile() {
570 let result = QuantileCheck::new(1.5, Assertion::LessThan(100.0));
571 assert!(result.is_err());
572 assert!(result
573 .unwrap_err()
574 .to_string()
575 .contains("Quantile must be between 0.0 and 1.0"));
576 }
577}