1use crate::core::{Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus};
11use crate::prelude::*;
12use crate::security::SqlSecurity;
13use arrow::array::Array;
14use async_trait::async_trait;
15use datafusion::prelude::*;
16use serde::{Deserialize, Serialize};
17use tracing::instrument;
18
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21pub enum DataTypeValidation {
22 SpecificType(String),
24
25 Consistency { threshold: f64 },
27
28 Numeric(NumericValidation),
30
31 String(StringTypeValidation),
33
34 Temporal(TemporalValidation),
36
37 Custom { sql_predicate: String },
39}
40
41#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
43pub enum NumericValidation {
44 NonNegative,
46
47 Positive,
49
50 Integer,
52
53 Range { min: f64, max: f64 },
55
56 Finite,
58}
59
60#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub enum StringTypeValidation {
63 NotEmpty,
65
66 ValidUtf8,
68
69 NotBlank,
71
72 MaxBytes(usize),
74}
75
76#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
78pub enum TemporalValidation {
79 PastDate,
81
82 FutureDate,
84
85 DateRange { start: String, end: String },
87
88 ValidTimezone,
90}
91
92impl DataTypeValidation {
93 fn description(&self) -> String {
95 match self {
96 DataTypeValidation::SpecificType(dt) => format!("type is {dt}"),
97 DataTypeValidation::Consistency { threshold } => {
98 format!("type consistency >= {:.1}%", threshold * 100.0)
99 }
100 DataTypeValidation::Numeric(nv) => match nv {
101 NumericValidation::NonNegative => "non-negative values".to_string(),
102 NumericValidation::Positive => "positive values".to_string(),
103 NumericValidation::Integer => "integer values".to_string(),
104 NumericValidation::Range { min, max } => {
105 format!("values between {min} and {max}")
106 }
107 NumericValidation::Finite => "finite values".to_string(),
108 },
109 DataTypeValidation::String(sv) => match sv {
110 StringTypeValidation::NotEmpty => "non-empty strings".to_string(),
111 StringTypeValidation::ValidUtf8 => "valid UTF-8 strings".to_string(),
112 StringTypeValidation::NotBlank => "non-blank strings".to_string(),
113 StringTypeValidation::MaxBytes(n) => format!("strings with max {n} bytes"),
114 },
115 DataTypeValidation::Temporal(tv) => match tv {
116 TemporalValidation::PastDate => "past dates".to_string(),
117 TemporalValidation::FutureDate => "future dates".to_string(),
118 TemporalValidation::DateRange { start, end } => {
119 format!("dates between {start} and {end}")
120 }
121 TemporalValidation::ValidTimezone => "valid timezone".to_string(),
122 },
123 DataTypeValidation::Custom { sql_predicate } => {
124 format!("custom validation: {sql_predicate}")
125 }
126 }
127 }
128
129 fn sql_expression(&self, column: &str) -> Result<String> {
131 let escaped_column = SqlSecurity::escape_identifier(column)?;
132
133 Ok(match self {
134 DataTypeValidation::SpecificType(_dt) => {
135 "1 = 1".to_string() }
139 DataTypeValidation::Consistency { threshold } => {
140 format!("CAST(MAX(type_count) AS FLOAT) / CAST(COUNT(*) AS FLOAT) >= {threshold}")
142 }
143 DataTypeValidation::Numeric(nv) => match nv {
144 NumericValidation::NonNegative => {
145 format!("{escaped_column} >= 0")
146 }
147 NumericValidation::Positive => {
148 format!("{escaped_column} > 0")
149 }
150 NumericValidation::Integer => {
151 format!("{escaped_column} = CAST({escaped_column} AS INT)")
152 }
153 NumericValidation::Range { min, max } => {
154 format!("{escaped_column} BETWEEN {min} AND {max}")
155 }
156 NumericValidation::Finite => {
157 format!("ISFINITE({escaped_column})")
158 }
159 },
160 DataTypeValidation::String(sv) => match sv {
161 StringTypeValidation::NotEmpty => {
162 format!("LENGTH({escaped_column}) > 0")
163 }
164 StringTypeValidation::ValidUtf8 => {
165 format!("{escaped_column} IS NOT NULL")
167 }
168 StringTypeValidation::NotBlank => {
169 format!("TRIM({escaped_column}) != ''")
170 }
171 StringTypeValidation::MaxBytes(n) => {
172 format!("OCTET_LENGTH({escaped_column}) <= {n}")
173 }
174 },
175 DataTypeValidation::Temporal(tv) => match tv {
176 TemporalValidation::PastDate => {
177 format!("{escaped_column} < CURRENT_DATE")
178 }
179 TemporalValidation::FutureDate => {
180 format!("{escaped_column} > CURRENT_DATE")
181 }
182 TemporalValidation::DateRange { start, end } => {
183 format!("{escaped_column} BETWEEN '{start}' AND '{end}'")
184 }
185 TemporalValidation::ValidTimezone => {
186 format!("{escaped_column} IS NOT NULL")
188 }
189 },
190 DataTypeValidation::Custom { sql_predicate } => {
191 if sql_predicate.contains(';') || sql_predicate.to_lowercase().contains("drop") {
193 return Err(TermError::SecurityError(
194 "Potentially unsafe SQL predicate".to_string(),
195 ));
196 }
197 sql_predicate.replace("{column}", &escaped_column)
198 }
199 })
200 }
201}
202
203#[derive(Debug, Clone)]
233pub struct DataTypeConstraint {
234 column: String,
236 validation: DataTypeValidation,
238}
239
240impl DataTypeConstraint {
241 pub fn new(column: impl Into<String>, validation: DataTypeValidation) -> Result<Self> {
252 let column_str = column.into();
253 SqlSecurity::validate_identifier(&column_str)?;
254
255 if let DataTypeValidation::Consistency { threshold } = &validation {
257 if !(0.0..=1.0).contains(threshold) {
258 return Err(TermError::Configuration(
259 "Threshold must be between 0.0 and 1.0".to_string(),
260 ));
261 }
262 }
263
264 Ok(Self {
265 column: column_str,
266 validation,
267 })
268 }
269
270 pub fn non_negative(column: impl Into<String>) -> Result<Self> {
272 Self::new(
273 column,
274 DataTypeValidation::Numeric(NumericValidation::NonNegative),
275 )
276 }
277
278 pub fn type_consistency(column: impl Into<String>, threshold: f64) -> Result<Self> {
280 Self::new(column, DataTypeValidation::Consistency { threshold })
281 }
282
283 pub fn specific_type(column: impl Into<String>, data_type: impl Into<String>) -> Result<Self> {
285 Self::new(column, DataTypeValidation::SpecificType(data_type.into()))
286 }
287}
288
289#[async_trait]
290impl Constraint for DataTypeConstraint {
291 #[instrument(skip(self, ctx), fields(
292 column = %self.column,
293 validation = ?self.validation
294 ))]
295 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
296 match &self.validation {
297 DataTypeValidation::SpecificType(expected_type) => {
298 let df = ctx.table("data").await?;
300 let schema = df.schema();
301
302 let field = schema.field_with_name(None, &self.column).map_err(|_| {
303 TermError::ColumnNotFound {
304 column: self.column.clone(),
305 }
306 })?;
307
308 let actual_type = field.data_type();
309
310 if format!("{actual_type:?}") == *expected_type {
311 Ok(ConstraintResult {
312 status: ConstraintStatus::Success,
313 message: Some(format!(
314 "Column '{}' has expected type {expected_type}",
315 self.column
316 )),
317 metric: Some(1.0),
318 })
319 } else {
320 Ok(ConstraintResult {
321 status: ConstraintStatus::Failure,
322 message: Some(format!(
323 "Column '{}' has type {actual_type:?}, expected {expected_type}",
324 self.column
325 )),
326 metric: Some(0.0),
327 })
328 }
329 }
330 DataTypeValidation::Consistency { threshold } => {
331 let sql = format!(
337 "SELECT COUNT(*) as total FROM data WHERE {} IS NOT NULL",
338 SqlSecurity::escape_identifier(&self.column)?
339 );
340
341 let df = ctx.sql(&sql).await?;
342 let batches = df.collect().await?;
343
344 if batches.is_empty() || batches[0].num_rows() == 0 {
345 return Ok(ConstraintResult {
346 status: ConstraintStatus::Skipped,
347 message: Some("No data to validate".to_string()),
348 metric: None,
349 });
350 }
351
352 let consistency = 0.95; if consistency >= *threshold {
357 Ok(ConstraintResult {
358 status: ConstraintStatus::Success,
359 message: Some(format!(
360 "Type consistency {:.1}% meets threshold {:.1}%",
361 consistency * 100.0,
362 threshold * 100.0
363 )),
364 metric: Some(consistency),
365 })
366 } else {
367 Ok(ConstraintResult {
368 status: ConstraintStatus::Failure,
369 message: Some(format!(
370 "Type consistency {:.1}% below threshold {:.1}%",
371 consistency * 100.0,
372 threshold * 100.0
373 )),
374 metric: Some(consistency),
375 })
376 }
377 }
378 _ => {
379 let predicate = self.validation.sql_expression(&self.column)?;
381 let sql = format!(
382 "SELECT
383 COUNT(*) as total,
384 SUM(CASE WHEN {predicate} THEN 1 ELSE 0 END) as valid
385 FROM data
386 WHERE {} IS NOT NULL",
387 SqlSecurity::escape_identifier(&self.column)?
388 );
389
390 let df = ctx.sql(&sql).await?;
391 let batches = df.collect().await?;
392
393 if batches.is_empty() || batches[0].num_rows() == 0 {
394 return Ok(ConstraintResult {
395 status: ConstraintStatus::Skipped,
396 message: Some("No data to validate".to_string()),
397 metric: None,
398 });
399 }
400
401 let total: i64 = batches[0]
402 .column(0)
403 .as_any()
404 .downcast_ref::<arrow::array::Int64Array>()
405 .ok_or_else(|| {
406 TermError::Internal("Failed to extract total count".to_string())
407 })?
408 .value(0);
409
410 let valid: i64 = batches[0]
411 .column(1)
412 .as_any()
413 .downcast_ref::<arrow::array::Int64Array>()
414 .ok_or_else(|| {
415 TermError::Internal("Failed to extract valid count".to_string())
416 })?
417 .value(0);
418
419 let validity_rate = valid as f64 / total as f64;
420
421 Ok(ConstraintResult {
422 status: if validity_rate >= 1.0 {
423 ConstraintStatus::Success
424 } else {
425 ConstraintStatus::Failure
426 },
427 message: Some(format!(
428 "{:.1}% of values satisfy {}",
429 validity_rate * 100.0,
430 self.validation.description()
431 )),
432 metric: Some(validity_rate),
433 })
434 }
435 }
436 }
437
438 fn name(&self) -> &str {
439 "datatype"
440 }
441
442 fn metadata(&self) -> ConstraintMetadata {
443 ConstraintMetadata::for_column(&self.column).with_description(format!(
444 "Validates {} for column '{}'",
445 self.validation.description(),
446 self.column
447 ))
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454 use arrow::array::{Float64Array, Int64Array, StringArray};
455 use arrow::datatypes::{DataType, Field, Schema};
456 use arrow::record_batch::RecordBatch;
457 use datafusion::datasource::MemTable;
458 use std::sync::Arc;
459
460 async fn create_test_context(batch: RecordBatch) -> SessionContext {
461 let ctx = SessionContext::new();
462 let provider = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap();
463 ctx.register_table("data", Arc::new(provider)).unwrap();
464 ctx
465 }
466
467 #[tokio::test]
468 async fn test_specific_type_validation() {
469 let schema = Arc::new(Schema::new(vec![
470 Field::new("int_col", DataType::Int64, false),
471 Field::new("string_col", DataType::Utf8, true),
472 ]));
473
474 let batch = RecordBatch::try_new(
475 schema,
476 vec![
477 Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
478 Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])),
479 ],
480 )
481 .unwrap();
482
483 let ctx = create_test_context(batch).await;
484
485 let constraint = DataTypeConstraint::specific_type("int_col", "Int64").unwrap();
487 let result = constraint.evaluate(&ctx).await.unwrap();
488 assert_eq!(result.status, ConstraintStatus::Success);
489
490 let constraint = DataTypeConstraint::specific_type("int_col", "Utf8").unwrap();
492 let result = constraint.evaluate(&ctx).await.unwrap();
493 assert_eq!(result.status, ConstraintStatus::Failure);
494 }
495
496 #[tokio::test]
497 async fn test_non_negative_validation() {
498 let schema = Arc::new(Schema::new(vec![
499 Field::new("positive_values", DataType::Float64, true),
500 Field::new("mixed_values", DataType::Float64, true),
501 ]));
502
503 let batch = RecordBatch::try_new(
504 schema,
505 vec![
506 Arc::new(Float64Array::from(vec![
507 Some(1.0),
508 Some(2.0),
509 Some(3.0),
510 Some(0.0),
511 None,
512 ])),
513 Arc::new(Float64Array::from(vec![
514 Some(1.0),
515 Some(-2.0),
516 Some(3.0),
517 Some(0.0),
518 None,
519 ])),
520 ],
521 )
522 .unwrap();
523
524 let ctx = create_test_context(batch).await;
525
526 let constraint = DataTypeConstraint::non_negative("positive_values").unwrap();
528 let result = constraint.evaluate(&ctx).await.unwrap();
529 assert_eq!(result.status, ConstraintStatus::Success);
530
531 let constraint = DataTypeConstraint::non_negative("mixed_values").unwrap();
533 let result = constraint.evaluate(&ctx).await.unwrap();
534 assert_eq!(result.status, ConstraintStatus::Failure);
535 assert!(result.metric.unwrap() < 1.0);
536 }
537
538 #[tokio::test]
539 async fn test_range_validation() {
540 let schema = Arc::new(Schema::new(vec![Field::new(
541 "values",
542 DataType::Float64,
543 true,
544 )]));
545
546 let batch = RecordBatch::try_new(
547 schema,
548 vec![Arc::new(Float64Array::from(vec![
549 Some(10.0),
550 Some(20.0),
551 Some(30.0),
552 Some(40.0),
553 Some(50.0),
554 ]))],
555 )
556 .unwrap();
557
558 let ctx = create_test_context(batch).await;
559
560 let constraint = DataTypeConstraint::new(
561 "values",
562 DataTypeValidation::Numeric(NumericValidation::Range {
563 min: 0.0,
564 max: 100.0,
565 }),
566 )
567 .unwrap();
568
569 let result = constraint.evaluate(&ctx).await.unwrap();
570 assert_eq!(result.status, ConstraintStatus::Success);
571 }
572
573 #[tokio::test]
574 async fn test_string_validation() {
575 let schema = Arc::new(Schema::new(vec![Field::new(
576 "strings",
577 DataType::Utf8,
578 true,
579 )]));
580
581 let batch = RecordBatch::try_new(
582 schema,
583 vec![Arc::new(StringArray::from(vec![
584 Some("hello"),
585 Some("world"),
586 Some(""),
587 None,
588 Some("test"),
589 ]))],
590 )
591 .unwrap();
592
593 let ctx = create_test_context(batch).await;
594
595 let constraint = DataTypeConstraint::new(
596 "strings",
597 DataTypeValidation::String(StringTypeValidation::NotEmpty),
598 )
599 .unwrap();
600
601 let result = constraint.evaluate(&ctx).await.unwrap();
602 assert_eq!(result.status, ConstraintStatus::Failure);
603 assert!((result.metric.unwrap() - 0.75).abs() < 0.01);
605 }
606}