term_guard/constraints/
histogram.rs

1//! Histogram analysis constraint for value distribution analysis.
2
3use crate::core::{
4    current_validation_context, Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus,
5};
6use crate::prelude::*;
7use arrow::array::{Array, LargeStringArray, StringViewArray};
8use async_trait::async_trait;
9use datafusion::prelude::*;
10use std::fmt;
11use std::sync::Arc;
12use tracing::instrument;
13/// A bucket in a histogram representing a value and its frequency information.
14#[derive(Debug, Clone, PartialEq)]
15pub struct HistogramBucket {
16    /// The value in this bucket
17    pub value: String,
18    /// The count of occurrences
19    pub count: i64,
20    /// The ratio of this value to the total count
21    pub ratio: f64,
22}
23
24/// A histogram representing the distribution of values in a column.
25#[derive(Debug, Clone)]
26pub struct Histogram {
27    /// The buckets in the histogram, ordered by frequency (descending)
28    pub buckets: Vec<HistogramBucket>,
29    /// Total number of values (including nulls if present)
30    pub total_count: i64,
31    /// Number of distinct values
32    pub distinct_count: usize,
33    /// Number of null values
34    pub null_count: i64,
35}
36
37impl Histogram {
38    /// Creates a new histogram from buckets.
39    pub fn new(buckets: Vec<HistogramBucket>, total_count: i64, null_count: i64) -> Self {
40        let distinct_count = buckets.len();
41        Self {
42            buckets,
43            total_count,
44            distinct_count,
45            null_count,
46        }
47    }
48
49    /// Returns the ratio of the most common value.
50    pub fn most_common_ratio(&self) -> f64 {
51        self.buckets.first().map(|b| b.ratio).unwrap_or(0.0)
52    }
53
54    /// Returns the ratio of the least common value.
55    pub fn least_common_ratio(&self) -> f64 {
56        self.buckets.last().map(|b| b.ratio).unwrap_or(0.0)
57    }
58
59    /// Returns the number of buckets (distinct values).
60    pub fn bucket_count(&self) -> usize {
61        self.buckets.len()
62    }
63
64    /// Returns the top N most common values and their ratios.
65    pub fn top_n(&self, n: usize) -> Vec<(&str, f64)> {
66        self.buckets
67            .iter()
68            .take(n)
69            .map(|b| (b.value.as_str(), b.ratio))
70            .collect()
71    }
72
73    /// Checks if the distribution is roughly uniform (all values have similar frequencies).
74    ///
75    /// A distribution is considered roughly uniform if the ratio between the most common
76    /// and least common values is less than the threshold (default 1.5).
77    pub fn is_roughly_uniform(&self, threshold: f64) -> bool {
78        if self.buckets.is_empty() {
79            return true;
80        }
81
82        let max_ratio = self.most_common_ratio();
83        let min_ratio = self.least_common_ratio();
84
85        if min_ratio == 0.0 {
86            return false;
87        }
88
89        max_ratio / min_ratio <= threshold
90    }
91
92    /// Gets the ratio for a specific value, if it exists in the histogram.
93    pub fn get_value_ratio(&self, value: &str) -> Option<f64> {
94        self.buckets
95            .iter()
96            .find(|b| b.value == value)
97            .map(|b| b.ratio)
98    }
99
100    /// Returns the entropy of the distribution.
101    ///
102    /// Higher entropy indicates more uniform distribution.
103    pub fn entropy(&self) -> f64 {
104        self.buckets
105            .iter()
106            .filter(|b| b.ratio > 0.0)
107            .map(|b| -b.ratio * b.ratio.ln())
108            .sum()
109    }
110
111    /// Checks if the distribution follows a power law (few values dominate).
112    ///
113    /// Returns true if the top `n` values account for more than `threshold` of the distribution.
114    pub fn follows_power_law(&self, top_n: usize, threshold: f64) -> bool {
115        let top_sum: f64 = self.buckets.iter().take(top_n).map(|b| b.ratio).sum();
116        top_sum >= threshold
117    }
118
119    /// Returns the null ratio in the data.
120    pub fn null_ratio(&self) -> f64 {
121        if self.total_count == 0 {
122            0.0
123        } else {
124            self.null_count as f64 / self.total_count as f64
125        }
126    }
127}
128
129/// Type alias for histogram assertion function.
130pub type HistogramAssertion = Arc<dyn Fn(&Histogram) -> bool + Send + Sync>;
131
132/// A constraint that analyzes value distribution in a column and applies custom assertions.
133///
134/// This constraint computes a histogram of value frequencies and allows custom assertion
135/// functions to validate the distribution characteristics.
136///
137/// # Examples
138///
139/// ```rust
140/// use term_guard::constraints::{HistogramConstraint, Histogram};
141/// use term_guard::core::Constraint;
142/// use std::sync::Arc;
143///
144/// // Check that no single value dominates
145/// let constraint = HistogramConstraint::new("status", Arc::new(|hist: &Histogram| {
146///     hist.most_common_ratio() < 0.5
147/// }));
148///
149/// // Verify distribution has expected number of categories
150/// let constraint = HistogramConstraint::new("category", Arc::new(|hist| {
151///     hist.bucket_count() >= 5 && hist.bucket_count() <= 10
152/// }));
153/// ```
154#[derive(Clone)]
155pub struct HistogramConstraint {
156    column: String,
157    assertion: HistogramAssertion,
158    assertion_description: String,
159}
160
161impl fmt::Debug for HistogramConstraint {
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        f.debug_struct("HistogramConstraint")
164            .field("column", &self.column)
165            .field("assertion_description", &self.assertion_description)
166            .finish()
167    }
168}
169
170impl HistogramConstraint {
171    /// Creates a new histogram constraint.
172    ///
173    /// # Arguments
174    ///
175    /// * `column` - The column to analyze
176    /// * `assertion` - The assertion function to apply to the histogram
177    pub fn new(column: impl Into<String>, assertion: HistogramAssertion) -> Self {
178        Self {
179            column: column.into(),
180            assertion,
181            assertion_description: "custom assertion".to_string(),
182        }
183    }
184
185    /// Creates a new histogram constraint with a description.
186    ///
187    /// # Arguments
188    ///
189    /// * `column` - The column to analyze
190    /// * `assertion` - The assertion function to apply to the histogram
191    /// * `description` - A description of what the assertion checks
192    pub fn new_with_description(
193        column: impl Into<String>,
194        assertion: HistogramAssertion,
195        description: impl Into<String>,
196    ) -> Self {
197        Self {
198            column: column.into(),
199            assertion,
200            assertion_description: description.into(),
201        }
202    }
203}
204
205#[async_trait]
206impl Constraint for HistogramConstraint {
207    #[instrument(skip(self, ctx), fields(column = %self.column))]
208    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
209        // Get the table name from the validation context
210        let validation_ctx = current_validation_context();
211        let table_name = validation_ctx.table_name();
212
213        // SQL query to compute value frequencies
214        let sql = format!(
215            r#"
216            WITH value_counts AS (
217                SELECT 
218                    CAST({} AS VARCHAR) as value,
219                    COUNT(*) as count
220                FROM {table_name}
221                WHERE {} IS NOT NULL
222                GROUP BY {}
223            ),
224            totals AS (
225                SELECT 
226                    COUNT(*) as total_cnt,
227                    SUM(CASE WHEN {} IS NULL THEN 1 ELSE 0 END) as null_cnt
228                FROM {table_name}
229            )
230            SELECT 
231                vc.value,
232                vc.count,
233                vc.count * 1.0 / (t.total_cnt - t.null_cnt) as ratio,
234                t.total_cnt as total_count,
235                t.null_cnt as null_count
236            FROM value_counts vc
237            CROSS JOIN totals t
238            ORDER BY vc.count DESC, vc.value
239            "#,
240            self.column, self.column, self.column, self.column
241        );
242
243        let df = ctx.sql(&sql).await.map_err(|e| {
244            TermError::constraint_evaluation(
245                self.name(),
246                format!("Failed to execute histogram query: {e}"),
247            )
248        })?;
249
250        let batches = df.collect().await?;
251
252        if batches.is_empty() || batches[0].num_rows() == 0 {
253            return Ok(ConstraintResult::skipped("No data to analyze"));
254        }
255
256        // Extract histogram data from results
257        let mut buckets = Vec::new();
258        let mut total_count = 0i64;
259        let mut null_count = 0i64;
260
261        for batch in &batches {
262            // DataFusion might return various string types
263            let values_col = batch.column(0);
264            let value_strings: Vec<String> = match values_col.data_type() {
265                arrow::datatypes::DataType::Utf8 => {
266                    let arr = values_col
267                        .as_any()
268                        .downcast_ref::<arrow::array::StringArray>()
269                        .ok_or_else(|| {
270                            TermError::constraint_evaluation(
271                                self.name(),
272                                "Failed to extract string values",
273                            )
274                        })?;
275                    (0..arr.len()).map(|i| arr.value(i).to_string()).collect()
276                }
277                arrow::datatypes::DataType::Utf8View => {
278                    let arr = values_col
279                        .as_any()
280                        .downcast_ref::<StringViewArray>()
281                        .ok_or_else(|| {
282                            TermError::constraint_evaluation(
283                                self.name(),
284                                "Failed to extract string view values",
285                            )
286                        })?;
287                    (0..arr.len()).map(|i| arr.value(i).to_string()).collect()
288                }
289                arrow::datatypes::DataType::LargeUtf8 => {
290                    let arr = values_col
291                        .as_any()
292                        .downcast_ref::<LargeStringArray>()
293                        .ok_or_else(|| {
294                            TermError::constraint_evaluation(
295                                self.name(),
296                                "Failed to extract large string values",
297                            )
298                        })?;
299                    (0..arr.len()).map(|i| arr.value(i).to_string()).collect()
300                }
301                _ => {
302                    return Err(TermError::constraint_evaluation(
303                        self.name(),
304                        format!("Unexpected value column type: {:?}", values_col.data_type()),
305                    ));
306                }
307            };
308
309            let count_array = batch
310                .column(1)
311                .as_any()
312                .downcast_ref::<arrow::array::Int64Array>()
313                .ok_or_else(|| {
314                    TermError::constraint_evaluation(self.name(), "Failed to extract counts")
315                })?;
316
317            let ratio_array = batch
318                .column(2)
319                .as_any()
320                .downcast_ref::<arrow::array::Float64Array>()
321                .ok_or_else(|| {
322                    TermError::constraint_evaluation(self.name(), "Failed to extract ratios")
323                })?;
324
325            let total_array = batch
326                .column(3)
327                .as_any()
328                .downcast_ref::<arrow::array::Int64Array>()
329                .ok_or_else(|| {
330                    TermError::constraint_evaluation(self.name(), "Failed to extract total count")
331                })?;
332
333            let null_array = batch
334                .column(4)
335                .as_any()
336                .downcast_ref::<arrow::array::Int64Array>()
337                .ok_or_else(|| {
338                    TermError::constraint_evaluation(self.name(), "Failed to extract null count")
339                })?;
340
341            // Get total and null counts from first row
342            if batch.num_rows() > 0 {
343                total_count = total_array.value(0);
344                null_count = null_array.value(0);
345            }
346
347            // Collect buckets
348            for (i, value) in value_strings.into_iter().enumerate() {
349                let count = count_array.value(i);
350                let ratio = ratio_array.value(i);
351
352                buckets.push(HistogramBucket {
353                    value,
354                    count,
355                    ratio,
356                });
357            }
358        }
359
360        // Create histogram
361        let histogram = Histogram::new(buckets, total_count, null_count);
362
363        // Apply assertion
364        let assertion_result = (self.assertion)(&histogram);
365
366        let status = if assertion_result {
367            ConstraintStatus::Success
368        } else {
369            ConstraintStatus::Failure
370        };
371
372        let message = if status == ConstraintStatus::Failure {
373            let most_common_pct = histogram.most_common_ratio() * 100.0;
374            let null_pct = histogram.null_ratio() * 100.0;
375            Some(format!(
376                "Histogram assertion '{}' failed for column '{}'. Distribution: {} distinct values, most common ratio: {most_common_pct:.2}%, null ratio: {null_pct:.2}%",
377                self.assertion_description,
378                self.column,
379                histogram.distinct_count
380            ))
381        } else {
382            None
383        };
384
385        // Store histogram entropy as metric
386        Ok(ConstraintResult {
387            status,
388            metric: Some(histogram.entropy()),
389            message,
390        })
391    }
392
393    fn name(&self) -> &str {
394        "histogram"
395    }
396
397    fn column(&self) -> Option<&str> {
398        Some(&self.column)
399    }
400
401    fn metadata(&self) -> ConstraintMetadata {
402        ConstraintMetadata::for_column(&self.column)
403            .with_description(format!(
404                "Analyzes value distribution in column '{}' and applies assertion: {}",
405                self.column, self.assertion_description
406            ))
407            .with_custom("assertion", &self.assertion_description)
408            .with_custom("constraint_type", "histogram")
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use crate::core::ConstraintStatus;
416    use arrow::array::StringArray;
417    use arrow::datatypes::{DataType, Field, Schema};
418    use arrow::record_batch::RecordBatch;
419    use datafusion::datasource::MemTable;
420    use std::sync::Arc;
421
422    use crate::test_helpers::evaluate_constraint_with_context;
423    async fn create_test_context_with_data(values: Vec<Option<&str>>) -> SessionContext {
424        let ctx = SessionContext::new();
425
426        let schema = Arc::new(Schema::new(vec![Field::new(
427            "test_col",
428            DataType::Utf8,
429            true,
430        )]));
431
432        let array = StringArray::from(values);
433        let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
434
435        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
436        ctx.register_table("data", Arc::new(provider)).unwrap();
437
438        ctx
439    }
440
441    #[test]
442    fn test_histogram_basic() {
443        let buckets = vec![
444            HistogramBucket {
445                value: "A".to_string(),
446                count: 50,
447                ratio: 0.5,
448            },
449            HistogramBucket {
450                value: "B".to_string(),
451                count: 30,
452                ratio: 0.3,
453            },
454            HistogramBucket {
455                value: "C".to_string(),
456                count: 20,
457                ratio: 0.2,
458            },
459        ];
460
461        let histogram = Histogram::new(buckets, 100, 0);
462
463        assert_eq!(histogram.most_common_ratio(), 0.5);
464        assert_eq!(histogram.least_common_ratio(), 0.2);
465        assert_eq!(histogram.bucket_count(), 3);
466        assert_eq!(histogram.null_ratio(), 0.0);
467    }
468
469    #[test]
470    fn test_histogram_entropy() {
471        // Uniform distribution should have higher entropy
472        let uniform_buckets = vec![
473            HistogramBucket {
474                value: "A".to_string(),
475                count: 25,
476                ratio: 0.25,
477            },
478            HistogramBucket {
479                value: "B".to_string(),
480                count: 25,
481                ratio: 0.25,
482            },
483            HistogramBucket {
484                value: "C".to_string(),
485                count: 25,
486                ratio: 0.25,
487            },
488            HistogramBucket {
489                value: "D".to_string(),
490                count: 25,
491                ratio: 0.25,
492            },
493        ];
494
495        let uniform_hist = Histogram::new(uniform_buckets, 100, 0);
496
497        // Skewed distribution should have lower entropy
498        let skewed_buckets = vec![
499            HistogramBucket {
500                value: "A".to_string(),
501                count: 90,
502                ratio: 0.9,
503            },
504            HistogramBucket {
505                value: "B".to_string(),
506                count: 10,
507                ratio: 0.1,
508            },
509        ];
510
511        let skewed_hist = Histogram::new(skewed_buckets, 100, 0);
512
513        assert!(uniform_hist.entropy() > skewed_hist.entropy());
514    }
515
516    #[tokio::test]
517    async fn test_most_common_ratio_constraint() {
518        // Create data where "A" appears 60% of the time
519        let values = vec![
520            Some("A"),
521            Some("A"),
522            Some("A"),
523            Some("A"),
524            Some("A"),
525            Some("A"),
526            Some("B"),
527            Some("B"),
528            Some("C"),
529            Some("C"),
530        ];
531        let ctx = create_test_context_with_data(values).await;
532
533        // Constraint that fails: most common should be < 50%
534        let constraint = HistogramConstraint::new_with_description(
535            "test_col",
536            Arc::new(|hist| hist.most_common_ratio() < 0.5),
537            "most common value appears less than 50%",
538        );
539
540        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
541            .await
542            .unwrap();
543        assert_eq!(result.status, ConstraintStatus::Failure);
544        assert!(result.message.is_some());
545
546        // Constraint that passes: most common should be < 70%
547        let constraint =
548            HistogramConstraint::new("test_col", Arc::new(|hist| hist.most_common_ratio() < 0.7));
549
550        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
551            .await
552            .unwrap();
553        assert_eq!(result.status, ConstraintStatus::Success);
554    }
555
556    #[tokio::test]
557    async fn test_bucket_count_constraint() {
558        // Create data with 4 distinct values
559        let values = vec![
560            Some("RED"),
561            Some("BLUE"),
562            Some("GREEN"),
563            Some("YELLOW"),
564            Some("RED"),
565            Some("BLUE"),
566        ];
567        let ctx = create_test_context_with_data(values).await;
568
569        let constraint = HistogramConstraint::new_with_description(
570            "test_col",
571            Arc::new(|hist| hist.bucket_count() >= 3 && hist.bucket_count() <= 5),
572            "has between 3 and 5 distinct values",
573        );
574
575        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
576            .await
577            .unwrap();
578        assert_eq!(result.status, ConstraintStatus::Success);
579    }
580
581    #[tokio::test]
582    async fn test_uniform_distribution_check() {
583        // Create roughly uniform distribution
584        let values = vec![
585            Some("A"),
586            Some("A"),
587            Some("B"),
588            Some("B"),
589            Some("C"),
590            Some("C"),
591            Some("D"),
592            Some("D"),
593        ];
594        let ctx = create_test_context_with_data(values).await;
595
596        let constraint =
597            HistogramConstraint::new("test_col", Arc::new(|hist| hist.is_roughly_uniform(1.5)));
598
599        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
600            .await
601            .unwrap();
602        assert_eq!(result.status, ConstraintStatus::Success);
603    }
604
605    #[tokio::test]
606    async fn test_power_law_distribution() {
607        // Create power law distribution where top 2 values dominate
608        let values = vec![
609            Some("Popular1"),
610            Some("Popular1"),
611            Some("Popular1"),
612            Some("Popular1"),
613            Some("Popular2"),
614            Some("Popular2"),
615            Some("Popular2"),
616            Some("Rare1"),
617            Some("Rare2"),
618            Some("Rare3"),
619        ];
620        let ctx = create_test_context_with_data(values).await;
621
622        let constraint = HistogramConstraint::new_with_description(
623            "test_col",
624            Arc::new(|hist| hist.follows_power_law(2, 0.7)),
625            "top 2 values account for 70% of distribution",
626        );
627
628        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
629            .await
630            .unwrap();
631        assert_eq!(result.status, ConstraintStatus::Success);
632    }
633
634    #[tokio::test]
635    async fn test_with_nulls() {
636        let values = vec![
637            Some("A"),
638            Some("A"),
639            None,
640            None,
641            None,
642            Some("B"),
643            Some("B"),
644            Some("C"),
645        ];
646        let ctx = create_test_context_with_data(values).await;
647
648        let constraint = HistogramConstraint::new(
649            "test_col",
650            Arc::new(|hist| hist.null_ratio() > 0.3 && hist.null_ratio() < 0.4),
651        );
652
653        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
654            .await
655            .unwrap();
656        assert_eq!(result.status, ConstraintStatus::Success);
657    }
658
659    #[tokio::test]
660    async fn test_empty_data() {
661        let ctx = create_test_context_with_data(vec![]).await;
662
663        let constraint = HistogramConstraint::new("test_col", Arc::new(|_| true));
664
665        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
666            .await
667            .unwrap();
668        assert_eq!(result.status, ConstraintStatus::Skipped);
669    }
670
671    #[tokio::test]
672    async fn test_specific_value_check() {
673        let values = vec![
674            Some("PENDING"),
675            Some("PENDING"),
676            Some("APPROVED"),
677            Some("APPROVED"),
678            Some("APPROVED"),
679            Some("REJECTED"),
680        ];
681        let ctx = create_test_context_with_data(values).await;
682
683        let constraint = HistogramConstraint::new_with_description(
684            "test_col",
685            Arc::new(|hist| {
686                // Check that APPROVED is the most common status
687                hist.get_value_ratio("APPROVED").unwrap_or(0.0) > 0.4
688            }),
689            "APPROVED status is most common",
690        );
691
692        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
693            .await
694            .unwrap();
695        assert_eq!(result.status, ConstraintStatus::Success);
696    }
697
698    #[tokio::test]
699    async fn test_top_n_values() {
700        let values = vec![
701            Some("A"),
702            Some("A"),
703            Some("A"),
704            Some("A"), // 40%
705            Some("B"),
706            Some("B"),
707            Some("B"), // 30%
708            Some("C"),
709            Some("C"), // 20%
710            Some("D"), // 10%
711        ];
712        let ctx = create_test_context_with_data(values).await;
713
714        let constraint = HistogramConstraint::new(
715            "test_col",
716            Arc::new(|hist| {
717                let top_2 = hist.top_n(2);
718                top_2.len() == 2 && top_2[0].1 == 0.4 && top_2[1].1 == 0.3
719            }),
720        );
721
722        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
723            .await
724            .unwrap();
725        assert_eq!(result.status, ConstraintStatus::Success);
726    }
727
728    #[tokio::test]
729    async fn test_numeric_data_histogram() {
730        use arrow::array::Int64Array;
731        use arrow::datatypes::{DataType, Field, Schema};
732
733        let ctx = SessionContext::new();
734
735        let schema = Arc::new(Schema::new(vec![Field::new("age", DataType::Int64, true)]));
736
737        let values = vec![
738            Some(25),
739            Some(25),
740            Some(30),
741            Some(30),
742            Some(30),
743            Some(35),
744            Some(35),
745            Some(40),
746            Some(45),
747            Some(50),
748        ];
749        let array = Int64Array::from(values);
750        let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
751
752        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
753        ctx.register_table("data", Arc::new(provider)).unwrap();
754
755        let constraint = HistogramConstraint::new_with_description(
756            "age",
757            Arc::new(|hist| {
758                // Check we have reasonable age distribution
759                hist.bucket_count() >= 5 && hist.most_common_ratio() < 0.4
760            }),
761            "age distribution is reasonable",
762        );
763
764        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
765            .await
766            .unwrap();
767        assert_eq!(result.status, ConstraintStatus::Success);
768    }
769}