1use crate::constraints::Assertion;
4use crate::core::{Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus};
5use crate::prelude::*;
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use tracing::instrument;
9
10#[derive(Debug, Clone)]
31pub struct ApproxCountDistinctConstraint {
32 column: String,
33 assertion: Assertion,
34}
35
36impl ApproxCountDistinctConstraint {
37 pub fn new<S: Into<String>>(column: S, assertion: Assertion) -> Self {
44 Self {
45 column: column.into(),
46 assertion,
47 }
48 }
49}
50
51#[async_trait]
52impl Constraint for ApproxCountDistinctConstraint {
53 #[instrument(skip(self, ctx), fields(column = %self.column, assertion = ?self.assertion))]
54 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
55 let sql = format!(
57 "SELECT APPROX_DISTINCT({}) as approx_distinct_count FROM data",
58 self.column
59 );
60
61 let df = ctx.sql(&sql).await.map_err(|e| {
62 TermError::constraint_evaluation(
63 self.name(),
64 format!("Failed to execute approximate count distinct query: {e}"),
65 )
66 })?;
67
68 let batches = df.collect().await?;
69
70 if batches.is_empty() || batches[0].num_rows() == 0 {
71 return Ok(ConstraintResult::skipped("No data to validate"));
72 }
73
74 let batch = &batches[0];
75
76 let approx_count = batch
78 .column(0)
79 .as_any()
80 .downcast_ref::<arrow::array::UInt64Array>()
81 .ok_or_else(|| {
82 TermError::constraint_evaluation(
83 self.name(),
84 "Failed to extract approximate distinct count from result",
85 )
86 })?
87 .value(0) as f64;
88
89 let assertion_result = self.assertion.evaluate(approx_count);
91
92 let status = if assertion_result {
93 ConstraintStatus::Success
94 } else {
95 ConstraintStatus::Failure
96 };
97
98 let message = if status == ConstraintStatus::Failure {
99 Some(format!(
100 "Approximate distinct count {approx_count} does not satisfy assertion {} for column '{}'",
101 self.assertion.description(),
102 self.column
103 ))
104 } else {
105 None
106 };
107
108 Ok(ConstraintResult {
109 status,
110 metric: Some(approx_count),
111 message,
112 })
113 }
114
115 fn name(&self) -> &str {
116 "approx_count_distinct"
117 }
118
119 fn column(&self) -> Option<&str> {
120 Some(&self.column)
121 }
122
123 fn metadata(&self) -> ConstraintMetadata {
124 ConstraintMetadata::for_column(&self.column)
125 .with_description(format!(
126 "Checks that the approximate distinct count of column '{}' {}",
127 self.column,
128 self.assertion.description()
129 ))
130 .with_custom("assertion", self.assertion.description())
131 .with_custom("constraint_type", "cardinality")
132 .with_custom("algorithm", "HyperLogLog")
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::core::ConstraintStatus;
140 use arrow::array::{Int64Array, StringArray};
141 use arrow::datatypes::{DataType, Field, Schema};
142 use arrow::record_batch::RecordBatch;
143 use datafusion::datasource::MemTable;
144 use std::sync::Arc;
145
146 async fn create_test_context_with_data(values: Vec<Option<i64>>) -> SessionContext {
147 let ctx = SessionContext::new();
148
149 let schema = Arc::new(Schema::new(vec![Field::new(
150 "test_col",
151 DataType::Int64,
152 true,
153 )]));
154
155 let array = Int64Array::from(values);
156 let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
157
158 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
159 ctx.register_table("data", Arc::new(provider)).unwrap();
160
161 ctx
162 }
163
164 async fn create_string_context_with_data(values: Vec<Option<&str>>) -> SessionContext {
165 let ctx = SessionContext::new();
166
167 let schema = Arc::new(Schema::new(vec![Field::new(
168 "test_col",
169 DataType::Utf8,
170 true,
171 )]));
172
173 let array = StringArray::from(values);
174 let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
175
176 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
177 ctx.register_table("data", Arc::new(provider)).unwrap();
178
179 ctx
180 }
181
182 #[tokio::test]
183 async fn test_high_cardinality() {
184 let values: Vec<Option<i64>> = (0..1000).map(Some).collect();
186 let ctx = create_test_context_with_data(values).await;
187
188 let constraint =
189 ApproxCountDistinctConstraint::new("test_col", Assertion::GreaterThan(990.0));
190 let result = constraint.evaluate(&ctx).await.unwrap();
191
192 assert_eq!(result.status, ConstraintStatus::Success);
193 assert!(result.metric.unwrap() > 990.0);
195 }
196
197 #[tokio::test]
198 async fn test_low_cardinality() {
199 let mut values = Vec::new();
201 for _ in 0..100 {
202 values.push(Some(1));
203 values.push(Some(2));
204 values.push(Some(3));
205 }
206 let ctx = create_test_context_with_data(values).await;
207
208 let constraint = ApproxCountDistinctConstraint::new("test_col", Assertion::LessThan(10.0));
209 let result = constraint.evaluate(&ctx).await.unwrap();
210
211 assert_eq!(result.status, ConstraintStatus::Success);
212 assert!(result.metric.unwrap() < 10.0);
214 }
215
216 #[tokio::test]
217 async fn test_with_nulls() {
218 let values = vec![
220 Some(1),
221 None,
222 Some(2),
223 None,
224 Some(3),
225 None,
226 Some(1),
227 Some(2),
228 Some(3),
229 None,
230 ];
231 let ctx = create_test_context_with_data(values).await;
232
233 let constraint =
235 ApproxCountDistinctConstraint::new("test_col", Assertion::Between(2.0, 5.0));
236 let result = constraint.evaluate(&ctx).await.unwrap();
237
238 assert_eq!(result.status, ConstraintStatus::Success);
239 let metric = result.metric.unwrap();
240 assert!((2.0..=5.0).contains(&metric));
241 }
242
243 #[tokio::test]
244 async fn test_constraint_failure() {
245 let values: Vec<Option<i64>> = (0..50).map(|i| Some(i % 10)).collect();
247 let ctx = create_test_context_with_data(values).await;
248
249 let constraint =
251 ApproxCountDistinctConstraint::new("test_col", Assertion::GreaterThan(100.0));
252 let result = constraint.evaluate(&ctx).await.unwrap();
253
254 assert_eq!(result.status, ConstraintStatus::Failure);
255 assert!(result.metric.unwrap() < 20.0);
256 assert!(result.message.is_some());
257 }
258
259 #[tokio::test]
260 async fn test_string_column() {
261 let values = vec![
262 Some("apple"),
263 Some("banana"),
264 Some("cherry"),
265 Some("apple"),
266 Some("banana"),
267 Some("date"),
268 Some("elderberry"),
269 None,
270 ];
271 let ctx = create_string_context_with_data(values).await;
272
273 let constraint =
275 ApproxCountDistinctConstraint::new("test_col", Assertion::Between(4.0, 6.0));
276 let result = constraint.evaluate(&ctx).await.unwrap();
277
278 assert_eq!(result.status, ConstraintStatus::Success);
279 }
280
281 #[tokio::test]
282 async fn test_empty_data() {
283 let ctx = create_test_context_with_data(vec![]).await;
284
285 let constraint = ApproxCountDistinctConstraint::new("test_col", Assertion::Equals(0.0));
287 let result = constraint.evaluate(&ctx).await.unwrap();
288
289 assert_eq!(result.status, ConstraintStatus::Success);
290 assert_eq!(result.metric, Some(0.0));
291 }
292
293 #[tokio::test]
294 async fn test_all_null_values() {
295 let values = vec![None, None, None, None, None];
296 let ctx = create_test_context_with_data(values).await;
297
298 let constraint = ApproxCountDistinctConstraint::new("test_col", Assertion::Equals(0.0));
300 let result = constraint.evaluate(&ctx).await.unwrap();
301
302 assert_eq!(result.status, ConstraintStatus::Success);
303 assert_eq!(result.metric, Some(0.0));
304 }
305
306 #[tokio::test]
307 async fn test_accuracy_comparison() {
308 let mut values = Vec::new();
311 for i in 0..10000 {
312 values.push(Some(i % 1000)); }
314 let ctx = create_test_context_with_data(values).await;
315
316 let constraint =
317 ApproxCountDistinctConstraint::new("test_col", Assertion::Between(970.0, 1030.0));
318 let result = constraint.evaluate(&ctx).await.unwrap();
319
320 assert_eq!(result.status, ConstraintStatus::Success);
321 let metric = result.metric.unwrap();
323 assert!((970.0..=1030.0).contains(&metric));
324 }
325
326 #[tokio::test]
327 async fn test_metadata() {
328 let constraint =
329 ApproxCountDistinctConstraint::new("user_id", Assertion::GreaterThan(1000000.0));
330 let metadata = constraint.metadata();
331
332 assert_eq!(metadata.columns, vec!["user_id".to_string()]);
333 let description = metadata.description.unwrap_or_default();
334 assert!(description.contains("approximate distinct count"));
335 assert!(description.contains("greater than 1000000"));
336 assert_eq!(
337 metadata.custom.get("algorithm"),
338 Some(&"HyperLogLog".to_string())
339 );
340 assert_eq!(
341 metadata.custom.get("constraint_type"),
342 Some(&"cardinality".to_string())
343 );
344 }
345}