term_guard/constraints/
custom_sql.rs

1//! Custom SQL validation constraints.
2
3use crate::core::{current_validation_context, Constraint, ConstraintMetadata, ConstraintResult};
4use crate::prelude::*;
5use crate::security::SqlSecurity;
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use once_cell::sync::Lazy;
9use regex::Regex;
10use std::collections::{HashMap, HashSet};
11use std::sync::RwLock;
12use tracing::instrument;
13/// Cache for compiled regex patterns to avoid recompiling
14static REGEX_CACHE: Lazy<RwLock<HashMap<String, Regex>>> =
15    Lazy::new(|| RwLock::new(HashMap::new()));
16
17/// A constraint that evaluates custom SQL expressions.
18///
19/// This constraint allows users to define custom validation logic using SQL expressions
20/// while preventing dangerous operations like DROP, DELETE, UPDATE, etc.
21///
22/// # Examples
23///
24/// ```rust
25/// use term_guard::constraints::CustomSqlConstraint;
26/// use term_guard::core::Constraint;
27///
28/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
29/// // Check that values in a column meet a custom condition
30/// let constraint = CustomSqlConstraint::new("price > 0 AND price < 1000000", None::<String>)?;
31/// assert_eq!(constraint.name(), "custom_sql");
32///
33/// // With a custom hint message
34/// let constraint = CustomSqlConstraint::new(
35///     "order_date <= ship_date",
36///     Some("Shipping date must be after or equal to order date")
37/// )?;
38/// # Ok(())
39/// # }
40/// # example().unwrap();
41/// ```
42#[derive(Debug, Clone)]
43pub struct CustomSqlConstraint {
44    expression: String,
45    hint: Option<String>,
46}
47
48impl CustomSqlConstraint {
49    /// Creates a new custom SQL constraint.
50    ///
51    /// # Arguments
52    ///
53    /// * `expression` - The SQL expression to evaluate (should return a boolean)
54    /// * `hint` - Optional hint message to provide context when the constraint fails
55    ///
56    /// # Errors
57    ///
58    /// Returns an error if the SQL expression contains dangerous operations
59    pub fn new(expression: impl Into<String>, hint: Option<impl Into<String>>) -> Result<Self> {
60        let expression = expression.into();
61
62        // Validate the SQL expression for safety using both local and security module validation
63        validate_sql_expression(&expression)?;
64        SqlSecurity::validate_sql_expression(&expression)?;
65
66        Ok(Self {
67            expression,
68            hint: hint.map(Into::into),
69        })
70    }
71
72    /// Attempts to create a new custom SQL constraint, returning an error if validation fails.
73    ///
74    /// # Arguments
75    ///
76    /// * `expression` - The SQL expression to evaluate
77    /// * `hint` - Optional hint message
78    ///
79    /// # Returns
80    ///
81    /// A Result containing the constraint or a validation error
82    pub fn try_new(expression: impl Into<String>, hint: Option<impl Into<String>>) -> Result<Self> {
83        let expression = expression.into();
84
85        // Validate the SQL expression for safety using both local and security module validation
86        validate_sql_expression(&expression)?;
87        SqlSecurity::validate_sql_expression(&expression)?;
88
89        Ok(Self {
90            expression,
91            hint: hint.map(Into::into),
92        })
93    }
94}
95
96/// Validates that a SQL expression doesn't contain dangerous operations.
97///
98/// This function checks for keywords that could modify data or schema,
99/// ensuring the expression is read-only.
100fn validate_sql_expression(sql: &str) -> Result<()> {
101    // Convert to uppercase for case-insensitive comparison
102    let sql_upper = sql.to_uppercase();
103
104    // Define dangerous keywords that should not be allowed
105    let dangerous_keywords: HashSet<&str> = [
106        "DROP",
107        "DELETE",
108        "INSERT",
109        "UPDATE",
110        "CREATE",
111        "ALTER",
112        "TRUNCATE",
113        "GRANT",
114        "REVOKE",
115        "EXECUTE",
116        "EXEC",
117        "CALL",
118        "MERGE",
119        "REPLACE",
120        "RENAME",
121        "MODIFY",
122        "SET",
123        "COMMIT",
124        "ROLLBACK",
125        "SAVEPOINT",
126        "BEGIN",
127        "START",
128        "TRANSACTION",
129        "LOCK",
130        "UNLOCK",
131    ]
132    .iter()
133    .copied()
134    .collect();
135
136    // Check for dangerous keywords
137    for keyword in dangerous_keywords {
138        // Use word boundaries to avoid false positives (e.g., "UPDATE" in "UPDATED_AT")
139        let pattern = format!(r"\b{keyword}\b");
140
141        // Check cache first
142        let matches = {
143            let cache = REGEX_CACHE.read().map_err(|_| {
144                TermError::Internal("Failed to acquire read lock on regex cache".to_string())
145            })?;
146
147            if let Some(regex) = cache.get(&pattern) {
148                regex.is_match(&sql_upper)
149            } else {
150                // Need to compile and cache the regex
151                drop(cache);
152                let mut write_cache = REGEX_CACHE.write().map_err(|_| {
153                    TermError::Internal("Failed to acquire write lock on regex cache".to_string())
154                })?;
155
156                let regex = Regex::new(&pattern).map_err(|e| {
157                    TermError::Internal(format!("Failed to compile regex pattern: {e}"))
158                })?;
159                let is_match = regex.is_match(&sql_upper);
160                write_cache.insert(pattern.clone(), regex);
161                is_match
162            }
163        };
164
165        if matches {
166            return Err(TermError::validation_failed(
167                "custom_sql",
168                format!("SQL expression contains forbidden operation: {keyword}"),
169            ));
170        }
171    }
172
173    // Check for semicolons which could be used to inject multiple statements
174    if sql.contains(';') {
175        return Err(TermError::validation_failed(
176            "custom_sql",
177            "SQL expression cannot contain semicolons",
178        ));
179    }
180
181    // Check for comment sequences that could be used to bypass validation
182    if sql.contains("--") || sql.contains("/*") || sql.contains("*/") {
183        return Err(TermError::validation_failed(
184            "custom_sql",
185            "SQL expression cannot contain comments",
186        ));
187    }
188
189    Ok(())
190}
191
192#[async_trait]
193impl Constraint for CustomSqlConstraint {
194    #[instrument(skip(self, ctx), fields(expression = %self.expression))]
195    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
196        // Wrap the expression in a query that counts rows where the condition is true
197        // Get the table name from the validation context
198
199        let validation_ctx = current_validation_context();
200
201        let table_name = validation_ctx.table_name();
202
203        let sql = format!(
204            "SELECT 
205                COUNT(CASE WHEN {} THEN 1 END) as satisfied,
206                COUNT(*) as total
207             FROM {table_name}",
208            self.expression
209        );
210
211        // Try to execute the SQL
212        let df = match ctx.sql(&sql).await {
213            Ok(df) => df,
214            Err(e) => {
215                // Return a clear error message for SQL errors
216                return Ok(ConstraintResult::failure(format!(
217                    "SQL expression error: {e}. Expression: '{}'",
218                    self.expression
219                )));
220            }
221        };
222
223        let batches = match df.collect().await {
224            Ok(batches) => batches,
225            Err(e) => {
226                // Return a clear error message for execution errors
227                return Ok(ConstraintResult::failure(format!(
228                    "SQL execution error: {e}. Expression: '{}'",
229                    self.expression
230                )));
231            }
232        };
233
234        if batches.is_empty() {
235            return Ok(ConstraintResult::skipped("No data to validate"));
236        }
237
238        let batch = &batches[0];
239        if batch.num_rows() == 0 {
240            return Ok(ConstraintResult::skipped("No data to validate"));
241        }
242
243        // Extract results
244        let satisfied = batch
245            .column(0)
246            .as_any()
247            .downcast_ref::<arrow::array::Int64Array>()
248            .ok_or_else(|| TermError::Internal("Failed to extract satisfied count".to_string()))?
249            .value(0) as f64;
250
251        let total = batch
252            .column(1)
253            .as_any()
254            .downcast_ref::<arrow::array::Int64Array>()
255            .ok_or_else(|| TermError::Internal("Failed to extract total count".to_string()))?
256            .value(0) as f64;
257
258        if total == 0.0 {
259            return Ok(ConstraintResult::skipped("No data to validate"));
260        }
261
262        let satisfaction_ratio = satisfied / total;
263
264        if satisfaction_ratio == 1.0 {
265            Ok(ConstraintResult::success_with_metric(satisfaction_ratio))
266        } else {
267            let failed_count = total - satisfied;
268            let message = if let Some(hint) = &self.hint {
269                format!("{hint} ({} rows failed the condition)", failed_count as i64)
270            } else {
271                format!(
272                    "Custom SQL condition not satisfied for {} rows. Expression: '{}'",
273                    failed_count as i64, self.expression
274                )
275            };
276
277            Ok(ConstraintResult::failure_with_metric(
278                satisfaction_ratio,
279                message,
280            ))
281        }
282    }
283
284    fn name(&self) -> &str {
285        "custom_sql"
286    }
287
288    fn metadata(&self) -> ConstraintMetadata {
289        let mut metadata = ConstraintMetadata::new()
290            .with_description(format!(
291                "Checks that all rows satisfy the SQL expression: {}",
292                self.expression
293            ))
294            .with_custom("expression", self.expression.clone())
295            .with_custom("constraint_type", "custom");
296
297        if let Some(hint) = &self.hint {
298            metadata = metadata.with_custom("hint", hint.clone());
299        }
300
301        metadata
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use crate::core::ConstraintStatus;
309    use arrow::array::{Float64Array, Int64Array, StringArray};
310    use arrow::datatypes::{DataType, Field, Schema};
311    use arrow::record_batch::RecordBatch;
312    use datafusion::datasource::MemTable;
313    use std::sync::Arc;
314
315    use crate::test_helpers::evaluate_constraint_with_context;
316    async fn create_test_context() -> SessionContext {
317        let ctx = SessionContext::new();
318
319        let schema = Arc::new(Schema::new(vec![
320            Field::new("price", DataType::Float64, true),
321            Field::new("quantity", DataType::Int64, true),
322            Field::new("status", DataType::Utf8, true),
323        ]));
324
325        let price_array =
326            Float64Array::from(vec![Some(10.5), Some(25.0), Some(5.0), Some(100.0), None]);
327        let quantity_array = Int64Array::from(vec![Some(5), Some(10), Some(0), Some(20), Some(15)]);
328        let status_array = StringArray::from(vec![
329            Some("active"),
330            Some("active"),
331            Some("inactive"),
332            Some("active"),
333            Some("pending"),
334        ]);
335
336        let batch = RecordBatch::try_new(
337            schema.clone(),
338            vec![
339                Arc::new(price_array),
340                Arc::new(quantity_array),
341                Arc::new(status_array),
342            ],
343        )
344        .unwrap();
345
346        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
347        ctx.register_table("data", Arc::new(provider)).unwrap();
348
349        ctx
350    }
351
352    #[test]
353    fn test_sql_validation_accepts_safe_expressions() {
354        // These should all be accepted
355        assert!(validate_sql_expression("price > 0").is_ok());
356        assert!(validate_sql_expression("quantity BETWEEN 1 AND 100").is_ok());
357        assert!(validate_sql_expression("status = 'active' AND price < 1000").is_ok());
358        assert!(validate_sql_expression("LENGTH(name) > 3").is_ok());
359        assert!(validate_sql_expression("order_date <= ship_date").is_ok());
360    }
361
362    #[test]
363    fn test_sql_validation_rejects_dangerous_operations() {
364        // These should all be rejected
365        assert!(validate_sql_expression("DROP TABLE users").is_err());
366        assert!(validate_sql_expression("DELETE FROM {table_name} WHERE 1=1").is_err());
367        assert!(validate_sql_expression("UPDATE data SET price = 0").is_err());
368        assert!(validate_sql_expression("price > 0; DROP TABLE data").is_err());
369        assert!(validate_sql_expression("INSERT INTO data VALUES (1, 2, 3)").is_err());
370        assert!(validate_sql_expression("CREATE TABLE new_table (id INT)").is_err());
371        assert!(validate_sql_expression("ALTER TABLE data ADD COLUMN new_col").is_err());
372        assert!(validate_sql_expression("TRUNCATE TABLE data").is_err());
373        assert!(validate_sql_expression("-- comment\nprice > 0").is_err());
374        assert!(validate_sql_expression("price > 0 /* comment */").is_err());
375    }
376
377    #[test]
378    fn test_sql_validation_case_insensitive() {
379        // Should reject regardless of case
380        assert!(validate_sql_expression("drop table users").is_err());
381        assert!(validate_sql_expression("DeLeTe FROM {table_name}").is_err());
382        assert!(validate_sql_expression("UpDaTe data SET x = 1").is_err());
383    }
384
385    #[test]
386    fn test_sql_validation_word_boundaries() {
387        // Should not reject if keyword is part of a larger word
388        assert!(validate_sql_expression("updated_at > '2024-01-01'").is_ok());
389        assert!(validate_sql_expression("is_deleted = false").is_ok());
390        assert!(validate_sql_expression("created_by = 'admin'").is_ok());
391    }
392
393    #[tokio::test]
394    async fn test_custom_sql_with_nulls_expression() {
395        let ctx = create_test_context().await;
396
397        let constraint = CustomSqlConstraint::new("price > 0", None::<String>).unwrap();
398
399        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
400            .await
401            .unwrap();
402        assert_eq!(result.status, ConstraintStatus::Failure);
403        assert_eq!(result.metric, Some(0.8)); // 4 out of 5 rows satisfy (NULL doesn't satisfy)
404    }
405
406    #[tokio::test]
407    async fn test_custom_sql_all_satisfy() {
408        let ctx = create_test_context().await;
409
410        // Using quantity > -1 will be true for all rows (all quantities are >= 0)
411        let constraint = CustomSqlConstraint::new("quantity >= 0", None::<String>).unwrap();
412
413        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
414            .await
415            .unwrap();
416        assert_eq!(result.status, ConstraintStatus::Success);
417        assert_eq!(result.metric, Some(1.0)); // All rows satisfy
418    }
419
420    #[tokio::test]
421    async fn test_custom_sql_partial_satisfy() {
422        let ctx = create_test_context().await;
423
424        let constraint =
425            CustomSqlConstraint::new("quantity > 0", Some("Quantity must be positive")).unwrap();
426
427        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
428            .await
429            .unwrap();
430        assert_eq!(result.status, ConstraintStatus::Failure);
431        assert_eq!(result.metric, Some(0.8)); // 4 out of 5 have quantity > 0
432        assert!(result
433            .message
434            .as_ref()
435            .unwrap()
436            .contains("Quantity must be positive"));
437        assert!(result.message.as_ref().unwrap().contains("1 rows failed"));
438    }
439
440    #[tokio::test]
441    async fn test_custom_sql_complex_expression() {
442        let ctx = create_test_context().await;
443
444        let constraint = CustomSqlConstraint::new(
445            "status = 'active' AND price >= 10",
446            Some("Active items must have price >= 10"),
447        )
448        .unwrap();
449
450        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
451            .await
452            .unwrap();
453        assert_eq!(result.status, ConstraintStatus::Failure);
454        // Only 3 rows have status='active' AND price >= 10
455        assert_eq!(result.metric, Some(0.6));
456    }
457
458    #[tokio::test]
459    async fn test_custom_sql_with_nulls() {
460        let ctx = create_test_context().await;
461
462        let constraint = CustomSqlConstraint::new("price IS NOT NULL", None::<String>).unwrap();
463
464        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
465            .await
466            .unwrap();
467        assert_eq!(result.status, ConstraintStatus::Failure);
468        assert_eq!(result.metric, Some(0.8)); // 4 out of 5 have non-null price
469    }
470
471    #[tokio::test]
472    async fn test_custom_sql_invalid_expression() {
473        let ctx = create_test_context().await;
474
475        let constraint = CustomSqlConstraint::new("invalid_column > 0", None::<String>).unwrap();
476
477        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
478            .await
479            .unwrap();
480        assert_eq!(result.status, ConstraintStatus::Failure);
481        assert!(result
482            .message
483            .as_ref()
484            .unwrap()
485            .contains("SQL expression error"));
486    }
487
488    #[test]
489    fn test_new_returns_error_on_dangerous_sql_new() {
490        let result = CustomSqlConstraint::new("DROP TABLE data", None::<String>);
491        assert!(result.is_err());
492        assert!(result
493            .unwrap_err()
494            .to_string()
495            .contains("forbidden operation: DROP"));
496    }
497
498    #[test]
499    fn test_try_new_returns_error_on_dangerous_sql() {
500        let result = CustomSqlConstraint::try_new("DELETE FROM {table_name}", None::<String>);
501        assert!(result.is_err());
502        assert!(result
503            .unwrap_err()
504            .to_string()
505            .contains("forbidden operation: DELETE"));
506    }
507}