1use crate::core::{
11 current_validation_context, Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus,
12};
13use crate::prelude::*;
14use crate::security::SqlSecurity;
15use arrow::array::Array;
16use async_trait::async_trait;
17use datafusion::prelude::*;
18use serde::{Deserialize, Serialize};
19use tracing::instrument;
20#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
22pub enum DataTypeValidation {
23 SpecificType(String),
25
26 Consistency { threshold: f64 },
28
29 Numeric(NumericValidation),
31
32 String(StringTypeValidation),
34
35 Temporal(TemporalValidation),
37
38 Custom { sql_predicate: String },
40}
41
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
44pub enum NumericValidation {
45 NonNegative,
47
48 Positive,
50
51 Integer,
53
54 Range { min: f64, max: f64 },
56
57 Finite,
59}
60
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
63pub enum StringTypeValidation {
64 NotEmpty,
66
67 ValidUtf8,
69
70 NotBlank,
72
73 MaxBytes(usize),
75}
76
77#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
79pub enum TemporalValidation {
80 PastDate,
82
83 FutureDate,
85
86 DateRange { start: String, end: String },
88
89 ValidTimezone,
91}
92
93impl DataTypeValidation {
94 fn description(&self) -> String {
96 match self {
97 DataTypeValidation::SpecificType(dt) => format!("type is {dt}"),
98 DataTypeValidation::Consistency { threshold } => {
99 format!("type consistency >= {:.1}%", threshold * 100.0)
100 }
101 DataTypeValidation::Numeric(nv) => match nv {
102 NumericValidation::NonNegative => "non-negative values".to_string(),
103 NumericValidation::Positive => "positive values".to_string(),
104 NumericValidation::Integer => "integer values".to_string(),
105 NumericValidation::Range { min, max } => {
106 format!("values between {min} and {max}")
107 }
108 NumericValidation::Finite => "finite values".to_string(),
109 },
110 DataTypeValidation::String(sv) => match sv {
111 StringTypeValidation::NotEmpty => "non-empty strings".to_string(),
112 StringTypeValidation::ValidUtf8 => "valid UTF-8 strings".to_string(),
113 StringTypeValidation::NotBlank => "non-blank strings".to_string(),
114 StringTypeValidation::MaxBytes(n) => format!("strings with max {n} bytes"),
115 },
116 DataTypeValidation::Temporal(tv) => match tv {
117 TemporalValidation::PastDate => "past dates".to_string(),
118 TemporalValidation::FutureDate => "future dates".to_string(),
119 TemporalValidation::DateRange { start, end } => {
120 format!("dates between {start} and {end}")
121 }
122 TemporalValidation::ValidTimezone => "valid timezone".to_string(),
123 },
124 DataTypeValidation::Custom { sql_predicate } => {
125 format!("custom validation: {sql_predicate}")
126 }
127 }
128 }
129
130 fn sql_expression(&self, column: &str) -> Result<String> {
132 let escaped_column = SqlSecurity::escape_identifier(column)?;
133
134 Ok(match self {
135 DataTypeValidation::SpecificType(_dt) => {
136 "1 = 1".to_string() }
140 DataTypeValidation::Consistency { threshold } => {
141 format!("CAST(MAX(type_count) AS FLOAT) / CAST(COUNT(*) AS FLOAT) >= {threshold}")
143 }
144 DataTypeValidation::Numeric(nv) => match nv {
145 NumericValidation::NonNegative => {
146 format!("{escaped_column} >= 0")
147 }
148 NumericValidation::Positive => {
149 format!("{escaped_column} > 0")
150 }
151 NumericValidation::Integer => {
152 format!("{escaped_column} = CAST({escaped_column} AS INT)")
153 }
154 NumericValidation::Range { min, max } => {
155 format!("{escaped_column} BETWEEN {min} AND {max}")
156 }
157 NumericValidation::Finite => {
158 format!("ISFINITE({escaped_column})")
159 }
160 },
161 DataTypeValidation::String(sv) => match sv {
162 StringTypeValidation::NotEmpty => {
163 format!("LENGTH({escaped_column}) > 0")
164 }
165 StringTypeValidation::ValidUtf8 => {
166 format!("{escaped_column} IS NOT NULL")
168 }
169 StringTypeValidation::NotBlank => {
170 format!("TRIM({escaped_column}) != ''")
171 }
172 StringTypeValidation::MaxBytes(n) => {
173 format!("OCTET_LENGTH({escaped_column}) <= {n}")
174 }
175 },
176 DataTypeValidation::Temporal(tv) => match tv {
177 TemporalValidation::PastDate => {
178 format!("{escaped_column} < CURRENT_DATE")
179 }
180 TemporalValidation::FutureDate => {
181 format!("{escaped_column} > CURRENT_DATE")
182 }
183 TemporalValidation::DateRange { start, end } => {
184 format!("{escaped_column} BETWEEN '{start}' AND '{end}'")
185 }
186 TemporalValidation::ValidTimezone => {
187 format!("{escaped_column} IS NOT NULL")
189 }
190 },
191 DataTypeValidation::Custom { sql_predicate } => {
192 if sql_predicate.contains(';') || sql_predicate.to_lowercase().contains("drop") {
194 return Err(TermError::SecurityError(
195 "Potentially unsafe SQL predicate".to_string(),
196 ));
197 }
198 sql_predicate.replace("{column}", &escaped_column)
199 }
200 })
201 }
202}
203
204#[derive(Debug, Clone)]
234pub struct DataTypeConstraint {
235 column: String,
237 validation: DataTypeValidation,
239}
240
241impl DataTypeConstraint {
242 pub fn new(column: impl Into<String>, validation: DataTypeValidation) -> Result<Self> {
253 let column_str = column.into();
254 SqlSecurity::validate_identifier(&column_str)?;
255
256 if let DataTypeValidation::Consistency { threshold } = &validation {
258 if !(0.0..=1.0).contains(threshold) {
259 return Err(TermError::Configuration(
260 "Threshold must be between 0.0 and 1.0".to_string(),
261 ));
262 }
263 }
264
265 Ok(Self {
266 column: column_str,
267 validation,
268 })
269 }
270
271 pub fn non_negative(column: impl Into<String>) -> Result<Self> {
273 Self::new(
274 column,
275 DataTypeValidation::Numeric(NumericValidation::NonNegative),
276 )
277 }
278
279 pub fn type_consistency(column: impl Into<String>, threshold: f64) -> Result<Self> {
281 Self::new(column, DataTypeValidation::Consistency { threshold })
282 }
283
284 pub fn specific_type(column: impl Into<String>, data_type: impl Into<String>) -> Result<Self> {
286 Self::new(column, DataTypeValidation::SpecificType(data_type.into()))
287 }
288}
289
290#[async_trait]
291impl Constraint for DataTypeConstraint {
292 #[instrument(skip(self, ctx), fields(
293 column = %self.column,
294 validation = ?self.validation
295 ))]
296 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
297 let validation_ctx = current_validation_context();
299 let table_name = validation_ctx.table_name();
300
301 match &self.validation {
302 DataTypeValidation::SpecificType(expected_type) => {
303 let df = ctx.table(table_name).await?;
305 let schema = df.schema();
306
307 let field = schema.field_with_name(None, &self.column).map_err(|_| {
308 TermError::ColumnNotFound {
309 column: self.column.clone(),
310 }
311 })?;
312
313 let actual_type = field.data_type();
314
315 if format!("{actual_type:?}") == *expected_type {
316 Ok(ConstraintResult {
317 status: ConstraintStatus::Success,
318 message: Some(format!(
319 "Column '{}' has expected type {expected_type}",
320 self.column
321 )),
322 metric: Some(1.0),
323 })
324 } else {
325 Ok(ConstraintResult {
326 status: ConstraintStatus::Failure,
327 message: Some(format!(
328 "Column '{}' has type {actual_type:?}, expected {expected_type}",
329 self.column
330 )),
331 metric: Some(0.0),
332 })
333 }
334 }
335 DataTypeValidation::Consistency { threshold } => {
336 let sql = format!(
342 "SELECT COUNT(*) as total FROM {table_name} WHERE {} IS NOT NULL",
343 SqlSecurity::escape_identifier(&self.column)?
344 );
345
346 let df = ctx.sql(&sql).await?;
347 let batches = df.collect().await?;
348
349 if batches.is_empty() || batches[0].num_rows() == 0 {
350 return Ok(ConstraintResult {
351 status: ConstraintStatus::Skipped,
352 message: Some("No data to validate".to_string()),
353 metric: None,
354 });
355 }
356
357 let consistency = 0.95; if consistency >= *threshold {
362 Ok(ConstraintResult {
363 status: ConstraintStatus::Success,
364 message: Some(format!(
365 "Type consistency {:.1}% meets threshold {:.1}%",
366 consistency * 100.0,
367 threshold * 100.0
368 )),
369 metric: Some(consistency),
370 })
371 } else {
372 Ok(ConstraintResult {
373 status: ConstraintStatus::Failure,
374 message: Some(format!(
375 "Type consistency {:.1}% below threshold {:.1}%",
376 consistency * 100.0,
377 threshold * 100.0
378 )),
379 metric: Some(consistency),
380 })
381 }
382 }
383 _ => {
384 let predicate = self.validation.sql_expression(&self.column)?;
386 let sql = format!(
387 "SELECT
388 COUNT(*) as total,
389 SUM(CASE WHEN {predicate} THEN 1 ELSE 0 END) as valid
390 FROM {table_name}
391 WHERE {} IS NOT NULL",
392 SqlSecurity::escape_identifier(&self.column)?
393 );
394
395 let df = ctx.sql(&sql).await?;
396 let batches = df.collect().await?;
397
398 if batches.is_empty() || batches[0].num_rows() == 0 {
399 return Ok(ConstraintResult {
400 status: ConstraintStatus::Skipped,
401 message: Some("No data to validate".to_string()),
402 metric: None,
403 });
404 }
405
406 let total: i64 = batches[0]
407 .column(0)
408 .as_any()
409 .downcast_ref::<arrow::array::Int64Array>()
410 .ok_or_else(|| {
411 TermError::Internal("Failed to extract total count".to_string())
412 })?
413 .value(0);
414
415 let valid: i64 = batches[0]
416 .column(1)
417 .as_any()
418 .downcast_ref::<arrow::array::Int64Array>()
419 .ok_or_else(|| {
420 TermError::Internal("Failed to extract valid count".to_string())
421 })?
422 .value(0);
423
424 let validity_rate = valid as f64 / total as f64;
425
426 Ok(ConstraintResult {
427 status: if validity_rate >= 1.0 {
428 ConstraintStatus::Success
429 } else {
430 ConstraintStatus::Failure
431 },
432 message: Some(format!(
433 "{:.1}% of values satisfy {}",
434 validity_rate * 100.0,
435 self.validation.description()
436 )),
437 metric: Some(validity_rate),
438 })
439 }
440 }
441 }
442
443 fn name(&self) -> &str {
444 "datatype"
445 }
446
447 fn metadata(&self) -> ConstraintMetadata {
448 ConstraintMetadata::for_column(&self.column).with_description(format!(
449 "Validates {} for column '{}'",
450 self.validation.description(),
451 self.column
452 ))
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459 use arrow::array::{Float64Array, Int64Array, StringArray};
460 use arrow::datatypes::{DataType, Field, Schema};
461 use arrow::record_batch::RecordBatch;
462 use datafusion::datasource::MemTable;
463 use std::sync::Arc;
464
465 use crate::test_helpers::evaluate_constraint_with_context;
466 async fn create_test_context(batch: RecordBatch) -> SessionContext {
467 let ctx = SessionContext::new();
468 let provider = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap();
469 ctx.register_table("data", Arc::new(provider)).unwrap();
470 ctx
471 }
472
473 #[tokio::test]
474 async fn test_specific_type_validation() {
475 let schema = Arc::new(Schema::new(vec![
476 Field::new("int_col", DataType::Int64, false),
477 Field::new("string_col", DataType::Utf8, true),
478 ]));
479
480 let batch = RecordBatch::try_new(
481 schema,
482 vec![
483 Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
484 Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])),
485 ],
486 )
487 .unwrap();
488
489 let ctx = create_test_context(batch).await;
490
491 let constraint = DataTypeConstraint::specific_type("int_col", "Int64").unwrap();
493 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
494 .await
495 .unwrap();
496 assert_eq!(result.status, ConstraintStatus::Success);
497
498 let constraint = DataTypeConstraint::specific_type("int_col", "Utf8").unwrap();
500 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
501 .await
502 .unwrap();
503 assert_eq!(result.status, ConstraintStatus::Failure);
504 }
505
506 #[tokio::test]
507 async fn test_non_negative_validation() {
508 let schema = Arc::new(Schema::new(vec![
509 Field::new("positive_values", DataType::Float64, true),
510 Field::new("mixed_values", DataType::Float64, true),
511 ]));
512
513 let batch = RecordBatch::try_new(
514 schema,
515 vec![
516 Arc::new(Float64Array::from(vec![
517 Some(1.0),
518 Some(2.0),
519 Some(3.0),
520 Some(0.0),
521 None,
522 ])),
523 Arc::new(Float64Array::from(vec![
524 Some(1.0),
525 Some(-2.0),
526 Some(3.0),
527 Some(0.0),
528 None,
529 ])),
530 ],
531 )
532 .unwrap();
533
534 let ctx = create_test_context(batch).await;
535
536 let constraint = DataTypeConstraint::non_negative("positive_values").unwrap();
538 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
539 .await
540 .unwrap();
541 assert_eq!(result.status, ConstraintStatus::Success);
542
543 let constraint = DataTypeConstraint::non_negative("mixed_values").unwrap();
545 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
546 .await
547 .unwrap();
548 assert_eq!(result.status, ConstraintStatus::Failure);
549 assert!(result.metric.unwrap() < 1.0);
550 }
551
552 #[tokio::test]
553 async fn test_range_validation() {
554 let schema = Arc::new(Schema::new(vec![Field::new(
555 "values",
556 DataType::Float64,
557 true,
558 )]));
559
560 let batch = RecordBatch::try_new(
561 schema,
562 vec![Arc::new(Float64Array::from(vec![
563 Some(10.0),
564 Some(20.0),
565 Some(30.0),
566 Some(40.0),
567 Some(50.0),
568 ]))],
569 )
570 .unwrap();
571
572 let ctx = create_test_context(batch).await;
573
574 let constraint = DataTypeConstraint::new(
575 "values",
576 DataTypeValidation::Numeric(NumericValidation::Range {
577 min: 0.0,
578 max: 100.0,
579 }),
580 )
581 .unwrap();
582
583 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
584 .await
585 .unwrap();
586 assert_eq!(result.status, ConstraintStatus::Success);
587 }
588
589 #[tokio::test]
590 async fn test_string_validation() {
591 let schema = Arc::new(Schema::new(vec![Field::new(
592 "strings",
593 DataType::Utf8,
594 true,
595 )]));
596
597 let batch = RecordBatch::try_new(
598 schema,
599 vec![Arc::new(StringArray::from(vec![
600 Some("hello"),
601 Some("world"),
602 Some(""),
603 None,
604 Some("test"),
605 ]))],
606 )
607 .unwrap();
608
609 let ctx = create_test_context(batch).await;
610
611 let constraint = DataTypeConstraint::new(
612 "strings",
613 DataTypeValidation::String(StringTypeValidation::NotEmpty),
614 )
615 .unwrap();
616
617 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
618 .await
619 .unwrap();
620 assert_eq!(result.status, ConstraintStatus::Failure);
621 assert!((result.metric.unwrap() - 0.75).abs() < 0.01);
623 }
624}