term_guard/constraints/
approx_count_distinct.rs

1//! Approximate count distinct validation constraint.
2
3use crate::constraints::Assertion;
4use crate::core::{Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus};
5use crate::prelude::*;
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use tracing::instrument;
9
10/// A constraint that validates the approximate count of distinct values in a column.
11///
12/// This constraint uses DataFusion's APPROX_DISTINCT function which provides
13/// an approximate count using HyperLogLog algorithm. This is much faster than
14/// exact COUNT(DISTINCT) for large datasets while maintaining accuracy within
15/// 2-3% error margin.
16///
17/// # Examples
18///
19/// ```rust
20/// use term_guard::constraints::{ApproxCountDistinctConstraint, Assertion};
21/// use term_guard::core::Constraint;
22///
23/// // Check for high cardinality (e.g., user IDs)
24/// let constraint = ApproxCountDistinctConstraint::new("user_id", Assertion::GreaterThan(1000000.0));
25/// assert_eq!(constraint.name(), "approx_count_distinct");
26///
27/// // Check for low cardinality (e.g., country codes)
28/// let constraint = ApproxCountDistinctConstraint::new("country_code", Assertion::LessThan(200.0));
29/// ```
30#[derive(Debug, Clone)]
31pub struct ApproxCountDistinctConstraint {
32    column: String,
33    assertion: Assertion,
34}
35
36impl ApproxCountDistinctConstraint {
37    /// Creates a new approximate count distinct constraint.
38    ///
39    /// # Arguments
40    ///
41    /// * `column` - The column to count distinct values in
42    /// * `assertion` - The assertion to apply to the approximate distinct count
43    pub fn new<S: Into<String>>(column: S, assertion: Assertion) -> Self {
44        Self {
45            column: column.into(),
46            assertion,
47        }
48    }
49}
50
51#[async_trait]
52impl Constraint for ApproxCountDistinctConstraint {
53    #[instrument(skip(self, ctx), fields(column = %self.column, assertion = ?self.assertion))]
54    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
55        // Build the SQL query using APPROX_DISTINCT
56        let sql = format!(
57            "SELECT APPROX_DISTINCT({}) as approx_distinct_count FROM data",
58            self.column
59        );
60
61        let df = ctx.sql(&sql).await.map_err(|e| {
62            TermError::constraint_evaluation(
63                self.name(),
64                format!("Failed to execute approximate count distinct query: {e}"),
65            )
66        })?;
67
68        let batches = df.collect().await?;
69
70        if batches.is_empty() || batches[0].num_rows() == 0 {
71            return Ok(ConstraintResult::skipped("No data to validate"));
72        }
73
74        let batch = &batches[0];
75
76        // Extract the approximate distinct count
77        let approx_count = batch
78            .column(0)
79            .as_any()
80            .downcast_ref::<arrow::array::UInt64Array>()
81            .ok_or_else(|| {
82                TermError::constraint_evaluation(
83                    self.name(),
84                    "Failed to extract approximate distinct count from result",
85                )
86            })?
87            .value(0) as f64;
88
89        // Evaluate the assertion
90        let assertion_result = self.assertion.evaluate(approx_count);
91
92        let status = if assertion_result {
93            ConstraintStatus::Success
94        } else {
95            ConstraintStatus::Failure
96        };
97
98        let message = if status == ConstraintStatus::Failure {
99            Some(format!(
100                "Approximate distinct count {approx_count} does not satisfy assertion {} for column '{}'",
101                self.assertion.description(),
102                self.column
103            ))
104        } else {
105            None
106        };
107
108        Ok(ConstraintResult {
109            status,
110            metric: Some(approx_count),
111            message,
112        })
113    }
114
115    fn name(&self) -> &str {
116        "approx_count_distinct"
117    }
118
119    fn column(&self) -> Option<&str> {
120        Some(&self.column)
121    }
122
123    fn metadata(&self) -> ConstraintMetadata {
124        ConstraintMetadata::for_column(&self.column)
125            .with_description(format!(
126                "Checks that the approximate distinct count of column '{}' {}",
127                self.column,
128                self.assertion.description()
129            ))
130            .with_custom("assertion", self.assertion.description())
131            .with_custom("constraint_type", "cardinality")
132            .with_custom("algorithm", "HyperLogLog")
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::core::ConstraintStatus;
140    use arrow::array::{Int64Array, StringArray};
141    use arrow::datatypes::{DataType, Field, Schema};
142    use arrow::record_batch::RecordBatch;
143    use datafusion::datasource::MemTable;
144    use std::sync::Arc;
145
146    async fn create_test_context_with_data(values: Vec<Option<i64>>) -> SessionContext {
147        let ctx = SessionContext::new();
148
149        let schema = Arc::new(Schema::new(vec![Field::new(
150            "test_col",
151            DataType::Int64,
152            true,
153        )]));
154
155        let array = Int64Array::from(values);
156        let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
157
158        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
159        ctx.register_table("data", Arc::new(provider)).unwrap();
160
161        ctx
162    }
163
164    async fn create_string_context_with_data(values: Vec<Option<&str>>) -> SessionContext {
165        let ctx = SessionContext::new();
166
167        let schema = Arc::new(Schema::new(vec![Field::new(
168            "test_col",
169            DataType::Utf8,
170            true,
171        )]));
172
173        let array = StringArray::from(values);
174        let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
175
176        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
177        ctx.register_table("data", Arc::new(provider)).unwrap();
178
179        ctx
180    }
181
182    #[tokio::test]
183    async fn test_high_cardinality() {
184        // All unique values
185        let values: Vec<Option<i64>> = (0..1000).map(Some).collect();
186        let ctx = create_test_context_with_data(values).await;
187
188        let constraint =
189            ApproxCountDistinctConstraint::new("test_col", Assertion::GreaterThan(990.0));
190        let result = constraint.evaluate(&ctx).await.unwrap();
191
192        assert_eq!(result.status, ConstraintStatus::Success);
193        // APPROX_DISTINCT should be close to 1000
194        assert!(result.metric.unwrap() > 990.0);
195    }
196
197    #[tokio::test]
198    async fn test_low_cardinality() {
199        // Only a few distinct values repeated many times
200        let mut values = Vec::new();
201        for _ in 0..100 {
202            values.push(Some(1));
203            values.push(Some(2));
204            values.push(Some(3));
205        }
206        let ctx = create_test_context_with_data(values).await;
207
208        let constraint = ApproxCountDistinctConstraint::new("test_col", Assertion::LessThan(10.0));
209        let result = constraint.evaluate(&ctx).await.unwrap();
210
211        assert_eq!(result.status, ConstraintStatus::Success);
212        // Should be approximately 3
213        assert!(result.metric.unwrap() < 10.0);
214    }
215
216    #[tokio::test]
217    async fn test_with_nulls() {
218        // Mix of values and nulls
219        let values = vec![
220            Some(1),
221            None,
222            Some(2),
223            None,
224            Some(3),
225            None,
226            Some(1),
227            Some(2),
228            Some(3),
229            None,
230        ];
231        let ctx = create_test_context_with_data(values).await;
232
233        // APPROX_DISTINCT should count distinct non-null values (approximately 3)
234        let constraint =
235            ApproxCountDistinctConstraint::new("test_col", Assertion::Between(2.0, 5.0));
236        let result = constraint.evaluate(&ctx).await.unwrap();
237
238        assert_eq!(result.status, ConstraintStatus::Success);
239        let metric = result.metric.unwrap();
240        assert!((2.0..=5.0).contains(&metric));
241    }
242
243    #[tokio::test]
244    async fn test_constraint_failure() {
245        // Create data with moderate cardinality
246        let values: Vec<Option<i64>> = (0..50).map(|i| Some(i % 10)).collect();
247        let ctx = create_test_context_with_data(values).await;
248
249        // Expect high cardinality but data has low cardinality (~10 distinct values)
250        let constraint =
251            ApproxCountDistinctConstraint::new("test_col", Assertion::GreaterThan(100.0));
252        let result = constraint.evaluate(&ctx).await.unwrap();
253
254        assert_eq!(result.status, ConstraintStatus::Failure);
255        assert!(result.metric.unwrap() < 20.0);
256        assert!(result.message.is_some());
257    }
258
259    #[tokio::test]
260    async fn test_string_column() {
261        let values = vec![
262            Some("apple"),
263            Some("banana"),
264            Some("cherry"),
265            Some("apple"),
266            Some("banana"),
267            Some("date"),
268            Some("elderberry"),
269            None,
270        ];
271        let ctx = create_string_context_with_data(values).await;
272
273        // Should have approximately 5 distinct non-null values
274        let constraint =
275            ApproxCountDistinctConstraint::new("test_col", Assertion::Between(4.0, 6.0));
276        let result = constraint.evaluate(&ctx).await.unwrap();
277
278        assert_eq!(result.status, ConstraintStatus::Success);
279    }
280
281    #[tokio::test]
282    async fn test_empty_data() {
283        let ctx = create_test_context_with_data(vec![]).await;
284
285        // APPROX_DISTINCT returns 0 for empty data, not skipped
286        let constraint = ApproxCountDistinctConstraint::new("test_col", Assertion::Equals(0.0));
287        let result = constraint.evaluate(&ctx).await.unwrap();
288
289        assert_eq!(result.status, ConstraintStatus::Success);
290        assert_eq!(result.metric, Some(0.0));
291    }
292
293    #[tokio::test]
294    async fn test_all_null_values() {
295        let values = vec![None, None, None, None, None];
296        let ctx = create_test_context_with_data(values).await;
297
298        // APPROX_DISTINCT of all nulls should be 0
299        let constraint = ApproxCountDistinctConstraint::new("test_col", Assertion::Equals(0.0));
300        let result = constraint.evaluate(&ctx).await.unwrap();
301
302        assert_eq!(result.status, ConstraintStatus::Success);
303        assert_eq!(result.metric, Some(0.0));
304    }
305
306    #[tokio::test]
307    async fn test_accuracy_comparison() {
308        // Create a dataset large enough to see HyperLogLog in action
309        // but small enough to verify accuracy
310        let mut values = Vec::new();
311        for i in 0..10000 {
312            values.push(Some(i % 1000)); // 1000 distinct values
313        }
314        let ctx = create_test_context_with_data(values).await;
315
316        let constraint =
317            ApproxCountDistinctConstraint::new("test_col", Assertion::Between(970.0, 1030.0));
318        let result = constraint.evaluate(&ctx).await.unwrap();
319
320        assert_eq!(result.status, ConstraintStatus::Success);
321        // Should be within 3% of 1000
322        let metric = result.metric.unwrap();
323        assert!((970.0..=1030.0).contains(&metric));
324    }
325
326    #[tokio::test]
327    async fn test_metadata() {
328        let constraint =
329            ApproxCountDistinctConstraint::new("user_id", Assertion::GreaterThan(1000000.0));
330        let metadata = constraint.metadata();
331
332        assert_eq!(metadata.columns, vec!["user_id".to_string()]);
333        let description = metadata.description.unwrap_or_default();
334        assert!(description.contains("approximate distinct count"));
335        assert!(description.contains("greater than 1000000"));
336        assert_eq!(
337            metadata.custom.get("algorithm"),
338            Some(&"HyperLogLog".to_string())
339        );
340        assert_eq!(
341            metadata.custom.get("constraint_type"),
342            Some(&"cardinality".to_string())
343        );
344    }
345}