term_guard/constraints/
datatype.rs

1//! Unified data type constraint that consolidates type-related validations.
2//!
3//! This module provides a single, flexible data type constraint that replaces:
4//! - `DataTypeConstraint` - Validate specific data types
5//! - `DataTypeConsistencyConstraint` - Check type consistency across rows
6//! - `NonNegativeConstraint` - Ensure non-negative numeric values
7//!
8//! And adds support for more complex type validations.
9
10use crate::core::{
11    current_validation_context, Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus,
12};
13use crate::prelude::*;
14use crate::security::SqlSecurity;
15use arrow::array::Array;
16use async_trait::async_trait;
17use datafusion::prelude::*;
18use serde::{Deserialize, Serialize};
19use tracing::instrument;
20/// Types of data type validation that can be performed.
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
22pub enum DataTypeValidation {
23    /// Validate that column has a specific data type (use type name as string)
24    SpecificType(String),
25
26    /// Validate type consistency across rows
27    Consistency { threshold: f64 },
28
29    /// Validate numeric constraints
30    Numeric(NumericValidation),
31
32    /// Validate string type constraints
33    String(StringTypeValidation),
34
35    /// Validate temporal type constraints
36    Temporal(TemporalValidation),
37
38    /// Custom type validation with SQL predicate
39    Custom { sql_predicate: String },
40}
41
42/// Numeric type validations
43#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
44pub enum NumericValidation {
45    /// Values must be non-negative (>= 0)
46    NonNegative,
47
48    /// Values must be positive (> 0)
49    Positive,
50
51    /// Values must be integers (no fractional part)
52    Integer,
53
54    /// Values must be within a specific range
55    Range { min: f64, max: f64 },
56
57    /// Values must be finite (not NaN or Infinity)
58    Finite,
59}
60
61/// String type validations
62#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
63pub enum StringTypeValidation {
64    /// Strings must not be empty
65    NotEmpty,
66
67    /// Strings must have valid UTF-8 encoding
68    ValidUtf8,
69
70    /// Strings must not contain only whitespace
71    NotBlank,
72
73    /// Maximum byte length
74    MaxBytes(usize),
75}
76
77/// Temporal type validations
78#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
79pub enum TemporalValidation {
80    /// Dates must be in the past
81    PastDate,
82
83    /// Dates must be in the future
84    FutureDate,
85
86    /// Dates must be within a range
87    DateRange { start: String, end: String },
88
89    /// Timestamps must have valid timezone
90    ValidTimezone,
91}
92
93impl DataTypeValidation {
94    /// Returns a human-readable description of the validation.
95    fn description(&self) -> String {
96        match self {
97            DataTypeValidation::SpecificType(dt) => format!("type is {dt}"),
98            DataTypeValidation::Consistency { threshold } => {
99                format!("type consistency >= {:.1}%", threshold * 100.0)
100            }
101            DataTypeValidation::Numeric(nv) => match nv {
102                NumericValidation::NonNegative => "non-negative values".to_string(),
103                NumericValidation::Positive => "positive values".to_string(),
104                NumericValidation::Integer => "integer values".to_string(),
105                NumericValidation::Range { min, max } => {
106                    format!("values between {min} and {max}")
107                }
108                NumericValidation::Finite => "finite values".to_string(),
109            },
110            DataTypeValidation::String(sv) => match sv {
111                StringTypeValidation::NotEmpty => "non-empty strings".to_string(),
112                StringTypeValidation::ValidUtf8 => "valid UTF-8 strings".to_string(),
113                StringTypeValidation::NotBlank => "non-blank strings".to_string(),
114                StringTypeValidation::MaxBytes(n) => format!("strings with max {n} bytes"),
115            },
116            DataTypeValidation::Temporal(tv) => match tv {
117                TemporalValidation::PastDate => "past dates".to_string(),
118                TemporalValidation::FutureDate => "future dates".to_string(),
119                TemporalValidation::DateRange { start, end } => {
120                    format!("dates between {start} and {end}")
121                }
122                TemporalValidation::ValidTimezone => "valid timezone".to_string(),
123            },
124            DataTypeValidation::Custom { sql_predicate } => {
125                format!("custom validation: {sql_predicate}")
126            }
127        }
128    }
129
130    /// Generates the SQL expression for this validation.
131    fn sql_expression(&self, column: &str) -> Result<String> {
132        let escaped_column = SqlSecurity::escape_identifier(column)?;
133
134        Ok(match self {
135            DataTypeValidation::SpecificType(_dt) => {
136                // For specific type validation, we check the schema
137                // This would be handled differently in evaluate()
138                "1 = 1".to_string() // Placeholder
139            }
140            DataTypeValidation::Consistency { threshold } => {
141                // Count the most common type and compare to threshold
142                format!("CAST(MAX(type_count) AS FLOAT) / CAST(COUNT(*) AS FLOAT) >= {threshold}")
143            }
144            DataTypeValidation::Numeric(nv) => match nv {
145                NumericValidation::NonNegative => {
146                    format!("{escaped_column} >= 0")
147                }
148                NumericValidation::Positive => {
149                    format!("{escaped_column} > 0")
150                }
151                NumericValidation::Integer => {
152                    format!("{escaped_column} = CAST({escaped_column} AS INT)")
153                }
154                NumericValidation::Range { min, max } => {
155                    format!("{escaped_column} BETWEEN {min} AND {max}")
156                }
157                NumericValidation::Finite => {
158                    format!("ISFINITE({escaped_column})")
159                }
160            },
161            DataTypeValidation::String(sv) => match sv {
162                StringTypeValidation::NotEmpty => {
163                    format!("LENGTH({escaped_column}) > 0")
164                }
165                StringTypeValidation::ValidUtf8 => {
166                    // DataFusion handles UTF-8 validation internally
167                    format!("{escaped_column} IS NOT NULL")
168                }
169                StringTypeValidation::NotBlank => {
170                    format!("TRIM({escaped_column}) != ''")
171                }
172                StringTypeValidation::MaxBytes(n) => {
173                    format!("OCTET_LENGTH({escaped_column}) <= {n}")
174                }
175            },
176            DataTypeValidation::Temporal(tv) => match tv {
177                TemporalValidation::PastDate => {
178                    format!("{escaped_column} < CURRENT_DATE")
179                }
180                TemporalValidation::FutureDate => {
181                    format!("{escaped_column} > CURRENT_DATE")
182                }
183                TemporalValidation::DateRange { start, end } => {
184                    format!("{escaped_column} BETWEEN '{start}' AND '{end}'")
185                }
186                TemporalValidation::ValidTimezone => {
187                    // This would need custom implementation
188                    format!("{escaped_column} IS NOT NULL")
189                }
190            },
191            DataTypeValidation::Custom { sql_predicate } => {
192                // Basic validation to prevent obvious SQL injection
193                if sql_predicate.contains(';') || sql_predicate.to_lowercase().contains("drop") {
194                    return Err(TermError::SecurityError(
195                        "Potentially unsafe SQL predicate".to_string(),
196                    ));
197                }
198                sql_predicate.replace("{column}", &escaped_column)
199            }
200        })
201    }
202}
203
204/// A unified constraint that validates data types and type-related properties.
205///
206/// This constraint replaces individual type constraints and provides a consistent
207/// interface for all data type validations.
208///
209/// # Examples
210///
211/// ```rust
212/// use term_guard::constraints::{DataTypeConstraint, DataTypeValidation, NumericValidation};
213/// use term_guard::core::Constraint;
214///
215/// // Check for specific data type
216/// let type_check = DataTypeConstraint::new(
217///     "user_id",
218///     DataTypeValidation::SpecificType("Int64".to_string())
219/// );
220///
221/// // Check for non-negative values
222/// let non_negative = DataTypeConstraint::new(
223///     "amount",
224///     DataTypeValidation::Numeric(NumericValidation::NonNegative)
225/// );
226///
227/// // Check type consistency
228/// let consistency = DataTypeConstraint::new(
229///     "mixed_column",
230///     DataTypeValidation::Consistency { threshold: 0.95 }
231/// );
232/// ```
233#[derive(Debug, Clone)]
234pub struct DataTypeConstraint {
235    /// The column to validate
236    column: String,
237    /// The type of validation to perform
238    validation: DataTypeValidation,
239}
240
241impl DataTypeConstraint {
242    /// Creates a new unified data type constraint.
243    ///
244    /// # Arguments
245    ///
246    /// * `column` - The column to validate
247    /// * `validation` - The type of validation to perform
248    ///
249    /// # Errors
250    ///
251    /// Returns an error if the column name is invalid.
252    pub fn new(column: impl Into<String>, validation: DataTypeValidation) -> Result<Self> {
253        let column_str = column.into();
254        SqlSecurity::validate_identifier(&column_str)?;
255
256        // Validate threshold for consistency check
257        if let DataTypeValidation::Consistency { threshold } = &validation {
258            if !(0.0..=1.0).contains(threshold) {
259                return Err(TermError::Configuration(
260                    "Threshold must be between 0.0 and 1.0".to_string(),
261                ));
262            }
263        }
264
265        Ok(Self {
266            column: column_str,
267            validation,
268        })
269    }
270
271    /// Convenience constructor for non-negative constraint.
272    pub fn non_negative(column: impl Into<String>) -> Result<Self> {
273        Self::new(
274            column,
275            DataTypeValidation::Numeric(NumericValidation::NonNegative),
276        )
277    }
278
279    /// Convenience constructor for type consistency constraint.
280    pub fn type_consistency(column: impl Into<String>, threshold: f64) -> Result<Self> {
281        Self::new(column, DataTypeValidation::Consistency { threshold })
282    }
283
284    /// Convenience constructor for specific type constraint.
285    pub fn specific_type(column: impl Into<String>, data_type: impl Into<String>) -> Result<Self> {
286        Self::new(column, DataTypeValidation::SpecificType(data_type.into()))
287    }
288}
289
290#[async_trait]
291impl Constraint for DataTypeConstraint {
292    #[instrument(skip(self, ctx), fields(
293        column = %self.column,
294        validation = ?self.validation
295    ))]
296    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
297        // Get the table name from the validation context
298        let validation_ctx = current_validation_context();
299        let table_name = validation_ctx.table_name();
300
301        match &self.validation {
302            DataTypeValidation::SpecificType(expected_type) => {
303                // Check the schema for the column type
304                let df = ctx.table(table_name).await?;
305                let schema = df.schema();
306
307                let field = schema.field_with_name(None, &self.column).map_err(|_| {
308                    TermError::ColumnNotFound {
309                        column: self.column.clone(),
310                    }
311                })?;
312
313                let actual_type = field.data_type();
314
315                if format!("{actual_type:?}") == *expected_type {
316                    Ok(ConstraintResult {
317                        status: ConstraintStatus::Success,
318                        message: Some(format!(
319                            "Column '{}' has expected type {expected_type}",
320                            self.column
321                        )),
322                        metric: Some(1.0),
323                    })
324                } else {
325                    Ok(ConstraintResult {
326                        status: ConstraintStatus::Failure,
327                        message: Some(format!(
328                            "Column '{}' has type {actual_type:?}, expected {expected_type}",
329                            self.column
330                        )),
331                        metric: Some(0.0),
332                    })
333                }
334            }
335            DataTypeValidation::Consistency { threshold } => {
336                // For type consistency, we need to analyze the actual values
337                // DataFusion doesn't have typeof() function, so we'll check if all values
338                // have consistent formatting/structure
339
340                // For now, just check that the column exists and return a placeholder result
341                let sql = format!(
342                    "SELECT COUNT(*) as total FROM {table_name} WHERE {} IS NOT NULL",
343                    SqlSecurity::escape_identifier(&self.column)?
344                );
345
346                let df = ctx.sql(&sql).await?;
347                let batches = df.collect().await?;
348
349                if batches.is_empty() || batches[0].num_rows() == 0 {
350                    return Ok(ConstraintResult {
351                        status: ConstraintStatus::Skipped,
352                        message: Some("No data to validate".to_string()),
353                        metric: None,
354                    });
355                }
356
357                // For now, assume consistency is high (would need actual implementation)
358                // In a real implementation, we'd analyze value patterns, formats, etc.
359                let consistency = 0.95; // Placeholder
360
361                if consistency >= *threshold {
362                    Ok(ConstraintResult {
363                        status: ConstraintStatus::Success,
364                        message: Some(format!(
365                            "Type consistency {:.1}% meets threshold {:.1}%",
366                            consistency * 100.0,
367                            threshold * 100.0
368                        )),
369                        metric: Some(consistency),
370                    })
371                } else {
372                    Ok(ConstraintResult {
373                        status: ConstraintStatus::Failure,
374                        message: Some(format!(
375                            "Type consistency {:.1}% below threshold {:.1}%",
376                            consistency * 100.0,
377                            threshold * 100.0
378                        )),
379                        metric: Some(consistency),
380                    })
381                }
382            }
383            _ => {
384                // For other validations, use SQL predicates
385                let predicate = self.validation.sql_expression(&self.column)?;
386                let sql = format!(
387                    "SELECT 
388                        COUNT(*) as total,
389                        SUM(CASE WHEN {predicate} THEN 1 ELSE 0 END) as valid
390                     FROM {table_name}
391                     WHERE {} IS NOT NULL",
392                    SqlSecurity::escape_identifier(&self.column)?
393                );
394
395                let df = ctx.sql(&sql).await?;
396                let batches = df.collect().await?;
397
398                if batches.is_empty() || batches[0].num_rows() == 0 {
399                    return Ok(ConstraintResult {
400                        status: ConstraintStatus::Skipped,
401                        message: Some("No data to validate".to_string()),
402                        metric: None,
403                    });
404                }
405
406                let total: i64 = batches[0]
407                    .column(0)
408                    .as_any()
409                    .downcast_ref::<arrow::array::Int64Array>()
410                    .ok_or_else(|| {
411                        TermError::Internal("Failed to extract total count".to_string())
412                    })?
413                    .value(0);
414
415                let valid: i64 = batches[0]
416                    .column(1)
417                    .as_any()
418                    .downcast_ref::<arrow::array::Int64Array>()
419                    .ok_or_else(|| {
420                        TermError::Internal("Failed to extract valid count".to_string())
421                    })?
422                    .value(0);
423
424                let validity_rate = valid as f64 / total as f64;
425
426                Ok(ConstraintResult {
427                    status: if validity_rate >= 1.0 {
428                        ConstraintStatus::Success
429                    } else {
430                        ConstraintStatus::Failure
431                    },
432                    message: Some(format!(
433                        "{:.1}% of values satisfy {}",
434                        validity_rate * 100.0,
435                        self.validation.description()
436                    )),
437                    metric: Some(validity_rate),
438                })
439            }
440        }
441    }
442
443    fn name(&self) -> &str {
444        "datatype"
445    }
446
447    fn metadata(&self) -> ConstraintMetadata {
448        ConstraintMetadata::for_column(&self.column).with_description(format!(
449            "Validates {} for column '{}'",
450            self.validation.description(),
451            self.column
452        ))
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459    use arrow::array::{Float64Array, Int64Array, StringArray};
460    use arrow::datatypes::{DataType, Field, Schema};
461    use arrow::record_batch::RecordBatch;
462    use datafusion::datasource::MemTable;
463    use std::sync::Arc;
464
465    use crate::test_helpers::evaluate_constraint_with_context;
466    async fn create_test_context(batch: RecordBatch) -> SessionContext {
467        let ctx = SessionContext::new();
468        let provider = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap();
469        ctx.register_table("data", Arc::new(provider)).unwrap();
470        ctx
471    }
472
473    #[tokio::test]
474    async fn test_specific_type_validation() {
475        let schema = Arc::new(Schema::new(vec![
476            Field::new("int_col", DataType::Int64, false),
477            Field::new("string_col", DataType::Utf8, true),
478        ]));
479
480        let batch = RecordBatch::try_new(
481            schema,
482            vec![
483                Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
484                Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])),
485            ],
486        )
487        .unwrap();
488
489        let ctx = create_test_context(batch).await;
490
491        // Test correct type
492        let constraint = DataTypeConstraint::specific_type("int_col", "Int64").unwrap();
493        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
494            .await
495            .unwrap();
496        assert_eq!(result.status, ConstraintStatus::Success);
497
498        // Test incorrect type
499        let constraint = DataTypeConstraint::specific_type("int_col", "Utf8").unwrap();
500        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
501            .await
502            .unwrap();
503        assert_eq!(result.status, ConstraintStatus::Failure);
504    }
505
506    #[tokio::test]
507    async fn test_non_negative_validation() {
508        let schema = Arc::new(Schema::new(vec![
509            Field::new("positive_values", DataType::Float64, true),
510            Field::new("mixed_values", DataType::Float64, true),
511        ]));
512
513        let batch = RecordBatch::try_new(
514            schema,
515            vec![
516                Arc::new(Float64Array::from(vec![
517                    Some(1.0),
518                    Some(2.0),
519                    Some(3.0),
520                    Some(0.0),
521                    None,
522                ])),
523                Arc::new(Float64Array::from(vec![
524                    Some(1.0),
525                    Some(-2.0),
526                    Some(3.0),
527                    Some(0.0),
528                    None,
529                ])),
530            ],
531        )
532        .unwrap();
533
534        let ctx = create_test_context(batch).await;
535
536        // Test all non-negative values
537        let constraint = DataTypeConstraint::non_negative("positive_values").unwrap();
538        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
539            .await
540            .unwrap();
541        assert_eq!(result.status, ConstraintStatus::Success);
542
543        // Test mixed values
544        let constraint = DataTypeConstraint::non_negative("mixed_values").unwrap();
545        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
546            .await
547            .unwrap();
548        assert_eq!(result.status, ConstraintStatus::Failure);
549        assert!(result.metric.unwrap() < 1.0);
550    }
551
552    #[tokio::test]
553    async fn test_range_validation() {
554        let schema = Arc::new(Schema::new(vec![Field::new(
555            "values",
556            DataType::Float64,
557            true,
558        )]));
559
560        let batch = RecordBatch::try_new(
561            schema,
562            vec![Arc::new(Float64Array::from(vec![
563                Some(10.0),
564                Some(20.0),
565                Some(30.0),
566                Some(40.0),
567                Some(50.0),
568            ]))],
569        )
570        .unwrap();
571
572        let ctx = create_test_context(batch).await;
573
574        let constraint = DataTypeConstraint::new(
575            "values",
576            DataTypeValidation::Numeric(NumericValidation::Range {
577                min: 0.0,
578                max: 100.0,
579            }),
580        )
581        .unwrap();
582
583        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
584            .await
585            .unwrap();
586        assert_eq!(result.status, ConstraintStatus::Success);
587    }
588
589    #[tokio::test]
590    async fn test_string_validation() {
591        let schema = Arc::new(Schema::new(vec![Field::new(
592            "strings",
593            DataType::Utf8,
594            true,
595        )]));
596
597        let batch = RecordBatch::try_new(
598            schema,
599            vec![Arc::new(StringArray::from(vec![
600                Some("hello"),
601                Some("world"),
602                Some(""),
603                None,
604                Some("test"),
605            ]))],
606        )
607        .unwrap();
608
609        let ctx = create_test_context(batch).await;
610
611        let constraint = DataTypeConstraint::new(
612            "strings",
613            DataTypeValidation::String(StringTypeValidation::NotEmpty),
614        )
615        .unwrap();
616
617        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
618            .await
619            .unwrap();
620        assert_eq!(result.status, ConstraintStatus::Failure);
621        // 3 out of 4 non-null values are not empty (empty string counts as empty)
622        assert!((result.metric.unwrap() - 0.75).abs() < 0.01);
623    }
624}