1use crate::error::{Result, TermError};
7use once_cell::sync::Lazy;
8use regex::Regex;
9use std::collections::HashSet;
10use std::sync::OnceLock;
11use zeroize::{Zeroize, ZeroizeOnDrop};
12
13#[derive(Clone, ZeroizeOnDrop)]
15pub struct SecureString(String);
16
17impl std::fmt::Debug for SecureString {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 write!(f, "SecureString(***)")
20 }
21}
22
23impl SecureString {
24 pub fn new(value: impl Into<String>) -> Self {
26 Self(value.into())
27 }
28
29 pub fn expose(&self) -> &str {
31 &self.0
32 }
33
34 pub fn into_string(mut self) -> String {
36 let value = std::mem::take(&mut self.0);
37 self.0.zeroize();
38 value
39 }
40}
41
42impl From<String> for SecureString {
43 fn from(value: String) -> Self {
44 Self(value)
45 }
46}
47
48impl From<&str> for SecureString {
49 fn from(value: &str) -> Self {
50 Self(value.to_string())
51 }
52}
53
54pub struct SqlSecurity;
56
57impl SqlSecurity {
58 pub fn escape_identifier(identifier: &str) -> Result<String> {
90 Self::validate_identifier(identifier)?;
92
93 let escaped = identifier.replace('"', "\"\"");
95 Ok(format!("\"{escaped}\""))
96 }
97
98 pub fn validate_identifier(identifier: &str) -> Result<()> {
104 if identifier.is_empty() {
106 return Err(TermError::SecurityError(
107 "SQL identifier cannot be empty".to_string(),
108 ));
109 }
110
111 if identifier.len() > 128 {
113 return Err(TermError::SecurityError(
114 "SQL identifier too long (max 128 characters)".to_string(),
115 ));
116 }
117
118 if identifier.contains('\0') {
120 return Err(TermError::SecurityError(
121 "SQL identifier cannot contain null bytes".to_string(),
122 ));
123 }
124
125 static IDENTIFIER_REGEX: Lazy<Regex> = Lazy::new(|| {
127 #[allow(clippy::expect_used)]
131 Regex::new(r#"^[a-zA-Z_"][a-zA-Z0-9_"]*(\.[a-zA-Z_"][a-zA-Z0-9_"]*)*$"#)
132 .expect("Hard-coded regex pattern should be valid")
133 });
134 let regex = &*IDENTIFIER_REGEX;
135
136 if !regex.is_match(identifier) {
137 return Err(TermError::SecurityError(format!(
138 "Invalid SQL identifier format: '{identifier}'. Identifiers must start with a letter or underscore and contain only letters, numbers, underscores, and dots"
139 )));
140 }
141
142 Self::check_dangerous_patterns(identifier)?;
144
145 Ok(())
146 }
147
148 pub fn validate_regex_pattern(pattern: &str) -> Result<String> {
153 if pattern.len() > 1000 {
155 return Err(TermError::SecurityError(
156 "Regex pattern too long (max 1000 characters)".to_string(),
157 ));
158 }
159
160 if pattern.contains('\0') {
162 return Err(TermError::SecurityError(
163 "Regex pattern cannot contain null bytes".to_string(),
164 ));
165 }
166
167 match Regex::new(pattern) {
169 Ok(_) => (),
170 Err(e) => {
171 return Err(TermError::SecurityError(format!(
172 "Invalid regex pattern: {e}"
173 )));
174 }
175 }
176
177 Self::check_redos_patterns(pattern)?;
179
180 let escaped = pattern.replace('\'', "''");
182 Ok(escaped)
183 }
184
185 pub fn validate_sql_expression(expression: &str) -> Result<()> {
191 if expression.len() > 5000 {
193 return Err(TermError::SecurityError(
194 "SQL expression too long (max 5000 characters)".to_string(),
195 ));
196 }
197
198 if expression.contains('\0') {
200 return Err(TermError::SecurityError(
201 "SQL expression cannot contain null bytes".to_string(),
202 ));
203 }
204
205 Self::check_dangerous_sql_patterns(expression)?;
207
208 Ok(())
209 }
210
211 fn check_dangerous_patterns(identifier: &str) -> Result<()> {
213 let identifier_lower = identifier.to_lowercase();
214
215 let dangerous_chars = &[";", "--", "/*", "*/"];
220
221 for pattern in dangerous_chars {
222 if identifier_lower.contains(pattern) {
223 return Err(TermError::SecurityError(format!(
224 "SQL identifier contains dangerous character sequence: '{pattern}'"
225 )));
226 }
227 }
228
229 if identifier_lower.starts_with("xp_") || identifier_lower.starts_with("sp_") {
232 return Err(TermError::SecurityError(
233 "SQL identifier looks like a system stored procedure".to_string(),
234 ));
235 }
236
237 let injection_patterns = &[
240 "union ", "union_", "select ", "select_", "insert ", "insert_", "update ", "update_",
241 "delete ", "delete_", "drop ", "drop_", "create ", "alter ", "exec ", "execute ",
242 "declare ", "cursor ", "fetch ", "open ", "close ",
243 ];
244
245 for pattern in injection_patterns {
246 if identifier_lower.contains(pattern) {
247 return Err(TermError::SecurityError(format!(
248 "SQL identifier contains suspicious SQL keyword pattern: '{}'",
249 pattern.trim_end_matches('_').trim()
250 )));
251 }
252 }
253
254 Ok(())
255 }
256
257 fn check_redos_patterns(pattern: &str) -> Result<()> {
259 let dangerous_patterns = &[
266 "(.*)*", "(.*)+", "(a+)+", "(a*)*", ];
271
272 for dangerous in dangerous_patterns {
273 if pattern.contains(dangerous) {
274 return Err(TermError::SecurityError(
275 "Regex pattern might cause ReDoS attack".to_string(),
276 ));
277 }
278 }
279
280 Ok(())
281 }
282
283 fn check_dangerous_sql_patterns(expression: &str) -> Result<()> {
285 let expression_lower = expression.to_lowercase();
286
287 static DANGEROUS_KEYWORDS: OnceLock<HashSet<&'static str>> = OnceLock::new();
289 let keywords = DANGEROUS_KEYWORDS.get_or_init(|| {
290 [
291 "drop",
293 "create",
294 "alter",
295 "truncate",
296 "insert",
298 "update",
299 "delete",
300 "exec",
302 "execute",
303 "xp_",
304 "sp_",
305 "declare",
307 "cursor",
308 "fetch",
309 "open",
310 "close",
311 "begin",
312 "commit",
313 "rollback",
314 "transaction",
315 "information_schema",
317 "sys.",
318 "pg_",
319 "bulk",
321 "openrowset",
322 "opendatasource",
323 "load_file",
324 "into outfile",
325 "into dumpfile",
326 "--",
328 "/*",
329 "*/",
330 ]
331 .into_iter()
332 .collect()
333 });
334
335 for keyword in keywords {
337 if expression_lower.contains(keyword) {
338 return Err(TermError::SecurityError(format!(
339 "SQL expression contains dangerous keyword: '{keyword}'"
340 )));
341 }
342 }
343
344 let suspicious_patterns = &[
346 r";\s*\w+", r"union\s+select", r"'\s*or\s+'", r"'\s*and\s+'", r"=\s*\(.*select.*\)", r"\(\s*select\s+.*\)", r"in\s*\(\s*select\s+.*\)", ];
354
355 for pattern in suspicious_patterns {
356 if let Ok(regex) = Regex::new(pattern) {
357 if regex.is_match(&expression_lower) {
358 return Err(TermError::SecurityError(format!(
359 "SQL expression contains suspicious pattern matching: {pattern}"
360 )));
361 }
362 }
363 }
364
365 Ok(())
366 }
367}
368
369pub struct InputValidator;
371
372impl InputValidator {
373 pub fn validate_threshold(value: f64, name: &str) -> Result<()> {
375 if !value.is_finite() {
376 return Err(TermError::SecurityError(format!(
377 "Invalid {name} value: must be finite (not NaN or infinite)"
378 )));
379 }
380 Ok(())
381 }
382
383 pub fn validate_percentage(value: f64, name: &str) -> Result<()> {
385 Self::validate_threshold(value, name)?;
386
387 if !(0.0..=1.0).contains(&value) {
388 return Err(TermError::SecurityError(format!(
389 "Invalid {name} value: must be between 0.0 and 1.0, got {value}"
390 )));
391 }
392 Ok(())
393 }
394
395 pub fn validate_string_length(value: &str, max_length: usize, name: &str) -> Result<()> {
397 if value.len() > max_length {
398 return Err(TermError::SecurityError(format!(
399 "{name} too long: {} characters (max {max_length})",
400 value.len()
401 )));
402 }
403 Ok(())
404 }
405
406 pub fn validate_no_null_bytes(value: &str, name: &str) -> Result<()> {
408 if value.contains('\0') {
409 return Err(TermError::SecurityError(format!(
410 "{name} cannot contain null bytes"
411 )));
412 }
413 Ok(())
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_secure_string_zeroization() {
423 let password = "secret123";
424 let secure = SecureString::new(password);
425 assert_eq!(secure.expose(), "secret123");
426
427 let _extracted = secure.into_string();
428 }
430
431 #[test]
432 fn test_valid_sql_identifiers() {
433 assert!(SqlSecurity::validate_identifier("customer_id").is_ok());
434 assert!(SqlSecurity::validate_identifier("table1").is_ok());
435 assert!(SqlSecurity::validate_identifier("_private_col").is_ok());
436 assert!(SqlSecurity::validate_identifier("schema.table").is_ok());
437 }
438
439 #[test]
440 fn test_invalid_sql_identifiers() {
441 assert!(SqlSecurity::validate_identifier("").is_err());
443
444 assert!(SqlSecurity::validate_identifier(&"a".repeat(200)).is_err());
446
447 assert!(SqlSecurity::validate_identifier("id; DROP TABLE").is_err());
449 assert!(SqlSecurity::validate_identifier("col--comment").is_err());
450 assert!(SqlSecurity::validate_identifier("union_select").is_err());
451
452 assert!(SqlSecurity::validate_identifier("col name").is_err()); assert!(SqlSecurity::validate_identifier("col-name").is_err()); assert!(SqlSecurity::validate_identifier("123col").is_err()); }
457
458 #[test]
459 fn test_sql_identifier_escaping() {
460 let result = SqlSecurity::escape_identifier("customer_id").unwrap();
461 assert_eq!(result, "\"customer_id\"");
462
463 let result = SqlSecurity::escape_identifier("col\"with\"quotes").unwrap();
464 assert_eq!(result, "\"col\"\"with\"\"quotes\"");
465 }
466
467 #[test]
468 fn test_regex_pattern_validation() {
469 assert!(SqlSecurity::validate_regex_pattern(r"^[A-Z]\d+$").is_ok());
470 assert!(SqlSecurity::validate_regex_pattern(r"email@domain\.com").is_ok());
471
472 assert!(SqlSecurity::validate_regex_pattern(r"[unclosed").is_err());
474
475 assert!(SqlSecurity::validate_regex_pattern(&"a".repeat(2000)).is_err());
477
478 let result = SqlSecurity::validate_regex_pattern("it's a pattern").unwrap();
480 assert_eq!(result, "it''s a pattern");
481 }
482
483 #[test]
484 fn test_sql_expression_validation() {
485 assert!(SqlSecurity::validate_sql_expression("price > 100").is_ok());
487 assert!(SqlSecurity::validate_sql_expression("name IS NOT NULL").is_ok());
488 assert!(SqlSecurity::validate_sql_expression("age BETWEEN 18 AND 65").is_ok());
489
490 assert!(SqlSecurity::validate_sql_expression("price > 0; DROP TABLE users").is_err());
492 assert!(SqlSecurity::validate_sql_expression("name = '' OR '1'='1'").is_err());
493 assert!(SqlSecurity::validate_sql_expression("id IN (SELECT * FROM passwords)").is_err());
494 assert!(SqlSecurity::validate_sql_expression("EXEC sp_droplogin").is_err());
495 }
496
497 #[test]
498 fn test_input_validation() {
499 assert!(InputValidator::validate_threshold(5.5, "threshold").is_ok());
501 assert!(InputValidator::validate_percentage(0.95, "percentage").is_ok());
502 assert!(InputValidator::validate_string_length("short", 100, "name").is_ok());
503
504 assert!(InputValidator::validate_threshold(f64::NAN, "threshold").is_err());
506 assert!(InputValidator::validate_threshold(f64::INFINITY, "threshold").is_err());
507 assert!(InputValidator::validate_percentage(1.5, "percentage").is_err());
508 assert!(InputValidator::validate_percentage(-0.1, "percentage").is_err());
509 assert!(InputValidator::validate_string_length("too long", 5, "name").is_err());
510 assert!(InputValidator::validate_no_null_bytes("contains\0null", "name").is_err());
511 }
512}