term_guard/constraints/
custom_sql.rs

1//! Custom SQL validation constraints.
2
3use crate::core::{Constraint, ConstraintMetadata, ConstraintResult};
4use crate::prelude::*;
5use crate::security::SqlSecurity;
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use once_cell::sync::Lazy;
9use regex::Regex;
10use std::collections::{HashMap, HashSet};
11use std::sync::RwLock;
12use tracing::instrument;
13
14/// Cache for compiled regex patterns to avoid recompiling
15static REGEX_CACHE: Lazy<RwLock<HashMap<String, Regex>>> =
16    Lazy::new(|| RwLock::new(HashMap::new()));
17
18/// A constraint that evaluates custom SQL expressions.
19///
20/// This constraint allows users to define custom validation logic using SQL expressions
21/// while preventing dangerous operations like DROP, DELETE, UPDATE, etc.
22///
23/// # Examples
24///
25/// ```rust
26/// use term_guard::constraints::CustomSqlConstraint;
27/// use term_guard::core::Constraint;
28///
29/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
30/// // Check that values in a column meet a custom condition
31/// let constraint = CustomSqlConstraint::new("price > 0 AND price < 1000000", None::<String>)?;
32/// assert_eq!(constraint.name(), "custom_sql");
33///
34/// // With a custom hint message
35/// let constraint = CustomSqlConstraint::new(
36///     "order_date <= ship_date",
37///     Some("Shipping date must be after or equal to order date")
38/// )?;
39/// # Ok(())
40/// # }
41/// # example().unwrap();
42/// ```
43#[derive(Debug, Clone)]
44pub struct CustomSqlConstraint {
45    expression: String,
46    hint: Option<String>,
47}
48
49impl CustomSqlConstraint {
50    /// Creates a new custom SQL constraint.
51    ///
52    /// # Arguments
53    ///
54    /// * `expression` - The SQL expression to evaluate (should return a boolean)
55    /// * `hint` - Optional hint message to provide context when the constraint fails
56    ///
57    /// # Errors
58    ///
59    /// Returns an error if the SQL expression contains dangerous operations
60    pub fn new(expression: impl Into<String>, hint: Option<impl Into<String>>) -> Result<Self> {
61        let expression = expression.into();
62
63        // Validate the SQL expression for safety using both local and security module validation
64        validate_sql_expression(&expression)?;
65        SqlSecurity::validate_sql_expression(&expression)?;
66
67        Ok(Self {
68            expression,
69            hint: hint.map(Into::into),
70        })
71    }
72
73    /// Attempts to create a new custom SQL constraint, returning an error if validation fails.
74    ///
75    /// # Arguments
76    ///
77    /// * `expression` - The SQL expression to evaluate
78    /// * `hint` - Optional hint message
79    ///
80    /// # Returns
81    ///
82    /// A Result containing the constraint or a validation error
83    pub fn try_new(expression: impl Into<String>, hint: Option<impl Into<String>>) -> Result<Self> {
84        let expression = expression.into();
85
86        // Validate the SQL expression for safety using both local and security module validation
87        validate_sql_expression(&expression)?;
88        SqlSecurity::validate_sql_expression(&expression)?;
89
90        Ok(Self {
91            expression,
92            hint: hint.map(Into::into),
93        })
94    }
95}
96
97/// Validates that a SQL expression doesn't contain dangerous operations.
98///
99/// This function checks for keywords that could modify data or schema,
100/// ensuring the expression is read-only.
101fn validate_sql_expression(sql: &str) -> Result<()> {
102    // Convert to uppercase for case-insensitive comparison
103    let sql_upper = sql.to_uppercase();
104
105    // Define dangerous keywords that should not be allowed
106    let dangerous_keywords: HashSet<&str> = [
107        "DROP",
108        "DELETE",
109        "INSERT",
110        "UPDATE",
111        "CREATE",
112        "ALTER",
113        "TRUNCATE",
114        "GRANT",
115        "REVOKE",
116        "EXECUTE",
117        "EXEC",
118        "CALL",
119        "MERGE",
120        "REPLACE",
121        "RENAME",
122        "MODIFY",
123        "SET",
124        "COMMIT",
125        "ROLLBACK",
126        "SAVEPOINT",
127        "BEGIN",
128        "START",
129        "TRANSACTION",
130        "LOCK",
131        "UNLOCK",
132    ]
133    .iter()
134    .copied()
135    .collect();
136
137    // Check for dangerous keywords
138    for keyword in dangerous_keywords {
139        // Use word boundaries to avoid false positives (e.g., "UPDATE" in "UPDATED_AT")
140        let pattern = format!(r"\b{keyword}\b");
141
142        // Check cache first
143        let matches = {
144            let cache = REGEX_CACHE.read().map_err(|_| {
145                TermError::Internal("Failed to acquire read lock on regex cache".to_string())
146            })?;
147
148            if let Some(regex) = cache.get(&pattern) {
149                regex.is_match(&sql_upper)
150            } else {
151                // Need to compile and cache the regex
152                drop(cache);
153                let mut write_cache = REGEX_CACHE.write().map_err(|_| {
154                    TermError::Internal("Failed to acquire write lock on regex cache".to_string())
155                })?;
156
157                let regex = Regex::new(&pattern).map_err(|e| {
158                    TermError::Internal(format!("Failed to compile regex pattern: {e}"))
159                })?;
160                let is_match = regex.is_match(&sql_upper);
161                write_cache.insert(pattern.clone(), regex);
162                is_match
163            }
164        };
165
166        if matches {
167            return Err(TermError::validation_failed(
168                "custom_sql",
169                format!("SQL expression contains forbidden operation: {keyword}"),
170            ));
171        }
172    }
173
174    // Check for semicolons which could be used to inject multiple statements
175    if sql.contains(';') {
176        return Err(TermError::validation_failed(
177            "custom_sql",
178            "SQL expression cannot contain semicolons",
179        ));
180    }
181
182    // Check for comment sequences that could be used to bypass validation
183    if sql.contains("--") || sql.contains("/*") || sql.contains("*/") {
184        return Err(TermError::validation_failed(
185            "custom_sql",
186            "SQL expression cannot contain comments",
187        ));
188    }
189
190    Ok(())
191}
192
193#[async_trait]
194impl Constraint for CustomSqlConstraint {
195    #[instrument(skip(self, ctx), fields(expression = %self.expression))]
196    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
197        // Wrap the expression in a query that counts rows where the condition is true
198        let sql = format!(
199            "SELECT 
200                COUNT(CASE WHEN {} THEN 1 END) as satisfied,
201                COUNT(*) as total
202             FROM data",
203            self.expression
204        );
205
206        // Try to execute the SQL
207        let df = match ctx.sql(&sql).await {
208            Ok(df) => df,
209            Err(e) => {
210                // Return a clear error message for SQL errors
211                return Ok(ConstraintResult::failure(format!(
212                    "SQL expression error: {e}. Expression: '{}'",
213                    self.expression
214                )));
215            }
216        };
217
218        let batches = match df.collect().await {
219            Ok(batches) => batches,
220            Err(e) => {
221                // Return a clear error message for execution errors
222                return Ok(ConstraintResult::failure(format!(
223                    "SQL execution error: {e}. Expression: '{}'",
224                    self.expression
225                )));
226            }
227        };
228
229        if batches.is_empty() {
230            return Ok(ConstraintResult::skipped("No data to validate"));
231        }
232
233        let batch = &batches[0];
234        if batch.num_rows() == 0 {
235            return Ok(ConstraintResult::skipped("No data to validate"));
236        }
237
238        // Extract results
239        let satisfied = batch
240            .column(0)
241            .as_any()
242            .downcast_ref::<arrow::array::Int64Array>()
243            .ok_or_else(|| TermError::Internal("Failed to extract satisfied count".to_string()))?
244            .value(0) as f64;
245
246        let total = batch
247            .column(1)
248            .as_any()
249            .downcast_ref::<arrow::array::Int64Array>()
250            .ok_or_else(|| TermError::Internal("Failed to extract total count".to_string()))?
251            .value(0) as f64;
252
253        if total == 0.0 {
254            return Ok(ConstraintResult::skipped("No data to validate"));
255        }
256
257        let satisfaction_ratio = satisfied / total;
258
259        if satisfaction_ratio == 1.0 {
260            Ok(ConstraintResult::success_with_metric(satisfaction_ratio))
261        } else {
262            let failed_count = total - satisfied;
263            let message = if let Some(hint) = &self.hint {
264                format!("{hint} ({} rows failed the condition)", failed_count as i64)
265            } else {
266                format!(
267                    "Custom SQL condition not satisfied for {} rows. Expression: '{}'",
268                    failed_count as i64, self.expression
269                )
270            };
271
272            Ok(ConstraintResult::failure_with_metric(
273                satisfaction_ratio,
274                message,
275            ))
276        }
277    }
278
279    fn name(&self) -> &str {
280        "custom_sql"
281    }
282
283    fn metadata(&self) -> ConstraintMetadata {
284        let mut metadata = ConstraintMetadata::new()
285            .with_description(format!(
286                "Checks that all rows satisfy the SQL expression: {}",
287                self.expression
288            ))
289            .with_custom("expression", self.expression.clone())
290            .with_custom("constraint_type", "custom");
291
292        if let Some(hint) = &self.hint {
293            metadata = metadata.with_custom("hint", hint.clone());
294        }
295
296        metadata
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use crate::core::ConstraintStatus;
304    use arrow::array::{Float64Array, Int64Array, StringArray};
305    use arrow::datatypes::{DataType, Field, Schema};
306    use arrow::record_batch::RecordBatch;
307    use datafusion::datasource::MemTable;
308    use std::sync::Arc;
309
310    async fn create_test_context() -> SessionContext {
311        let ctx = SessionContext::new();
312
313        let schema = Arc::new(Schema::new(vec![
314            Field::new("price", DataType::Float64, true),
315            Field::new("quantity", DataType::Int64, true),
316            Field::new("status", DataType::Utf8, true),
317        ]));
318
319        let price_array =
320            Float64Array::from(vec![Some(10.5), Some(25.0), Some(5.0), Some(100.0), None]);
321        let quantity_array = Int64Array::from(vec![Some(5), Some(10), Some(0), Some(20), Some(15)]);
322        let status_array = StringArray::from(vec![
323            Some("active"),
324            Some("active"),
325            Some("inactive"),
326            Some("active"),
327            Some("pending"),
328        ]);
329
330        let batch = RecordBatch::try_new(
331            schema.clone(),
332            vec![
333                Arc::new(price_array),
334                Arc::new(quantity_array),
335                Arc::new(status_array),
336            ],
337        )
338        .unwrap();
339
340        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
341        ctx.register_table("data", Arc::new(provider)).unwrap();
342
343        ctx
344    }
345
346    #[test]
347    fn test_sql_validation_accepts_safe_expressions() {
348        // These should all be accepted
349        assert!(validate_sql_expression("price > 0").is_ok());
350        assert!(validate_sql_expression("quantity BETWEEN 1 AND 100").is_ok());
351        assert!(validate_sql_expression("status = 'active' AND price < 1000").is_ok());
352        assert!(validate_sql_expression("LENGTH(name) > 3").is_ok());
353        assert!(validate_sql_expression("order_date <= ship_date").is_ok());
354    }
355
356    #[test]
357    fn test_sql_validation_rejects_dangerous_operations() {
358        // These should all be rejected
359        assert!(validate_sql_expression("DROP TABLE users").is_err());
360        assert!(validate_sql_expression("DELETE FROM data WHERE 1=1").is_err());
361        assert!(validate_sql_expression("UPDATE data SET price = 0").is_err());
362        assert!(validate_sql_expression("price > 0; DROP TABLE data").is_err());
363        assert!(validate_sql_expression("INSERT INTO data VALUES (1, 2, 3)").is_err());
364        assert!(validate_sql_expression("CREATE TABLE new_table (id INT)").is_err());
365        assert!(validate_sql_expression("ALTER TABLE data ADD COLUMN new_col").is_err());
366        assert!(validate_sql_expression("TRUNCATE TABLE data").is_err());
367        assert!(validate_sql_expression("-- comment\nprice > 0").is_err());
368        assert!(validate_sql_expression("price > 0 /* comment */").is_err());
369    }
370
371    #[test]
372    fn test_sql_validation_case_insensitive() {
373        // Should reject regardless of case
374        assert!(validate_sql_expression("drop table users").is_err());
375        assert!(validate_sql_expression("DeLeTe FROM data").is_err());
376        assert!(validate_sql_expression("UpDaTe data SET x = 1").is_err());
377    }
378
379    #[test]
380    fn test_sql_validation_word_boundaries() {
381        // Should not reject if keyword is part of a larger word
382        assert!(validate_sql_expression("updated_at > '2024-01-01'").is_ok());
383        assert!(validate_sql_expression("is_deleted = false").is_ok());
384        assert!(validate_sql_expression("created_by = 'admin'").is_ok());
385    }
386
387    #[tokio::test]
388    async fn test_custom_sql_with_nulls_expression() {
389        let ctx = create_test_context().await;
390
391        let constraint = CustomSqlConstraint::new("price > 0", None::<String>).unwrap();
392
393        let result = constraint.evaluate(&ctx).await.unwrap();
394        assert_eq!(result.status, ConstraintStatus::Failure);
395        assert_eq!(result.metric, Some(0.8)); // 4 out of 5 rows satisfy (NULL doesn't satisfy)
396    }
397
398    #[tokio::test]
399    async fn test_custom_sql_all_satisfy() {
400        let ctx = create_test_context().await;
401
402        // Using quantity > -1 will be true for all rows (all quantities are >= 0)
403        let constraint = CustomSqlConstraint::new("quantity >= 0", None::<String>).unwrap();
404
405        let result = constraint.evaluate(&ctx).await.unwrap();
406        assert_eq!(result.status, ConstraintStatus::Success);
407        assert_eq!(result.metric, Some(1.0)); // All rows satisfy
408    }
409
410    #[tokio::test]
411    async fn test_custom_sql_partial_satisfy() {
412        let ctx = create_test_context().await;
413
414        let constraint =
415            CustomSqlConstraint::new("quantity > 0", Some("Quantity must be positive")).unwrap();
416
417        let result = constraint.evaluate(&ctx).await.unwrap();
418        assert_eq!(result.status, ConstraintStatus::Failure);
419        assert_eq!(result.metric, Some(0.8)); // 4 out of 5 have quantity > 0
420        assert!(result
421            .message
422            .as_ref()
423            .unwrap()
424            .contains("Quantity must be positive"));
425        assert!(result.message.as_ref().unwrap().contains("1 rows failed"));
426    }
427
428    #[tokio::test]
429    async fn test_custom_sql_complex_expression() {
430        let ctx = create_test_context().await;
431
432        let constraint = CustomSqlConstraint::new(
433            "status = 'active' AND price >= 10",
434            Some("Active items must have price >= 10"),
435        )
436        .unwrap();
437
438        let result = constraint.evaluate(&ctx).await.unwrap();
439        assert_eq!(result.status, ConstraintStatus::Failure);
440        // Only 3 rows have status='active' AND price >= 10
441        assert_eq!(result.metric, Some(0.6));
442    }
443
444    #[tokio::test]
445    async fn test_custom_sql_with_nulls() {
446        let ctx = create_test_context().await;
447
448        let constraint = CustomSqlConstraint::new("price IS NOT NULL", None::<String>).unwrap();
449
450        let result = constraint.evaluate(&ctx).await.unwrap();
451        assert_eq!(result.status, ConstraintStatus::Failure);
452        assert_eq!(result.metric, Some(0.8)); // 4 out of 5 have non-null price
453    }
454
455    #[tokio::test]
456    async fn test_custom_sql_invalid_expression() {
457        let ctx = create_test_context().await;
458
459        let constraint = CustomSqlConstraint::new("invalid_column > 0", None::<String>).unwrap();
460
461        let result = constraint.evaluate(&ctx).await.unwrap();
462        assert_eq!(result.status, ConstraintStatus::Failure);
463        assert!(result
464            .message
465            .as_ref()
466            .unwrap()
467            .contains("SQL expression error"));
468    }
469
470    #[test]
471    fn test_new_returns_error_on_dangerous_sql_new() {
472        let result = CustomSqlConstraint::new("DROP TABLE data", None::<String>);
473        assert!(result.is_err());
474        assert!(result
475            .unwrap_err()
476            .to_string()
477            .contains("forbidden operation: DROP"));
478    }
479
480    #[test]
481    fn test_try_new_returns_error_on_dangerous_sql() {
482        let result = CustomSqlConstraint::try_new("DELETE FROM data", None::<String>);
483        assert!(result.is_err());
484        assert!(result
485            .unwrap_err()
486            .to_string()
487            .contains("forbidden operation: DELETE"));
488    }
489}