1use crate::core::{Constraint, ConstraintMetadata, ConstraintResult};
4use crate::prelude::*;
5use crate::security::SqlSecurity;
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use once_cell::sync::Lazy;
9use regex::Regex;
10use std::collections::{HashMap, HashSet};
11use std::sync::RwLock;
12use tracing::instrument;
13
14static REGEX_CACHE: Lazy<RwLock<HashMap<String, Regex>>> =
16 Lazy::new(|| RwLock::new(HashMap::new()));
17
18#[derive(Debug, Clone)]
44pub struct CustomSqlConstraint {
45 expression: String,
46 hint: Option<String>,
47}
48
49impl CustomSqlConstraint {
50 pub fn new(expression: impl Into<String>, hint: Option<impl Into<String>>) -> Result<Self> {
61 let expression = expression.into();
62
63 validate_sql_expression(&expression)?;
65 SqlSecurity::validate_sql_expression(&expression)?;
66
67 Ok(Self {
68 expression,
69 hint: hint.map(Into::into),
70 })
71 }
72
73 pub fn try_new(expression: impl Into<String>, hint: Option<impl Into<String>>) -> Result<Self> {
84 let expression = expression.into();
85
86 validate_sql_expression(&expression)?;
88 SqlSecurity::validate_sql_expression(&expression)?;
89
90 Ok(Self {
91 expression,
92 hint: hint.map(Into::into),
93 })
94 }
95}
96
97fn validate_sql_expression(sql: &str) -> Result<()> {
102 let sql_upper = sql.to_uppercase();
104
105 let dangerous_keywords: HashSet<&str> = [
107 "DROP",
108 "DELETE",
109 "INSERT",
110 "UPDATE",
111 "CREATE",
112 "ALTER",
113 "TRUNCATE",
114 "GRANT",
115 "REVOKE",
116 "EXECUTE",
117 "EXEC",
118 "CALL",
119 "MERGE",
120 "REPLACE",
121 "RENAME",
122 "MODIFY",
123 "SET",
124 "COMMIT",
125 "ROLLBACK",
126 "SAVEPOINT",
127 "BEGIN",
128 "START",
129 "TRANSACTION",
130 "LOCK",
131 "UNLOCK",
132 ]
133 .iter()
134 .copied()
135 .collect();
136
137 for keyword in dangerous_keywords {
139 let pattern = format!(r"\b{keyword}\b");
141
142 let matches = {
144 let cache = REGEX_CACHE.read().map_err(|_| {
145 TermError::Internal("Failed to acquire read lock on regex cache".to_string())
146 })?;
147
148 if let Some(regex) = cache.get(&pattern) {
149 regex.is_match(&sql_upper)
150 } else {
151 drop(cache);
153 let mut write_cache = REGEX_CACHE.write().map_err(|_| {
154 TermError::Internal("Failed to acquire write lock on regex cache".to_string())
155 })?;
156
157 let regex = Regex::new(&pattern).map_err(|e| {
158 TermError::Internal(format!("Failed to compile regex pattern: {e}"))
159 })?;
160 let is_match = regex.is_match(&sql_upper);
161 write_cache.insert(pattern.clone(), regex);
162 is_match
163 }
164 };
165
166 if matches {
167 return Err(TermError::validation_failed(
168 "custom_sql",
169 format!("SQL expression contains forbidden operation: {keyword}"),
170 ));
171 }
172 }
173
174 if sql.contains(';') {
176 return Err(TermError::validation_failed(
177 "custom_sql",
178 "SQL expression cannot contain semicolons",
179 ));
180 }
181
182 if sql.contains("--") || sql.contains("/*") || sql.contains("*/") {
184 return Err(TermError::validation_failed(
185 "custom_sql",
186 "SQL expression cannot contain comments",
187 ));
188 }
189
190 Ok(())
191}
192
193#[async_trait]
194impl Constraint for CustomSqlConstraint {
195 #[instrument(skip(self, ctx), fields(expression = %self.expression))]
196 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
197 let sql = format!(
199 "SELECT
200 COUNT(CASE WHEN {} THEN 1 END) as satisfied,
201 COUNT(*) as total
202 FROM data",
203 self.expression
204 );
205
206 let df = match ctx.sql(&sql).await {
208 Ok(df) => df,
209 Err(e) => {
210 return Ok(ConstraintResult::failure(format!(
212 "SQL expression error: {e}. Expression: '{}'",
213 self.expression
214 )));
215 }
216 };
217
218 let batches = match df.collect().await {
219 Ok(batches) => batches,
220 Err(e) => {
221 return Ok(ConstraintResult::failure(format!(
223 "SQL execution error: {e}. Expression: '{}'",
224 self.expression
225 )));
226 }
227 };
228
229 if batches.is_empty() {
230 return Ok(ConstraintResult::skipped("No data to validate"));
231 }
232
233 let batch = &batches[0];
234 if batch.num_rows() == 0 {
235 return Ok(ConstraintResult::skipped("No data to validate"));
236 }
237
238 let satisfied = batch
240 .column(0)
241 .as_any()
242 .downcast_ref::<arrow::array::Int64Array>()
243 .ok_or_else(|| TermError::Internal("Failed to extract satisfied count".to_string()))?
244 .value(0) as f64;
245
246 let total = batch
247 .column(1)
248 .as_any()
249 .downcast_ref::<arrow::array::Int64Array>()
250 .ok_or_else(|| TermError::Internal("Failed to extract total count".to_string()))?
251 .value(0) as f64;
252
253 if total == 0.0 {
254 return Ok(ConstraintResult::skipped("No data to validate"));
255 }
256
257 let satisfaction_ratio = satisfied / total;
258
259 if satisfaction_ratio == 1.0 {
260 Ok(ConstraintResult::success_with_metric(satisfaction_ratio))
261 } else {
262 let failed_count = total - satisfied;
263 let message = if let Some(hint) = &self.hint {
264 format!("{hint} ({} rows failed the condition)", failed_count as i64)
265 } else {
266 format!(
267 "Custom SQL condition not satisfied for {} rows. Expression: '{}'",
268 failed_count as i64, self.expression
269 )
270 };
271
272 Ok(ConstraintResult::failure_with_metric(
273 satisfaction_ratio,
274 message,
275 ))
276 }
277 }
278
279 fn name(&self) -> &str {
280 "custom_sql"
281 }
282
283 fn metadata(&self) -> ConstraintMetadata {
284 let mut metadata = ConstraintMetadata::new()
285 .with_description(format!(
286 "Checks that all rows satisfy the SQL expression: {}",
287 self.expression
288 ))
289 .with_custom("expression", self.expression.clone())
290 .with_custom("constraint_type", "custom");
291
292 if let Some(hint) = &self.hint {
293 metadata = metadata.with_custom("hint", hint.clone());
294 }
295
296 metadata
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use crate::core::ConstraintStatus;
304 use arrow::array::{Float64Array, Int64Array, StringArray};
305 use arrow::datatypes::{DataType, Field, Schema};
306 use arrow::record_batch::RecordBatch;
307 use datafusion::datasource::MemTable;
308 use std::sync::Arc;
309
310 async fn create_test_context() -> SessionContext {
311 let ctx = SessionContext::new();
312
313 let schema = Arc::new(Schema::new(vec![
314 Field::new("price", DataType::Float64, true),
315 Field::new("quantity", DataType::Int64, true),
316 Field::new("status", DataType::Utf8, true),
317 ]));
318
319 let price_array =
320 Float64Array::from(vec![Some(10.5), Some(25.0), Some(5.0), Some(100.0), None]);
321 let quantity_array = Int64Array::from(vec![Some(5), Some(10), Some(0), Some(20), Some(15)]);
322 let status_array = StringArray::from(vec![
323 Some("active"),
324 Some("active"),
325 Some("inactive"),
326 Some("active"),
327 Some("pending"),
328 ]);
329
330 let batch = RecordBatch::try_new(
331 schema.clone(),
332 vec![
333 Arc::new(price_array),
334 Arc::new(quantity_array),
335 Arc::new(status_array),
336 ],
337 )
338 .unwrap();
339
340 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
341 ctx.register_table("data", Arc::new(provider)).unwrap();
342
343 ctx
344 }
345
346 #[test]
347 fn test_sql_validation_accepts_safe_expressions() {
348 assert!(validate_sql_expression("price > 0").is_ok());
350 assert!(validate_sql_expression("quantity BETWEEN 1 AND 100").is_ok());
351 assert!(validate_sql_expression("status = 'active' AND price < 1000").is_ok());
352 assert!(validate_sql_expression("LENGTH(name) > 3").is_ok());
353 assert!(validate_sql_expression("order_date <= ship_date").is_ok());
354 }
355
356 #[test]
357 fn test_sql_validation_rejects_dangerous_operations() {
358 assert!(validate_sql_expression("DROP TABLE users").is_err());
360 assert!(validate_sql_expression("DELETE FROM data WHERE 1=1").is_err());
361 assert!(validate_sql_expression("UPDATE data SET price = 0").is_err());
362 assert!(validate_sql_expression("price > 0; DROP TABLE data").is_err());
363 assert!(validate_sql_expression("INSERT INTO data VALUES (1, 2, 3)").is_err());
364 assert!(validate_sql_expression("CREATE TABLE new_table (id INT)").is_err());
365 assert!(validate_sql_expression("ALTER TABLE data ADD COLUMN new_col").is_err());
366 assert!(validate_sql_expression("TRUNCATE TABLE data").is_err());
367 assert!(validate_sql_expression("-- comment\nprice > 0").is_err());
368 assert!(validate_sql_expression("price > 0 /* comment */").is_err());
369 }
370
371 #[test]
372 fn test_sql_validation_case_insensitive() {
373 assert!(validate_sql_expression("drop table users").is_err());
375 assert!(validate_sql_expression("DeLeTe FROM data").is_err());
376 assert!(validate_sql_expression("UpDaTe data SET x = 1").is_err());
377 }
378
379 #[test]
380 fn test_sql_validation_word_boundaries() {
381 assert!(validate_sql_expression("updated_at > '2024-01-01'").is_ok());
383 assert!(validate_sql_expression("is_deleted = false").is_ok());
384 assert!(validate_sql_expression("created_by = 'admin'").is_ok());
385 }
386
387 #[tokio::test]
388 async fn test_custom_sql_with_nulls_expression() {
389 let ctx = create_test_context().await;
390
391 let constraint = CustomSqlConstraint::new("price > 0", None::<String>).unwrap();
392
393 let result = constraint.evaluate(&ctx).await.unwrap();
394 assert_eq!(result.status, ConstraintStatus::Failure);
395 assert_eq!(result.metric, Some(0.8)); }
397
398 #[tokio::test]
399 async fn test_custom_sql_all_satisfy() {
400 let ctx = create_test_context().await;
401
402 let constraint = CustomSqlConstraint::new("quantity >= 0", None::<String>).unwrap();
404
405 let result = constraint.evaluate(&ctx).await.unwrap();
406 assert_eq!(result.status, ConstraintStatus::Success);
407 assert_eq!(result.metric, Some(1.0)); }
409
410 #[tokio::test]
411 async fn test_custom_sql_partial_satisfy() {
412 let ctx = create_test_context().await;
413
414 let constraint =
415 CustomSqlConstraint::new("quantity > 0", Some("Quantity must be positive")).unwrap();
416
417 let result = constraint.evaluate(&ctx).await.unwrap();
418 assert_eq!(result.status, ConstraintStatus::Failure);
419 assert_eq!(result.metric, Some(0.8)); assert!(result
421 .message
422 .as_ref()
423 .unwrap()
424 .contains("Quantity must be positive"));
425 assert!(result.message.as_ref().unwrap().contains("1 rows failed"));
426 }
427
428 #[tokio::test]
429 async fn test_custom_sql_complex_expression() {
430 let ctx = create_test_context().await;
431
432 let constraint = CustomSqlConstraint::new(
433 "status = 'active' AND price >= 10",
434 Some("Active items must have price >= 10"),
435 )
436 .unwrap();
437
438 let result = constraint.evaluate(&ctx).await.unwrap();
439 assert_eq!(result.status, ConstraintStatus::Failure);
440 assert_eq!(result.metric, Some(0.6));
442 }
443
444 #[tokio::test]
445 async fn test_custom_sql_with_nulls() {
446 let ctx = create_test_context().await;
447
448 let constraint = CustomSqlConstraint::new("price IS NOT NULL", None::<String>).unwrap();
449
450 let result = constraint.evaluate(&ctx).await.unwrap();
451 assert_eq!(result.status, ConstraintStatus::Failure);
452 assert_eq!(result.metric, Some(0.8)); }
454
455 #[tokio::test]
456 async fn test_custom_sql_invalid_expression() {
457 let ctx = create_test_context().await;
458
459 let constraint = CustomSqlConstraint::new("invalid_column > 0", None::<String>).unwrap();
460
461 let result = constraint.evaluate(&ctx).await.unwrap();
462 assert_eq!(result.status, ConstraintStatus::Failure);
463 assert!(result
464 .message
465 .as_ref()
466 .unwrap()
467 .contains("SQL expression error"));
468 }
469
470 #[test]
471 fn test_new_returns_error_on_dangerous_sql_new() {
472 let result = CustomSqlConstraint::new("DROP TABLE data", None::<String>);
473 assert!(result.is_err());
474 assert!(result
475 .unwrap_err()
476 .to_string()
477 .contains("forbidden operation: DROP"));
478 }
479
480 #[test]
481 fn test_try_new_returns_error_on_dangerous_sql() {
482 let result = CustomSqlConstraint::try_new("DELETE FROM data", None::<String>);
483 assert!(result.is_err());
484 assert!(result
485 .unwrap_err()
486 .to_string()
487 .contains("forbidden operation: DELETE"));
488 }
489}