1use crate::core::{current_validation_context, Constraint, ConstraintMetadata, ConstraintResult};
4use crate::prelude::*;
5use crate::security::SqlSecurity;
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use once_cell::sync::Lazy;
9use regex::Regex;
10use std::collections::{HashMap, HashSet};
11use std::sync::RwLock;
12use tracing::instrument;
13static REGEX_CACHE: Lazy<RwLock<HashMap<String, Regex>>> =
15 Lazy::new(|| RwLock::new(HashMap::new()));
16
17#[derive(Debug, Clone)]
43pub struct CustomSqlConstraint {
44 expression: String,
45 hint: Option<String>,
46}
47
48impl CustomSqlConstraint {
49 pub fn new(expression: impl Into<String>, hint: Option<impl Into<String>>) -> Result<Self> {
60 let expression = expression.into();
61
62 validate_sql_expression(&expression)?;
64 SqlSecurity::validate_sql_expression(&expression)?;
65
66 Ok(Self {
67 expression,
68 hint: hint.map(Into::into),
69 })
70 }
71
72 pub fn try_new(expression: impl Into<String>, hint: Option<impl Into<String>>) -> Result<Self> {
83 let expression = expression.into();
84
85 validate_sql_expression(&expression)?;
87 SqlSecurity::validate_sql_expression(&expression)?;
88
89 Ok(Self {
90 expression,
91 hint: hint.map(Into::into),
92 })
93 }
94}
95
96fn validate_sql_expression(sql: &str) -> Result<()> {
101 let sql_upper = sql.to_uppercase();
103
104 let dangerous_keywords: HashSet<&str> = [
106 "DROP",
107 "DELETE",
108 "INSERT",
109 "UPDATE",
110 "CREATE",
111 "ALTER",
112 "TRUNCATE",
113 "GRANT",
114 "REVOKE",
115 "EXECUTE",
116 "EXEC",
117 "CALL",
118 "MERGE",
119 "REPLACE",
120 "RENAME",
121 "MODIFY",
122 "SET",
123 "COMMIT",
124 "ROLLBACK",
125 "SAVEPOINT",
126 "BEGIN",
127 "START",
128 "TRANSACTION",
129 "LOCK",
130 "UNLOCK",
131 ]
132 .iter()
133 .copied()
134 .collect();
135
136 for keyword in dangerous_keywords {
138 let pattern = format!(r"\b{keyword}\b");
140
141 let matches = {
143 let cache = REGEX_CACHE.read().map_err(|_| {
144 TermError::Internal("Failed to acquire read lock on regex cache".to_string())
145 })?;
146
147 if let Some(regex) = cache.get(&pattern) {
148 regex.is_match(&sql_upper)
149 } else {
150 drop(cache);
152 let mut write_cache = REGEX_CACHE.write().map_err(|_| {
153 TermError::Internal("Failed to acquire write lock on regex cache".to_string())
154 })?;
155
156 let regex = Regex::new(&pattern).map_err(|e| {
157 TermError::Internal(format!("Failed to compile regex pattern: {e}"))
158 })?;
159 let is_match = regex.is_match(&sql_upper);
160 write_cache.insert(pattern.clone(), regex);
161 is_match
162 }
163 };
164
165 if matches {
166 return Err(TermError::validation_failed(
167 "custom_sql",
168 format!("SQL expression contains forbidden operation: {keyword}"),
169 ));
170 }
171 }
172
173 if sql.contains(';') {
175 return Err(TermError::validation_failed(
176 "custom_sql",
177 "SQL expression cannot contain semicolons",
178 ));
179 }
180
181 if sql.contains("--") || sql.contains("/*") || sql.contains("*/") {
183 return Err(TermError::validation_failed(
184 "custom_sql",
185 "SQL expression cannot contain comments",
186 ));
187 }
188
189 Ok(())
190}
191
192#[async_trait]
193impl Constraint for CustomSqlConstraint {
194 #[instrument(skip(self, ctx), fields(expression = %self.expression))]
195 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
196 let validation_ctx = current_validation_context();
200
201 let table_name = validation_ctx.table_name();
202
203 let sql = format!(
204 "SELECT
205 COUNT(CASE WHEN {} THEN 1 END) as satisfied,
206 COUNT(*) as total
207 FROM {table_name}",
208 self.expression
209 );
210
211 let df = match ctx.sql(&sql).await {
213 Ok(df) => df,
214 Err(e) => {
215 return Ok(ConstraintResult::failure(format!(
217 "SQL expression error: {e}. Expression: '{}'",
218 self.expression
219 )));
220 }
221 };
222
223 let batches = match df.collect().await {
224 Ok(batches) => batches,
225 Err(e) => {
226 return Ok(ConstraintResult::failure(format!(
228 "SQL execution error: {e}. Expression: '{}'",
229 self.expression
230 )));
231 }
232 };
233
234 if batches.is_empty() {
235 return Ok(ConstraintResult::skipped("No data to validate"));
236 }
237
238 let batch = &batches[0];
239 if batch.num_rows() == 0 {
240 return Ok(ConstraintResult::skipped("No data to validate"));
241 }
242
243 let satisfied = batch
245 .column(0)
246 .as_any()
247 .downcast_ref::<arrow::array::Int64Array>()
248 .ok_or_else(|| TermError::Internal("Failed to extract satisfied count".to_string()))?
249 .value(0) as f64;
250
251 let total = batch
252 .column(1)
253 .as_any()
254 .downcast_ref::<arrow::array::Int64Array>()
255 .ok_or_else(|| TermError::Internal("Failed to extract total count".to_string()))?
256 .value(0) as f64;
257
258 if total == 0.0 {
259 return Ok(ConstraintResult::skipped("No data to validate"));
260 }
261
262 let satisfaction_ratio = satisfied / total;
263
264 if satisfaction_ratio == 1.0 {
265 Ok(ConstraintResult::success_with_metric(satisfaction_ratio))
266 } else {
267 let failed_count = total - satisfied;
268 let message = if let Some(hint) = &self.hint {
269 format!("{hint} ({} rows failed the condition)", failed_count as i64)
270 } else {
271 format!(
272 "Custom SQL condition not satisfied for {} rows. Expression: '{}'",
273 failed_count as i64, self.expression
274 )
275 };
276
277 Ok(ConstraintResult::failure_with_metric(
278 satisfaction_ratio,
279 message,
280 ))
281 }
282 }
283
284 fn name(&self) -> &str {
285 "custom_sql"
286 }
287
288 fn metadata(&self) -> ConstraintMetadata {
289 let mut metadata = ConstraintMetadata::new()
290 .with_description(format!(
291 "Checks that all rows satisfy the SQL expression: {}",
292 self.expression
293 ))
294 .with_custom("expression", self.expression.clone())
295 .with_custom("constraint_type", "custom");
296
297 if let Some(hint) = &self.hint {
298 metadata = metadata.with_custom("hint", hint.clone());
299 }
300
301 metadata
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::core::ConstraintStatus;
309 use arrow::array::{Float64Array, Int64Array, StringArray};
310 use arrow::datatypes::{DataType, Field, Schema};
311 use arrow::record_batch::RecordBatch;
312 use datafusion::datasource::MemTable;
313 use std::sync::Arc;
314
315 use crate::test_helpers::evaluate_constraint_with_context;
316 async fn create_test_context() -> SessionContext {
317 let ctx = SessionContext::new();
318
319 let schema = Arc::new(Schema::new(vec![
320 Field::new("price", DataType::Float64, true),
321 Field::new("quantity", DataType::Int64, true),
322 Field::new("status", DataType::Utf8, true),
323 ]));
324
325 let price_array =
326 Float64Array::from(vec![Some(10.5), Some(25.0), Some(5.0), Some(100.0), None]);
327 let quantity_array = Int64Array::from(vec![Some(5), Some(10), Some(0), Some(20), Some(15)]);
328 let status_array = StringArray::from(vec![
329 Some("active"),
330 Some("active"),
331 Some("inactive"),
332 Some("active"),
333 Some("pending"),
334 ]);
335
336 let batch = RecordBatch::try_new(
337 schema.clone(),
338 vec![
339 Arc::new(price_array),
340 Arc::new(quantity_array),
341 Arc::new(status_array),
342 ],
343 )
344 .unwrap();
345
346 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
347 ctx.register_table("data", Arc::new(provider)).unwrap();
348
349 ctx
350 }
351
352 #[test]
353 fn test_sql_validation_accepts_safe_expressions() {
354 assert!(validate_sql_expression("price > 0").is_ok());
356 assert!(validate_sql_expression("quantity BETWEEN 1 AND 100").is_ok());
357 assert!(validate_sql_expression("status = 'active' AND price < 1000").is_ok());
358 assert!(validate_sql_expression("LENGTH(name) > 3").is_ok());
359 assert!(validate_sql_expression("order_date <= ship_date").is_ok());
360 }
361
362 #[test]
363 fn test_sql_validation_rejects_dangerous_operations() {
364 assert!(validate_sql_expression("DROP TABLE users").is_err());
366 assert!(validate_sql_expression("DELETE FROM {table_name} WHERE 1=1").is_err());
367 assert!(validate_sql_expression("UPDATE data SET price = 0").is_err());
368 assert!(validate_sql_expression("price > 0; DROP TABLE data").is_err());
369 assert!(validate_sql_expression("INSERT INTO data VALUES (1, 2, 3)").is_err());
370 assert!(validate_sql_expression("CREATE TABLE new_table (id INT)").is_err());
371 assert!(validate_sql_expression("ALTER TABLE data ADD COLUMN new_col").is_err());
372 assert!(validate_sql_expression("TRUNCATE TABLE data").is_err());
373 assert!(validate_sql_expression("-- comment\nprice > 0").is_err());
374 assert!(validate_sql_expression("price > 0 /* comment */").is_err());
375 }
376
377 #[test]
378 fn test_sql_validation_case_insensitive() {
379 assert!(validate_sql_expression("drop table users").is_err());
381 assert!(validate_sql_expression("DeLeTe FROM {table_name}").is_err());
382 assert!(validate_sql_expression("UpDaTe data SET x = 1").is_err());
383 }
384
385 #[test]
386 fn test_sql_validation_word_boundaries() {
387 assert!(validate_sql_expression("updated_at > '2024-01-01'").is_ok());
389 assert!(validate_sql_expression("is_deleted = false").is_ok());
390 assert!(validate_sql_expression("created_by = 'admin'").is_ok());
391 }
392
393 #[tokio::test]
394 async fn test_custom_sql_with_nulls_expression() {
395 let ctx = create_test_context().await;
396
397 let constraint = CustomSqlConstraint::new("price > 0", None::<String>).unwrap();
398
399 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
400 .await
401 .unwrap();
402 assert_eq!(result.status, ConstraintStatus::Failure);
403 assert_eq!(result.metric, Some(0.8)); }
405
406 #[tokio::test]
407 async fn test_custom_sql_all_satisfy() {
408 let ctx = create_test_context().await;
409
410 let constraint = CustomSqlConstraint::new("quantity >= 0", None::<String>).unwrap();
412
413 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
414 .await
415 .unwrap();
416 assert_eq!(result.status, ConstraintStatus::Success);
417 assert_eq!(result.metric, Some(1.0)); }
419
420 #[tokio::test]
421 async fn test_custom_sql_partial_satisfy() {
422 let ctx = create_test_context().await;
423
424 let constraint =
425 CustomSqlConstraint::new("quantity > 0", Some("Quantity must be positive")).unwrap();
426
427 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
428 .await
429 .unwrap();
430 assert_eq!(result.status, ConstraintStatus::Failure);
431 assert_eq!(result.metric, Some(0.8)); assert!(result
433 .message
434 .as_ref()
435 .unwrap()
436 .contains("Quantity must be positive"));
437 assert!(result.message.as_ref().unwrap().contains("1 rows failed"));
438 }
439
440 #[tokio::test]
441 async fn test_custom_sql_complex_expression() {
442 let ctx = create_test_context().await;
443
444 let constraint = CustomSqlConstraint::new(
445 "status = 'active' AND price >= 10",
446 Some("Active items must have price >= 10"),
447 )
448 .unwrap();
449
450 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
451 .await
452 .unwrap();
453 assert_eq!(result.status, ConstraintStatus::Failure);
454 assert_eq!(result.metric, Some(0.6));
456 }
457
458 #[tokio::test]
459 async fn test_custom_sql_with_nulls() {
460 let ctx = create_test_context().await;
461
462 let constraint = CustomSqlConstraint::new("price IS NOT NULL", None::<String>).unwrap();
463
464 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
465 .await
466 .unwrap();
467 assert_eq!(result.status, ConstraintStatus::Failure);
468 assert_eq!(result.metric, Some(0.8)); }
470
471 #[tokio::test]
472 async fn test_custom_sql_invalid_expression() {
473 let ctx = create_test_context().await;
474
475 let constraint = CustomSqlConstraint::new("invalid_column > 0", None::<String>).unwrap();
476
477 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
478 .await
479 .unwrap();
480 assert_eq!(result.status, ConstraintStatus::Failure);
481 assert!(result
482 .message
483 .as_ref()
484 .unwrap()
485 .contains("SQL expression error"));
486 }
487
488 #[test]
489 fn test_new_returns_error_on_dangerous_sql_new() {
490 let result = CustomSqlConstraint::new("DROP TABLE data", None::<String>);
491 assert!(result.is_err());
492 assert!(result
493 .unwrap_err()
494 .to_string()
495 .contains("forbidden operation: DROP"));
496 }
497
498 #[test]
499 fn test_try_new_returns_error_on_dangerous_sql() {
500 let result = CustomSqlConstraint::try_new("DELETE FROM {table_name}", None::<String>);
501 assert!(result.is_err());
502 assert!(result
503 .unwrap_err()
504 .to_string()
505 .contains("forbidden operation: DELETE"));
506 }
507}