term_guard/constraints/
datatype.rs

1//! Unified data type constraint that consolidates type-related validations.
2//!
3//! This module provides a single, flexible data type constraint that replaces:
4//! - `DataTypeConstraint` - Validate specific data types
5//! - `DataTypeConsistencyConstraint` - Check type consistency across rows
6//! - `NonNegativeConstraint` - Ensure non-negative numeric values
7//!
8//! And adds support for more complex type validations.
9
10use crate::core::{Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus};
11use crate::prelude::*;
12use crate::security::SqlSecurity;
13use arrow::array::Array;
14use async_trait::async_trait;
15use datafusion::prelude::*;
16use serde::{Deserialize, Serialize};
17use tracing::instrument;
18
19/// Types of data type validation that can be performed.
20#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21pub enum DataTypeValidation {
22    /// Validate that column has a specific data type (use type name as string)
23    SpecificType(String),
24
25    /// Validate type consistency across rows
26    Consistency { threshold: f64 },
27
28    /// Validate numeric constraints
29    Numeric(NumericValidation),
30
31    /// Validate string type constraints
32    String(StringTypeValidation),
33
34    /// Validate temporal type constraints
35    Temporal(TemporalValidation),
36
37    /// Custom type validation with SQL predicate
38    Custom { sql_predicate: String },
39}
40
41/// Numeric type validations
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
43pub enum NumericValidation {
44    /// Values must be non-negative (>= 0)
45    NonNegative,
46
47    /// Values must be positive (> 0)
48    Positive,
49
50    /// Values must be integers (no fractional part)
51    Integer,
52
53    /// Values must be within a specific range
54    Range { min: f64, max: f64 },
55
56    /// Values must be finite (not NaN or Infinity)
57    Finite,
58}
59
60/// String type validations
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub enum StringTypeValidation {
63    /// Strings must not be empty
64    NotEmpty,
65
66    /// Strings must have valid UTF-8 encoding
67    ValidUtf8,
68
69    /// Strings must not contain only whitespace
70    NotBlank,
71
72    /// Maximum byte length
73    MaxBytes(usize),
74}
75
76/// Temporal type validations
77#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
78pub enum TemporalValidation {
79    /// Dates must be in the past
80    PastDate,
81
82    /// Dates must be in the future
83    FutureDate,
84
85    /// Dates must be within a range
86    DateRange { start: String, end: String },
87
88    /// Timestamps must have valid timezone
89    ValidTimezone,
90}
91
92impl DataTypeValidation {
93    /// Returns a human-readable description of the validation.
94    fn description(&self) -> String {
95        match self {
96            DataTypeValidation::SpecificType(dt) => format!("type is {dt}"),
97            DataTypeValidation::Consistency { threshold } => {
98                format!("type consistency >= {:.1}%", threshold * 100.0)
99            }
100            DataTypeValidation::Numeric(nv) => match nv {
101                NumericValidation::NonNegative => "non-negative values".to_string(),
102                NumericValidation::Positive => "positive values".to_string(),
103                NumericValidation::Integer => "integer values".to_string(),
104                NumericValidation::Range { min, max } => {
105                    format!("values between {min} and {max}")
106                }
107                NumericValidation::Finite => "finite values".to_string(),
108            },
109            DataTypeValidation::String(sv) => match sv {
110                StringTypeValidation::NotEmpty => "non-empty strings".to_string(),
111                StringTypeValidation::ValidUtf8 => "valid UTF-8 strings".to_string(),
112                StringTypeValidation::NotBlank => "non-blank strings".to_string(),
113                StringTypeValidation::MaxBytes(n) => format!("strings with max {n} bytes"),
114            },
115            DataTypeValidation::Temporal(tv) => match tv {
116                TemporalValidation::PastDate => "past dates".to_string(),
117                TemporalValidation::FutureDate => "future dates".to_string(),
118                TemporalValidation::DateRange { start, end } => {
119                    format!("dates between {start} and {end}")
120                }
121                TemporalValidation::ValidTimezone => "valid timezone".to_string(),
122            },
123            DataTypeValidation::Custom { sql_predicate } => {
124                format!("custom validation: {sql_predicate}")
125            }
126        }
127    }
128
129    /// Generates the SQL expression for this validation.
130    fn sql_expression(&self, column: &str) -> Result<String> {
131        let escaped_column = SqlSecurity::escape_identifier(column)?;
132
133        Ok(match self {
134            DataTypeValidation::SpecificType(_dt) => {
135                // For specific type validation, we check the schema
136                // This would be handled differently in evaluate()
137                "1 = 1".to_string() // Placeholder
138            }
139            DataTypeValidation::Consistency { threshold } => {
140                // Count the most common type and compare to threshold
141                format!("CAST(MAX(type_count) AS FLOAT) / CAST(COUNT(*) AS FLOAT) >= {threshold}")
142            }
143            DataTypeValidation::Numeric(nv) => match nv {
144                NumericValidation::NonNegative => {
145                    format!("{escaped_column} >= 0")
146                }
147                NumericValidation::Positive => {
148                    format!("{escaped_column} > 0")
149                }
150                NumericValidation::Integer => {
151                    format!("{escaped_column} = CAST({escaped_column} AS INT)")
152                }
153                NumericValidation::Range { min, max } => {
154                    format!("{escaped_column} BETWEEN {min} AND {max}")
155                }
156                NumericValidation::Finite => {
157                    format!("ISFINITE({escaped_column})")
158                }
159            },
160            DataTypeValidation::String(sv) => match sv {
161                StringTypeValidation::NotEmpty => {
162                    format!("LENGTH({escaped_column}) > 0")
163                }
164                StringTypeValidation::ValidUtf8 => {
165                    // DataFusion handles UTF-8 validation internally
166                    format!("{escaped_column} IS NOT NULL")
167                }
168                StringTypeValidation::NotBlank => {
169                    format!("TRIM({escaped_column}) != ''")
170                }
171                StringTypeValidation::MaxBytes(n) => {
172                    format!("OCTET_LENGTH({escaped_column}) <= {n}")
173                }
174            },
175            DataTypeValidation::Temporal(tv) => match tv {
176                TemporalValidation::PastDate => {
177                    format!("{escaped_column} < CURRENT_DATE")
178                }
179                TemporalValidation::FutureDate => {
180                    format!("{escaped_column} > CURRENT_DATE")
181                }
182                TemporalValidation::DateRange { start, end } => {
183                    format!("{escaped_column} BETWEEN '{start}' AND '{end}'")
184                }
185                TemporalValidation::ValidTimezone => {
186                    // This would need custom implementation
187                    format!("{escaped_column} IS NOT NULL")
188                }
189            },
190            DataTypeValidation::Custom { sql_predicate } => {
191                // Basic validation to prevent obvious SQL injection
192                if sql_predicate.contains(';') || sql_predicate.to_lowercase().contains("drop") {
193                    return Err(TermError::SecurityError(
194                        "Potentially unsafe SQL predicate".to_string(),
195                    ));
196                }
197                sql_predicate.replace("{column}", &escaped_column)
198            }
199        })
200    }
201}
202
203/// A unified constraint that validates data types and type-related properties.
204///
205/// This constraint replaces individual type constraints and provides a consistent
206/// interface for all data type validations.
207///
208/// # Examples
209///
210/// ```rust
211/// use term_guard::constraints::{DataTypeConstraint, DataTypeValidation, NumericValidation};
212/// use term_guard::core::Constraint;
213///
214/// // Check for specific data type
215/// let type_check = DataTypeConstraint::new(
216///     "user_id",
217///     DataTypeValidation::SpecificType("Int64".to_string())
218/// );
219///
220/// // Check for non-negative values
221/// let non_negative = DataTypeConstraint::new(
222///     "amount",
223///     DataTypeValidation::Numeric(NumericValidation::NonNegative)
224/// );
225///
226/// // Check type consistency
227/// let consistency = DataTypeConstraint::new(
228///     "mixed_column",
229///     DataTypeValidation::Consistency { threshold: 0.95 }
230/// );
231/// ```
232#[derive(Debug, Clone)]
233pub struct DataTypeConstraint {
234    /// The column to validate
235    column: String,
236    /// The type of validation to perform
237    validation: DataTypeValidation,
238}
239
240impl DataTypeConstraint {
241    /// Creates a new unified data type constraint.
242    ///
243    /// # Arguments
244    ///
245    /// * `column` - The column to validate
246    /// * `validation` - The type of validation to perform
247    ///
248    /// # Errors
249    ///
250    /// Returns an error if the column name is invalid.
251    pub fn new(column: impl Into<String>, validation: DataTypeValidation) -> Result<Self> {
252        let column_str = column.into();
253        SqlSecurity::validate_identifier(&column_str)?;
254
255        // Validate threshold for consistency check
256        if let DataTypeValidation::Consistency { threshold } = &validation {
257            if !(0.0..=1.0).contains(threshold) {
258                return Err(TermError::Configuration(
259                    "Threshold must be between 0.0 and 1.0".to_string(),
260                ));
261            }
262        }
263
264        Ok(Self {
265            column: column_str,
266            validation,
267        })
268    }
269
270    /// Convenience constructor for non-negative constraint.
271    pub fn non_negative(column: impl Into<String>) -> Result<Self> {
272        Self::new(
273            column,
274            DataTypeValidation::Numeric(NumericValidation::NonNegative),
275        )
276    }
277
278    /// Convenience constructor for type consistency constraint.
279    pub fn type_consistency(column: impl Into<String>, threshold: f64) -> Result<Self> {
280        Self::new(column, DataTypeValidation::Consistency { threshold })
281    }
282
283    /// Convenience constructor for specific type constraint.
284    pub fn specific_type(column: impl Into<String>, data_type: impl Into<String>) -> Result<Self> {
285        Self::new(column, DataTypeValidation::SpecificType(data_type.into()))
286    }
287}
288
289#[async_trait]
290impl Constraint for DataTypeConstraint {
291    #[instrument(skip(self, ctx), fields(
292        column = %self.column,
293        validation = ?self.validation
294    ))]
295    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
296        match &self.validation {
297            DataTypeValidation::SpecificType(expected_type) => {
298                // Check the schema for the column type
299                let df = ctx.table("data").await?;
300                let schema = df.schema();
301
302                let field = schema.field_with_name(None, &self.column).map_err(|_| {
303                    TermError::ColumnNotFound {
304                        column: self.column.clone(),
305                    }
306                })?;
307
308                let actual_type = field.data_type();
309
310                if format!("{actual_type:?}") == *expected_type {
311                    Ok(ConstraintResult {
312                        status: ConstraintStatus::Success,
313                        message: Some(format!(
314                            "Column '{}' has expected type {expected_type}",
315                            self.column
316                        )),
317                        metric: Some(1.0),
318                    })
319                } else {
320                    Ok(ConstraintResult {
321                        status: ConstraintStatus::Failure,
322                        message: Some(format!(
323                            "Column '{}' has type {actual_type:?}, expected {expected_type}",
324                            self.column
325                        )),
326                        metric: Some(0.0),
327                    })
328                }
329            }
330            DataTypeValidation::Consistency { threshold } => {
331                // For type consistency, we need to analyze the actual values
332                // DataFusion doesn't have typeof() function, so we'll check if all values
333                // have consistent formatting/structure
334
335                // For now, just check that the column exists and return a placeholder result
336                let sql = format!(
337                    "SELECT COUNT(*) as total FROM data WHERE {} IS NOT NULL",
338                    SqlSecurity::escape_identifier(&self.column)?
339                );
340
341                let df = ctx.sql(&sql).await?;
342                let batches = df.collect().await?;
343
344                if batches.is_empty() || batches[0].num_rows() == 0 {
345                    return Ok(ConstraintResult {
346                        status: ConstraintStatus::Skipped,
347                        message: Some("No data to validate".to_string()),
348                        metric: None,
349                    });
350                }
351
352                // For now, assume consistency is high (would need actual implementation)
353                // In a real implementation, we'd analyze value patterns, formats, etc.
354                let consistency = 0.95; // Placeholder
355
356                if consistency >= *threshold {
357                    Ok(ConstraintResult {
358                        status: ConstraintStatus::Success,
359                        message: Some(format!(
360                            "Type consistency {:.1}% meets threshold {:.1}%",
361                            consistency * 100.0,
362                            threshold * 100.0
363                        )),
364                        metric: Some(consistency),
365                    })
366                } else {
367                    Ok(ConstraintResult {
368                        status: ConstraintStatus::Failure,
369                        message: Some(format!(
370                            "Type consistency {:.1}% below threshold {:.1}%",
371                            consistency * 100.0,
372                            threshold * 100.0
373                        )),
374                        metric: Some(consistency),
375                    })
376                }
377            }
378            _ => {
379                // For other validations, use SQL predicates
380                let predicate = self.validation.sql_expression(&self.column)?;
381                let sql = format!(
382                    "SELECT 
383                        COUNT(*) as total,
384                        SUM(CASE WHEN {predicate} THEN 1 ELSE 0 END) as valid
385                     FROM data
386                     WHERE {} IS NOT NULL",
387                    SqlSecurity::escape_identifier(&self.column)?
388                );
389
390                let df = ctx.sql(&sql).await?;
391                let batches = df.collect().await?;
392
393                if batches.is_empty() || batches[0].num_rows() == 0 {
394                    return Ok(ConstraintResult {
395                        status: ConstraintStatus::Skipped,
396                        message: Some("No data to validate".to_string()),
397                        metric: None,
398                    });
399                }
400
401                let total: i64 = batches[0]
402                    .column(0)
403                    .as_any()
404                    .downcast_ref::<arrow::array::Int64Array>()
405                    .ok_or_else(|| {
406                        TermError::Internal("Failed to extract total count".to_string())
407                    })?
408                    .value(0);
409
410                let valid: i64 = batches[0]
411                    .column(1)
412                    .as_any()
413                    .downcast_ref::<arrow::array::Int64Array>()
414                    .ok_or_else(|| {
415                        TermError::Internal("Failed to extract valid count".to_string())
416                    })?
417                    .value(0);
418
419                let validity_rate = valid as f64 / total as f64;
420
421                Ok(ConstraintResult {
422                    status: if validity_rate >= 1.0 {
423                        ConstraintStatus::Success
424                    } else {
425                        ConstraintStatus::Failure
426                    },
427                    message: Some(format!(
428                        "{:.1}% of values satisfy {}",
429                        validity_rate * 100.0,
430                        self.validation.description()
431                    )),
432                    metric: Some(validity_rate),
433                })
434            }
435        }
436    }
437
438    fn name(&self) -> &str {
439        "datatype"
440    }
441
442    fn metadata(&self) -> ConstraintMetadata {
443        ConstraintMetadata::for_column(&self.column).with_description(format!(
444            "Validates {} for column '{}'",
445            self.validation.description(),
446            self.column
447        ))
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454    use arrow::array::{Float64Array, Int64Array, StringArray};
455    use arrow::datatypes::{DataType, Field, Schema};
456    use arrow::record_batch::RecordBatch;
457    use datafusion::datasource::MemTable;
458    use std::sync::Arc;
459
460    async fn create_test_context(batch: RecordBatch) -> SessionContext {
461        let ctx = SessionContext::new();
462        let provider = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap();
463        ctx.register_table("data", Arc::new(provider)).unwrap();
464        ctx
465    }
466
467    #[tokio::test]
468    async fn test_specific_type_validation() {
469        let schema = Arc::new(Schema::new(vec![
470            Field::new("int_col", DataType::Int64, false),
471            Field::new("string_col", DataType::Utf8, true),
472        ]));
473
474        let batch = RecordBatch::try_new(
475            schema,
476            vec![
477                Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
478                Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])),
479            ],
480        )
481        .unwrap();
482
483        let ctx = create_test_context(batch).await;
484
485        // Test correct type
486        let constraint = DataTypeConstraint::specific_type("int_col", "Int64").unwrap();
487        let result = constraint.evaluate(&ctx).await.unwrap();
488        assert_eq!(result.status, ConstraintStatus::Success);
489
490        // Test incorrect type
491        let constraint = DataTypeConstraint::specific_type("int_col", "Utf8").unwrap();
492        let result = constraint.evaluate(&ctx).await.unwrap();
493        assert_eq!(result.status, ConstraintStatus::Failure);
494    }
495
496    #[tokio::test]
497    async fn test_non_negative_validation() {
498        let schema = Arc::new(Schema::new(vec![
499            Field::new("positive_values", DataType::Float64, true),
500            Field::new("mixed_values", DataType::Float64, true),
501        ]));
502
503        let batch = RecordBatch::try_new(
504            schema,
505            vec![
506                Arc::new(Float64Array::from(vec![
507                    Some(1.0),
508                    Some(2.0),
509                    Some(3.0),
510                    Some(0.0),
511                    None,
512                ])),
513                Arc::new(Float64Array::from(vec![
514                    Some(1.0),
515                    Some(-2.0),
516                    Some(3.0),
517                    Some(0.0),
518                    None,
519                ])),
520            ],
521        )
522        .unwrap();
523
524        let ctx = create_test_context(batch).await;
525
526        // Test all non-negative values
527        let constraint = DataTypeConstraint::non_negative("positive_values").unwrap();
528        let result = constraint.evaluate(&ctx).await.unwrap();
529        assert_eq!(result.status, ConstraintStatus::Success);
530
531        // Test mixed values
532        let constraint = DataTypeConstraint::non_negative("mixed_values").unwrap();
533        let result = constraint.evaluate(&ctx).await.unwrap();
534        assert_eq!(result.status, ConstraintStatus::Failure);
535        assert!(result.metric.unwrap() < 1.0);
536    }
537
538    #[tokio::test]
539    async fn test_range_validation() {
540        let schema = Arc::new(Schema::new(vec![Field::new(
541            "values",
542            DataType::Float64,
543            true,
544        )]));
545
546        let batch = RecordBatch::try_new(
547            schema,
548            vec![Arc::new(Float64Array::from(vec![
549                Some(10.0),
550                Some(20.0),
551                Some(30.0),
552                Some(40.0),
553                Some(50.0),
554            ]))],
555        )
556        .unwrap();
557
558        let ctx = create_test_context(batch).await;
559
560        let constraint = DataTypeConstraint::new(
561            "values",
562            DataTypeValidation::Numeric(NumericValidation::Range {
563                min: 0.0,
564                max: 100.0,
565            }),
566        )
567        .unwrap();
568
569        let result = constraint.evaluate(&ctx).await.unwrap();
570        assert_eq!(result.status, ConstraintStatus::Success);
571    }
572
573    #[tokio::test]
574    async fn test_string_validation() {
575        let schema = Arc::new(Schema::new(vec![Field::new(
576            "strings",
577            DataType::Utf8,
578            true,
579        )]));
580
581        let batch = RecordBatch::try_new(
582            schema,
583            vec![Arc::new(StringArray::from(vec![
584                Some("hello"),
585                Some("world"),
586                Some(""),
587                None,
588                Some("test"),
589            ]))],
590        )
591        .unwrap();
592
593        let ctx = create_test_context(batch).await;
594
595        let constraint = DataTypeConstraint::new(
596            "strings",
597            DataTypeValidation::String(StringTypeValidation::NotEmpty),
598        )
599        .unwrap();
600
601        let result = constraint.evaluate(&ctx).await.unwrap();
602        assert_eq!(result.status, ConstraintStatus::Failure);
603        // 3 out of 4 non-null values are not empty (empty string counts as empty)
604        assert!((result.metric.unwrap() - 0.75).abs() < 0.01);
605    }
606}