1use crate::core::{
4 current_validation_context, Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus,
5};
6use crate::prelude::*;
7use arrow::array::{Array, LargeStringArray, StringViewArray};
8use async_trait::async_trait;
9use datafusion::prelude::*;
10use std::fmt;
11use std::sync::Arc;
12use tracing::instrument;
13#[derive(Debug, Clone, PartialEq)]
15pub struct HistogramBucket {
16 pub value: String,
18 pub count: i64,
20 pub ratio: f64,
22}
23
24#[derive(Debug, Clone)]
26pub struct Histogram {
27 pub buckets: Vec<HistogramBucket>,
29 pub total_count: i64,
31 pub distinct_count: usize,
33 pub null_count: i64,
35}
36
37impl Histogram {
38 pub fn new(buckets: Vec<HistogramBucket>, total_count: i64, null_count: i64) -> Self {
40 let distinct_count = buckets.len();
41 Self {
42 buckets,
43 total_count,
44 distinct_count,
45 null_count,
46 }
47 }
48
49 pub fn most_common_ratio(&self) -> f64 {
51 self.buckets.first().map(|b| b.ratio).unwrap_or(0.0)
52 }
53
54 pub fn least_common_ratio(&self) -> f64 {
56 self.buckets.last().map(|b| b.ratio).unwrap_or(0.0)
57 }
58
59 pub fn bucket_count(&self) -> usize {
61 self.buckets.len()
62 }
63
64 pub fn top_n(&self, n: usize) -> Vec<(&str, f64)> {
66 self.buckets
67 .iter()
68 .take(n)
69 .map(|b| (b.value.as_str(), b.ratio))
70 .collect()
71 }
72
73 pub fn is_roughly_uniform(&self, threshold: f64) -> bool {
78 if self.buckets.is_empty() {
79 return true;
80 }
81
82 let max_ratio = self.most_common_ratio();
83 let min_ratio = self.least_common_ratio();
84
85 if min_ratio == 0.0 {
86 return false;
87 }
88
89 max_ratio / min_ratio <= threshold
90 }
91
92 pub fn get_value_ratio(&self, value: &str) -> Option<f64> {
94 self.buckets
95 .iter()
96 .find(|b| b.value == value)
97 .map(|b| b.ratio)
98 }
99
100 pub fn entropy(&self) -> f64 {
104 self.buckets
105 .iter()
106 .filter(|b| b.ratio > 0.0)
107 .map(|b| -b.ratio * b.ratio.ln())
108 .sum()
109 }
110
111 pub fn follows_power_law(&self, top_n: usize, threshold: f64) -> bool {
115 let top_sum: f64 = self.buckets.iter().take(top_n).map(|b| b.ratio).sum();
116 top_sum >= threshold
117 }
118
119 pub fn null_ratio(&self) -> f64 {
121 if self.total_count == 0 {
122 0.0
123 } else {
124 self.null_count as f64 / self.total_count as f64
125 }
126 }
127}
128
129pub type HistogramAssertion = Arc<dyn Fn(&Histogram) -> bool + Send + Sync>;
131
132#[derive(Clone)]
155pub struct HistogramConstraint {
156 column: String,
157 assertion: HistogramAssertion,
158 assertion_description: String,
159}
160
161impl fmt::Debug for HistogramConstraint {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 f.debug_struct("HistogramConstraint")
164 .field("column", &self.column)
165 .field("assertion_description", &self.assertion_description)
166 .finish()
167 }
168}
169
170impl HistogramConstraint {
171 pub fn new(column: impl Into<String>, assertion: HistogramAssertion) -> Self {
178 Self {
179 column: column.into(),
180 assertion,
181 assertion_description: "custom assertion".to_string(),
182 }
183 }
184
185 pub fn new_with_description(
193 column: impl Into<String>,
194 assertion: HistogramAssertion,
195 description: impl Into<String>,
196 ) -> Self {
197 Self {
198 column: column.into(),
199 assertion,
200 assertion_description: description.into(),
201 }
202 }
203}
204
205#[async_trait]
206impl Constraint for HistogramConstraint {
207 #[instrument(skip(self, ctx), fields(column = %self.column))]
208 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
209 let validation_ctx = current_validation_context();
211 let table_name = validation_ctx.table_name();
212
213 let sql = format!(
215 r#"
216 WITH value_counts AS (
217 SELECT
218 CAST({} AS VARCHAR) as value,
219 COUNT(*) as count
220 FROM {table_name}
221 WHERE {} IS NOT NULL
222 GROUP BY {}
223 ),
224 totals AS (
225 SELECT
226 COUNT(*) as total_cnt,
227 SUM(CASE WHEN {} IS NULL THEN 1 ELSE 0 END) as null_cnt
228 FROM {table_name}
229 )
230 SELECT
231 vc.value,
232 vc.count,
233 vc.count * 1.0 / (t.total_cnt - t.null_cnt) as ratio,
234 t.total_cnt as total_count,
235 t.null_cnt as null_count
236 FROM value_counts vc
237 CROSS JOIN totals t
238 ORDER BY vc.count DESC, vc.value
239 "#,
240 self.column, self.column, self.column, self.column
241 );
242
243 let df = ctx.sql(&sql).await.map_err(|e| {
244 TermError::constraint_evaluation(
245 self.name(),
246 format!("Failed to execute histogram query: {e}"),
247 )
248 })?;
249
250 let batches = df.collect().await?;
251
252 if batches.is_empty() || batches[0].num_rows() == 0 {
253 return Ok(ConstraintResult::skipped("No data to analyze"));
254 }
255
256 let mut buckets = Vec::new();
258 let mut total_count = 0i64;
259 let mut null_count = 0i64;
260
261 for batch in &batches {
262 let values_col = batch.column(0);
264 let value_strings: Vec<String> = match values_col.data_type() {
265 arrow::datatypes::DataType::Utf8 => {
266 let arr = values_col
267 .as_any()
268 .downcast_ref::<arrow::array::StringArray>()
269 .ok_or_else(|| {
270 TermError::constraint_evaluation(
271 self.name(),
272 "Failed to extract string values",
273 )
274 })?;
275 (0..arr.len()).map(|i| arr.value(i).to_string()).collect()
276 }
277 arrow::datatypes::DataType::Utf8View => {
278 let arr = values_col
279 .as_any()
280 .downcast_ref::<StringViewArray>()
281 .ok_or_else(|| {
282 TermError::constraint_evaluation(
283 self.name(),
284 "Failed to extract string view values",
285 )
286 })?;
287 (0..arr.len()).map(|i| arr.value(i).to_string()).collect()
288 }
289 arrow::datatypes::DataType::LargeUtf8 => {
290 let arr = values_col
291 .as_any()
292 .downcast_ref::<LargeStringArray>()
293 .ok_or_else(|| {
294 TermError::constraint_evaluation(
295 self.name(),
296 "Failed to extract large string values",
297 )
298 })?;
299 (0..arr.len()).map(|i| arr.value(i).to_string()).collect()
300 }
301 _ => {
302 return Err(TermError::constraint_evaluation(
303 self.name(),
304 format!("Unexpected value column type: {:?}", values_col.data_type()),
305 ));
306 }
307 };
308
309 let count_array = batch
310 .column(1)
311 .as_any()
312 .downcast_ref::<arrow::array::Int64Array>()
313 .ok_or_else(|| {
314 TermError::constraint_evaluation(self.name(), "Failed to extract counts")
315 })?;
316
317 let ratio_array = batch
318 .column(2)
319 .as_any()
320 .downcast_ref::<arrow::array::Float64Array>()
321 .ok_or_else(|| {
322 TermError::constraint_evaluation(self.name(), "Failed to extract ratios")
323 })?;
324
325 let total_array = batch
326 .column(3)
327 .as_any()
328 .downcast_ref::<arrow::array::Int64Array>()
329 .ok_or_else(|| {
330 TermError::constraint_evaluation(self.name(), "Failed to extract total count")
331 })?;
332
333 let null_array = batch
334 .column(4)
335 .as_any()
336 .downcast_ref::<arrow::array::Int64Array>()
337 .ok_or_else(|| {
338 TermError::constraint_evaluation(self.name(), "Failed to extract null count")
339 })?;
340
341 if batch.num_rows() > 0 {
343 total_count = total_array.value(0);
344 null_count = null_array.value(0);
345 }
346
347 for (i, value) in value_strings.into_iter().enumerate() {
349 let count = count_array.value(i);
350 let ratio = ratio_array.value(i);
351
352 buckets.push(HistogramBucket {
353 value,
354 count,
355 ratio,
356 });
357 }
358 }
359
360 let histogram = Histogram::new(buckets, total_count, null_count);
362
363 let assertion_result = (self.assertion)(&histogram);
365
366 let status = if assertion_result {
367 ConstraintStatus::Success
368 } else {
369 ConstraintStatus::Failure
370 };
371
372 let message = if status == ConstraintStatus::Failure {
373 let most_common_pct = histogram.most_common_ratio() * 100.0;
374 let null_pct = histogram.null_ratio() * 100.0;
375 Some(format!(
376 "Histogram assertion '{}' failed for column '{}'. Distribution: {} distinct values, most common ratio: {most_common_pct:.2}%, null ratio: {null_pct:.2}%",
377 self.assertion_description,
378 self.column,
379 histogram.distinct_count
380 ))
381 } else {
382 None
383 };
384
385 Ok(ConstraintResult {
387 status,
388 metric: Some(histogram.entropy()),
389 message,
390 })
391 }
392
393 fn name(&self) -> &str {
394 "histogram"
395 }
396
397 fn column(&self) -> Option<&str> {
398 Some(&self.column)
399 }
400
401 fn metadata(&self) -> ConstraintMetadata {
402 ConstraintMetadata::for_column(&self.column)
403 .with_description(format!(
404 "Analyzes value distribution in column '{}' and applies assertion: {}",
405 self.column, self.assertion_description
406 ))
407 .with_custom("assertion", &self.assertion_description)
408 .with_custom("constraint_type", "histogram")
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use crate::core::ConstraintStatus;
416 use arrow::array::StringArray;
417 use arrow::datatypes::{DataType, Field, Schema};
418 use arrow::record_batch::RecordBatch;
419 use datafusion::datasource::MemTable;
420 use std::sync::Arc;
421
422 use crate::test_helpers::evaluate_constraint_with_context;
423 async fn create_test_context_with_data(values: Vec<Option<&str>>) -> SessionContext {
424 let ctx = SessionContext::new();
425
426 let schema = Arc::new(Schema::new(vec![Field::new(
427 "test_col",
428 DataType::Utf8,
429 true,
430 )]));
431
432 let array = StringArray::from(values);
433 let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
434
435 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
436 ctx.register_table("data", Arc::new(provider)).unwrap();
437
438 ctx
439 }
440
441 #[test]
442 fn test_histogram_basic() {
443 let buckets = vec![
444 HistogramBucket {
445 value: "A".to_string(),
446 count: 50,
447 ratio: 0.5,
448 },
449 HistogramBucket {
450 value: "B".to_string(),
451 count: 30,
452 ratio: 0.3,
453 },
454 HistogramBucket {
455 value: "C".to_string(),
456 count: 20,
457 ratio: 0.2,
458 },
459 ];
460
461 let histogram = Histogram::new(buckets, 100, 0);
462
463 assert_eq!(histogram.most_common_ratio(), 0.5);
464 assert_eq!(histogram.least_common_ratio(), 0.2);
465 assert_eq!(histogram.bucket_count(), 3);
466 assert_eq!(histogram.null_ratio(), 0.0);
467 }
468
469 #[test]
470 fn test_histogram_entropy() {
471 let uniform_buckets = vec![
473 HistogramBucket {
474 value: "A".to_string(),
475 count: 25,
476 ratio: 0.25,
477 },
478 HistogramBucket {
479 value: "B".to_string(),
480 count: 25,
481 ratio: 0.25,
482 },
483 HistogramBucket {
484 value: "C".to_string(),
485 count: 25,
486 ratio: 0.25,
487 },
488 HistogramBucket {
489 value: "D".to_string(),
490 count: 25,
491 ratio: 0.25,
492 },
493 ];
494
495 let uniform_hist = Histogram::new(uniform_buckets, 100, 0);
496
497 let skewed_buckets = vec![
499 HistogramBucket {
500 value: "A".to_string(),
501 count: 90,
502 ratio: 0.9,
503 },
504 HistogramBucket {
505 value: "B".to_string(),
506 count: 10,
507 ratio: 0.1,
508 },
509 ];
510
511 let skewed_hist = Histogram::new(skewed_buckets, 100, 0);
512
513 assert!(uniform_hist.entropy() > skewed_hist.entropy());
514 }
515
516 #[tokio::test]
517 async fn test_most_common_ratio_constraint() {
518 let values = vec![
520 Some("A"),
521 Some("A"),
522 Some("A"),
523 Some("A"),
524 Some("A"),
525 Some("A"),
526 Some("B"),
527 Some("B"),
528 Some("C"),
529 Some("C"),
530 ];
531 let ctx = create_test_context_with_data(values).await;
532
533 let constraint = HistogramConstraint::new_with_description(
535 "test_col",
536 Arc::new(|hist| hist.most_common_ratio() < 0.5),
537 "most common value appears less than 50%",
538 );
539
540 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
541 .await
542 .unwrap();
543 assert_eq!(result.status, ConstraintStatus::Failure);
544 assert!(result.message.is_some());
545
546 let constraint =
548 HistogramConstraint::new("test_col", Arc::new(|hist| hist.most_common_ratio() < 0.7));
549
550 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
551 .await
552 .unwrap();
553 assert_eq!(result.status, ConstraintStatus::Success);
554 }
555
556 #[tokio::test]
557 async fn test_bucket_count_constraint() {
558 let values = vec![
560 Some("RED"),
561 Some("BLUE"),
562 Some("GREEN"),
563 Some("YELLOW"),
564 Some("RED"),
565 Some("BLUE"),
566 ];
567 let ctx = create_test_context_with_data(values).await;
568
569 let constraint = HistogramConstraint::new_with_description(
570 "test_col",
571 Arc::new(|hist| hist.bucket_count() >= 3 && hist.bucket_count() <= 5),
572 "has between 3 and 5 distinct values",
573 );
574
575 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
576 .await
577 .unwrap();
578 assert_eq!(result.status, ConstraintStatus::Success);
579 }
580
581 #[tokio::test]
582 async fn test_uniform_distribution_check() {
583 let values = vec![
585 Some("A"),
586 Some("A"),
587 Some("B"),
588 Some("B"),
589 Some("C"),
590 Some("C"),
591 Some("D"),
592 Some("D"),
593 ];
594 let ctx = create_test_context_with_data(values).await;
595
596 let constraint =
597 HistogramConstraint::new("test_col", Arc::new(|hist| hist.is_roughly_uniform(1.5)));
598
599 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
600 .await
601 .unwrap();
602 assert_eq!(result.status, ConstraintStatus::Success);
603 }
604
605 #[tokio::test]
606 async fn test_power_law_distribution() {
607 let values = vec![
609 Some("Popular1"),
610 Some("Popular1"),
611 Some("Popular1"),
612 Some("Popular1"),
613 Some("Popular2"),
614 Some("Popular2"),
615 Some("Popular2"),
616 Some("Rare1"),
617 Some("Rare2"),
618 Some("Rare3"),
619 ];
620 let ctx = create_test_context_with_data(values).await;
621
622 let constraint = HistogramConstraint::new_with_description(
623 "test_col",
624 Arc::new(|hist| hist.follows_power_law(2, 0.7)),
625 "top 2 values account for 70% of distribution",
626 );
627
628 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
629 .await
630 .unwrap();
631 assert_eq!(result.status, ConstraintStatus::Success);
632 }
633
634 #[tokio::test]
635 async fn test_with_nulls() {
636 let values = vec![
637 Some("A"),
638 Some("A"),
639 None,
640 None,
641 None,
642 Some("B"),
643 Some("B"),
644 Some("C"),
645 ];
646 let ctx = create_test_context_with_data(values).await;
647
648 let constraint = HistogramConstraint::new(
649 "test_col",
650 Arc::new(|hist| hist.null_ratio() > 0.3 && hist.null_ratio() < 0.4),
651 );
652
653 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
654 .await
655 .unwrap();
656 assert_eq!(result.status, ConstraintStatus::Success);
657 }
658
659 #[tokio::test]
660 async fn test_empty_data() {
661 let ctx = create_test_context_with_data(vec![]).await;
662
663 let constraint = HistogramConstraint::new("test_col", Arc::new(|_| true));
664
665 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
666 .await
667 .unwrap();
668 assert_eq!(result.status, ConstraintStatus::Skipped);
669 }
670
671 #[tokio::test]
672 async fn test_specific_value_check() {
673 let values = vec![
674 Some("PENDING"),
675 Some("PENDING"),
676 Some("APPROVED"),
677 Some("APPROVED"),
678 Some("APPROVED"),
679 Some("REJECTED"),
680 ];
681 let ctx = create_test_context_with_data(values).await;
682
683 let constraint = HistogramConstraint::new_with_description(
684 "test_col",
685 Arc::new(|hist| {
686 hist.get_value_ratio("APPROVED").unwrap_or(0.0) > 0.4
688 }),
689 "APPROVED status is most common",
690 );
691
692 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
693 .await
694 .unwrap();
695 assert_eq!(result.status, ConstraintStatus::Success);
696 }
697
698 #[tokio::test]
699 async fn test_top_n_values() {
700 let values = vec![
701 Some("A"),
702 Some("A"),
703 Some("A"),
704 Some("A"), Some("B"),
706 Some("B"),
707 Some("B"), Some("C"),
709 Some("C"), Some("D"), ];
712 let ctx = create_test_context_with_data(values).await;
713
714 let constraint = HistogramConstraint::new(
715 "test_col",
716 Arc::new(|hist| {
717 let top_2 = hist.top_n(2);
718 top_2.len() == 2 && top_2[0].1 == 0.4 && top_2[1].1 == 0.3
719 }),
720 );
721
722 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
723 .await
724 .unwrap();
725 assert_eq!(result.status, ConstraintStatus::Success);
726 }
727
728 #[tokio::test]
729 async fn test_numeric_data_histogram() {
730 use arrow::array::Int64Array;
731 use arrow::datatypes::{DataType, Field, Schema};
732
733 let ctx = SessionContext::new();
734
735 let schema = Arc::new(Schema::new(vec![Field::new("age", DataType::Int64, true)]));
736
737 let values = vec![
738 Some(25),
739 Some(25),
740 Some(30),
741 Some(30),
742 Some(30),
743 Some(35),
744 Some(35),
745 Some(40),
746 Some(45),
747 Some(50),
748 ];
749 let array = Int64Array::from(values);
750 let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
751
752 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
753 ctx.register_table("data", Arc::new(provider)).unwrap();
754
755 let constraint = HistogramConstraint::new_with_description(
756 "age",
757 Arc::new(|hist| {
758 hist.bucket_count() >= 5 && hist.most_common_ratio() < 0.4
760 }),
761 "age distribution is reasonable",
762 );
763
764 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
765 .await
766 .unwrap();
767 assert_eq!(result.status, ConstraintStatus::Success);
768 }
769}