1use crate::core::{current_validation_context, Constraint, ConstraintMetadata, ConstraintResult};
4use crate::prelude::*;
5use arrow::array::Array;
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use tracing::instrument;
9#[derive(Debug, Clone, PartialEq)]
11#[allow(dead_code)]
12pub enum DataType {
13 Integer,
15 Float,
17 Boolean,
19 Date,
21 Timestamp,
23 String,
25}
26
27impl DataType {
28 fn pattern(&self) -> &str {
30 match self {
31 DataType::Integer => r"^-?\d+$",
32 DataType::Float => r"^-?\d*\.?\d+([eE][+-]?\d+)?$",
33 DataType::Boolean => r"^(true|false|TRUE|FALSE|True|False|0|1)$",
34 DataType::Date => r"^\d{4}-\d{2}-\d{2}$",
35 DataType::Timestamp => r"^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}",
36 DataType::String => r".*", }
38 }
39
40 fn name(&self) -> &str {
42 match self {
43 DataType::Integer => "integer",
44 DataType::Float => "float",
45 DataType::Boolean => "boolean",
46 DataType::Date => "date",
47 DataType::Timestamp => "timestamp",
48 DataType::String => "string",
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
68#[allow(dead_code)]
69pub struct DataTypeConstraint {
70 column: String,
71 data_type: DataType,
72 threshold: f64,
73}
74
75#[allow(dead_code)]
76impl DataTypeConstraint {
77 pub fn new(column: impl Into<String>, data_type: DataType, threshold: f64) -> Self {
89 assert!(
90 (0.0..=1.0).contains(&threshold),
91 "Threshold must be between 0.0 and 1.0"
92 );
93 Self {
94 column: column.into(),
95 data_type,
96 threshold,
97 }
98 }
99}
100
101#[async_trait]
102impl Constraint for DataTypeConstraint {
103 #[instrument(skip(self, ctx), fields(column = %self.column, data_type = %self.data_type.name(), threshold = %self.threshold))]
104 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
105 let pattern = self.data_type.pattern();
106
107 let validation_ctx = current_validation_context();
110
111 let table_name = validation_ctx.table_name();
112
113 let sql = format!(
114 "SELECT
115 COUNT(CASE WHEN {} ~ '{pattern}' THEN 1 END) as matches,
116 COUNT(*) as total
117 FROM {table_name}
118 WHERE {} IS NOT NULL",
119 self.column, self.column
120 );
121
122 let df = ctx.sql(&sql).await?;
123 let batches = df.collect().await?;
124
125 if batches.is_empty() {
126 return Ok(ConstraintResult::skipped("No data to validate"));
127 }
128
129 let batch = &batches[0];
130 if batch.num_rows() == 0 {
131 return Ok(ConstraintResult::skipped("No data to validate"));
132 }
133
134 let matches = batch
135 .column(0)
136 .as_any()
137 .downcast_ref::<arrow::array::Int64Array>()
138 .ok_or_else(|| TermError::Internal("Failed to extract match count".to_string()))?
139 .value(0) as f64;
140
141 let total = batch
142 .column(1)
143 .as_any()
144 .downcast_ref::<arrow::array::Int64Array>()
145 .ok_or_else(|| TermError::Internal("Failed to extract total count".to_string()))?
146 .value(0) as f64;
147
148 if total == 0.0 {
149 return Ok(ConstraintResult::skipped("No non-null data to validate"));
150 }
151
152 let type_ratio = matches / total;
153
154 if type_ratio >= self.threshold {
155 Ok(ConstraintResult::success_with_metric(type_ratio))
156 } else {
157 Ok(ConstraintResult::failure_with_metric(
158 type_ratio,
159 format!(
160 "Data type conformance {type_ratio} is below threshold {}",
161 self.threshold
162 ),
163 ))
164 }
165 }
166
167 fn name(&self) -> &str {
168 "data_type"
169 }
170
171 fn column(&self) -> Option<&str> {
172 Some(&self.column)
173 }
174
175 fn metadata(&self) -> ConstraintMetadata {
176 ConstraintMetadata::for_column(&self.column)
177 .with_description(format!(
178 "Checks that at least {:.1}% of values in '{}' conform to {} type",
179 self.threshold * 100.0,
180 self.column,
181 self.data_type.name()
182 ))
183 .with_custom("data_type", self.data_type.name())
184 .with_custom("threshold", self.threshold.to_string())
185 .with_custom("constraint_type", "data_type")
186 }
187}
188
189#[derive(Debug, Clone)]
205pub struct ContainmentConstraint {
206 column: String,
207 allowed_values: Vec<String>,
208}
209
210impl ContainmentConstraint {
211 pub fn new<I, S>(column: impl Into<String>, allowed_values: I) -> Self
218 where
219 I: IntoIterator<Item = S>,
220 S: Into<String>,
221 {
222 Self {
223 column: column.into(),
224 allowed_values: allowed_values.into_iter().map(Into::into).collect(),
225 }
226 }
227}
228
229#[async_trait]
230impl Constraint for ContainmentConstraint {
231 #[instrument(skip(self, ctx), fields(column = %self.column, allowed_count = %self.allowed_values.len()))]
232 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
233 let validation_ctx = current_validation_context();
235 let table_name = validation_ctx.table_name();
236
237 let values_list = self
239 .allowed_values
240 .iter()
241 .map(|v| format!("'{}'", v.replace('\'', "''"))) .collect::<Vec<_>>()
243 .join(", ");
244
245 let sql = format!(
246 "SELECT
247 COUNT(CASE WHEN {} IN ({values_list}) THEN 1 END) as valid_values,
248 COUNT(*) as total
249 FROM {table_name}
250 WHERE {} IS NOT NULL",
251 self.column, self.column
252 );
253
254 let df = ctx.sql(&sql).await?;
255 let batches = df.collect().await?;
256
257 if batches.is_empty() {
258 return Ok(ConstraintResult::skipped("No data to validate"));
259 }
260
261 let batch = &batches[0];
262 if batch.num_rows() == 0 {
263 return Ok(ConstraintResult::skipped("No data to validate"));
264 }
265
266 let valid_values = batch
267 .column(0)
268 .as_any()
269 .downcast_ref::<arrow::array::Int64Array>()
270 .ok_or_else(|| TermError::Internal("Failed to extract valid count".to_string()))?
271 .value(0) as f64;
272
273 let total = batch
274 .column(1)
275 .as_any()
276 .downcast_ref::<arrow::array::Int64Array>()
277 .ok_or_else(|| TermError::Internal("Failed to extract total count".to_string()))?
278 .value(0) as f64;
279
280 if total == 0.0 {
281 return Ok(ConstraintResult::skipped("No non-null data to validate"));
282 }
283
284 let containment_ratio = valid_values / total;
285
286 if containment_ratio == 1.0 {
287 Ok(ConstraintResult::success_with_metric(containment_ratio))
288 } else {
289 let invalid_count = total - valid_values;
290 Ok(ConstraintResult::failure_with_metric(
291 containment_ratio,
292 format!("{invalid_count} values are not in the allowed set"),
293 ))
294 }
295 }
296
297 fn name(&self) -> &str {
298 "containment"
299 }
300
301 fn column(&self) -> Option<&str> {
302 Some(&self.column)
303 }
304
305 fn metadata(&self) -> ConstraintMetadata {
306 ConstraintMetadata::for_column(&self.column)
307 .with_description(format!(
308 "Checks that all values in '{}' are contained in the allowed set",
309 self.column
310 ))
311 .with_custom(
312 "allowed_values",
313 format!("[{}]", self.allowed_values.join(", ")),
314 )
315 .with_custom("constraint_type", "containment")
316 }
317}
318
319#[derive(Debug, Clone)]
335#[allow(dead_code)]
336pub struct NonNegativeConstraint {
337 column: String,
338}
339
340#[allow(dead_code)]
341impl NonNegativeConstraint {
342 pub fn new(column: impl Into<String>) -> Self {
348 Self {
349 column: column.into(),
350 }
351 }
352}
353
354#[async_trait]
355impl Constraint for NonNegativeConstraint {
356 #[instrument(skip(self, ctx), fields(column = %self.column))]
357 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
358 let validation_ctx = current_validation_context();
360 let table_name = validation_ctx.table_name();
361
362 let sql = format!(
364 "SELECT
365 COUNT(CASE WHEN CAST({} AS DOUBLE) >= 0 THEN 1 END) as non_negative,
366 COUNT(*) as total
367 FROM {table_name}
368 WHERE {} IS NOT NULL",
369 self.column, self.column
370 );
371
372 let df = ctx.sql(&sql).await?;
373 let batches = df.collect().await?;
374
375 if batches.is_empty() {
376 return Ok(ConstraintResult::skipped("No data to validate"));
377 }
378
379 let batch = &batches[0];
380 if batch.num_rows() == 0 {
381 return Ok(ConstraintResult::skipped("No data to validate"));
382 }
383
384 let non_negative = batch
385 .column(0)
386 .as_any()
387 .downcast_ref::<arrow::array::Int64Array>()
388 .ok_or_else(|| TermError::Internal("Failed to extract non-negative count".to_string()))?
389 .value(0) as f64;
390
391 let total = batch
392 .column(1)
393 .as_any()
394 .downcast_ref::<arrow::array::Int64Array>()
395 .ok_or_else(|| TermError::Internal("Failed to extract total count".to_string()))?
396 .value(0) as f64;
397
398 if total == 0.0 {
399 return Ok(ConstraintResult::skipped("No non-null data to validate"));
400 }
401
402 let non_negative_ratio = non_negative / total;
403
404 if non_negative_ratio == 1.0 {
405 Ok(ConstraintResult::success_with_metric(non_negative_ratio))
406 } else {
407 let negative_count = total - non_negative;
408 Ok(ConstraintResult::failure_with_metric(
409 non_negative_ratio,
410 format!("{negative_count} values are negative"),
411 ))
412 }
413 }
414
415 fn name(&self) -> &str {
416 "non_negative"
417 }
418
419 fn column(&self) -> Option<&str> {
420 Some(&self.column)
421 }
422
423 fn metadata(&self) -> ConstraintMetadata {
424 ConstraintMetadata::for_column(&self.column)
425 .with_description(format!(
426 "Checks that all values in '{}' are non-negative",
427 self.column
428 ))
429 .with_custom("constraint_type", "value_range")
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use crate::core::ConstraintStatus;
437 use arrow::array::{Float64Array, StringArray};
438 use arrow::datatypes::{DataType as ArrowDataType, Field, Schema};
439 use arrow::record_batch::RecordBatch;
440 use datafusion::datasource::MemTable;
441 use std::sync::Arc;
442
443 use crate::test_helpers::evaluate_constraint_with_context;
444 async fn create_string_test_context(values: Vec<Option<&str>>) -> SessionContext {
445 let ctx = SessionContext::new();
446
447 let schema = Arc::new(Schema::new(vec![Field::new(
448 "text_col",
449 ArrowDataType::Utf8,
450 true,
451 )]));
452
453 let array = StringArray::from(values);
454 let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
455
456 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
457 ctx.register_table("data", Arc::new(provider)).unwrap();
458
459 ctx
460 }
461
462 async fn create_numeric_test_context(values: Vec<Option<f64>>) -> SessionContext {
463 let ctx = SessionContext::new();
464
465 let schema = Arc::new(Schema::new(vec![Field::new(
466 "num_col",
467 ArrowDataType::Float64,
468 true,
469 )]));
470
471 let array = Float64Array::from(values);
472 let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
473
474 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
475 ctx.register_table("data", Arc::new(provider)).unwrap();
476
477 ctx
478 }
479
480 #[tokio::test]
481 async fn test_data_type_integer() {
482 let values = vec![Some("123"), Some("456"), Some("not_number"), Some("789")];
483 let ctx = create_string_test_context(values).await;
484
485 let constraint = DataTypeConstraint::new("text_col", DataType::Integer, 0.7);
486
487 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
488 .await
489 .unwrap();
490 assert_eq!(result.status, ConstraintStatus::Success);
491 assert_eq!(result.metric, Some(0.75)); }
493
494 #[tokio::test]
495 async fn test_data_type_float() {
496 let values = vec![Some("123.45"), Some("67.89"), Some("invalid"), Some("100")];
497 let ctx = create_string_test_context(values).await;
498
499 let constraint = DataTypeConstraint::new("text_col", DataType::Float, 0.7);
500
501 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
502 .await
503 .unwrap();
504 assert_eq!(result.status, ConstraintStatus::Success);
505 assert_eq!(result.metric, Some(0.75)); }
507
508 #[tokio::test]
509 async fn test_data_type_boolean() {
510 let values = vec![Some("true"), Some("false"), Some("invalid"), Some("1")];
511 let ctx = create_string_test_context(values).await;
512
513 let constraint = DataTypeConstraint::new("text_col", DataType::Boolean, 0.7);
514
515 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
516 .await
517 .unwrap();
518 assert_eq!(result.status, ConstraintStatus::Success);
519 assert_eq!(result.metric, Some(0.75)); }
521
522 #[tokio::test]
523 async fn test_containment_constraint() {
524 let values = vec![
525 Some("active"),
526 Some("inactive"),
527 Some("pending"),
528 Some("invalid_status"),
529 ];
530 let ctx = create_string_test_context(values).await;
531
532 let constraint = ContainmentConstraint::new(
533 "text_col",
534 vec!["active", "inactive", "pending", "archived"],
535 );
536
537 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
538 .await
539 .unwrap();
540 assert_eq!(result.status, ConstraintStatus::Failure);
541 assert_eq!(result.metric, Some(0.75)); }
543
544 #[tokio::test]
545 async fn test_containment_all_valid() {
546 let values = vec![Some("active"), Some("inactive"), Some("pending")];
547 let ctx = create_string_test_context(values).await;
548
549 let constraint = ContainmentConstraint::new(
550 "text_col",
551 vec!["active", "inactive", "pending", "archived"],
552 );
553
554 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
555 .await
556 .unwrap();
557 assert_eq!(result.status, ConstraintStatus::Success);
558 assert_eq!(result.metric, Some(1.0)); }
560
561 #[tokio::test]
562 async fn test_non_negative_constraint() {
563 let values = vec![Some(1.0), Some(0.0), Some(5.5), Some(100.0)];
564 let ctx = create_numeric_test_context(values).await;
565
566 let constraint = NonNegativeConstraint::new("num_col");
567
568 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
569 .await
570 .unwrap();
571 assert_eq!(result.status, ConstraintStatus::Success);
572 assert_eq!(result.metric, Some(1.0)); }
574
575 #[tokio::test]
576 async fn test_non_negative_with_negative() {
577 let values = vec![Some(1.0), Some(-2.0), Some(5.5), Some(100.0)];
578 let ctx = create_numeric_test_context(values).await;
579
580 let constraint = NonNegativeConstraint::new("num_col");
581
582 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
583 .await
584 .unwrap();
585 assert_eq!(result.status, ConstraintStatus::Failure);
586 assert_eq!(result.metric, Some(0.75)); }
588
589 #[tokio::test]
590 async fn test_with_nulls() {
591 let values = vec![Some("active"), None, Some("inactive"), None];
592 let ctx = create_string_test_context(values).await;
593
594 let constraint = ContainmentConstraint::new("text_col", vec!["active", "inactive"]);
595
596 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
597 .await
598 .unwrap();
599 assert_eq!(result.status, ConstraintStatus::Success);
600 assert_eq!(result.metric, Some(1.0)); }
602
603 #[test]
604 #[should_panic(expected = "Threshold must be between 0.0 and 1.0")]
605 fn test_invalid_threshold() {
606 DataTypeConstraint::new("col", DataType::Integer, 1.5);
607 }
608}