1use crate::error::{Result, TermError};
7use once_cell::sync::Lazy;
8use regex::Regex;
9use std::collections::HashSet;
10use std::sync::OnceLock;
11use zeroize::{Zeroize, ZeroizeOnDrop};
12
13#[derive(Clone, ZeroizeOnDrop)]
15pub struct SecureString(String);
16
17impl std::fmt::Debug for SecureString {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 write!(f, "SecureString(***)")
20 }
21}
22
23impl SecureString {
24 pub fn new(value: impl Into<String>) -> Self {
26 Self(value.into())
27 }
28
29 pub fn expose(&self) -> &str {
31 &self.0
32 }
33
34 pub fn into_string(mut self) -> String {
36 let value = std::mem::take(&mut self.0);
37 self.0.zeroize();
38 value
39 }
40}
41
42impl From<String> for SecureString {
43 fn from(value: String) -> Self {
44 Self(value)
45 }
46}
47
48impl From<&str> for SecureString {
49 fn from(value: &str) -> Self {
50 Self(value.to_string())
51 }
52}
53
54pub struct SqlSecurity;
56
57impl SqlSecurity {
58 pub fn escape_identifier(identifier: &str) -> Result<String> {
90 Self::validate_identifier(identifier)?;
92
93 let escaped = identifier.replace('"', "\"\"");
95 Ok(format!("\"{escaped}\""))
96 }
97
98 pub fn validate_identifier(identifier: &str) -> Result<()> {
104 if identifier.is_empty() {
106 return Err(TermError::SecurityError(
107 "SQL identifier cannot be empty".to_string(),
108 ));
109 }
110
111 if identifier.len() > 128 {
113 return Err(TermError::SecurityError(
114 "SQL identifier too long (max 128 characters)".to_string(),
115 ));
116 }
117
118 if identifier.contains('\0') {
120 return Err(TermError::SecurityError(
121 "SQL identifier cannot contain null bytes".to_string(),
122 ));
123 }
124
125 static IDENTIFIER_REGEX: Lazy<Regex> = Lazy::new(|| {
127 #[allow(clippy::expect_used)]
131 Regex::new(r#"^[a-zA-Z_"][a-zA-Z0-9_"]*(\.[a-zA-Z_"][a-zA-Z0-9_"]*)*$"#)
132 .expect("Hard-coded regex pattern should be valid")
133 });
134 let regex = &*IDENTIFIER_REGEX;
135
136 if !regex.is_match(identifier) {
137 return Err(TermError::SecurityError(format!(
138 "Invalid SQL identifier format: '{identifier}'. Identifiers must start with a letter or underscore and contain only letters, numbers, underscores, and dots"
139 )));
140 }
141
142 Self::check_dangerous_patterns(identifier)?;
144
145 Ok(())
146 }
147
148 pub fn validate_regex_pattern(pattern: &str) -> Result<String> {
153 if pattern.len() > 1000 {
155 return Err(TermError::SecurityError(
156 "Regex pattern too long (max 1000 characters)".to_string(),
157 ));
158 }
159
160 if pattern.contains('\0') {
162 return Err(TermError::SecurityError(
163 "Regex pattern cannot contain null bytes".to_string(),
164 ));
165 }
166
167 match Regex::new(pattern) {
169 Ok(_) => (),
170 Err(e) => {
171 return Err(TermError::SecurityError(format!(
172 "Invalid regex pattern: {e}"
173 )));
174 }
175 }
176
177 Self::check_redos_patterns(pattern)?;
179
180 let escaped = pattern.replace('\'', "''");
182 Ok(escaped)
183 }
184
185 pub fn validate_sql_expression(expression: &str) -> Result<()> {
191 if expression.len() > 5000 {
193 return Err(TermError::SecurityError(
194 "SQL expression too long (max 5000 characters)".to_string(),
195 ));
196 }
197
198 if expression.contains('\0') {
200 return Err(TermError::SecurityError(
201 "SQL expression cannot contain null bytes".to_string(),
202 ));
203 }
204
205 Self::check_dangerous_sql_patterns(expression)?;
207
208 Ok(())
209 }
210
211 fn check_dangerous_patterns(identifier: &str) -> Result<()> {
213 let identifier_lower = identifier.to_lowercase();
214
215 let dangerous_patterns = &[
217 ";", "--", "/*", "*/", "xp_", "sp_", "union", "select", "insert", "update", "delete",
218 "drop", "create", "alter", "exec", "execute", "declare", "cursor", "fetch", "open",
219 "close",
220 ];
221
222 for pattern in dangerous_patterns {
223 if identifier_lower.contains(pattern) {
224 return Err(TermError::SecurityError(format!(
225 "SQL identifier contains dangerous pattern: '{pattern}'"
226 )));
227 }
228 }
229
230 Ok(())
231 }
232
233 fn check_redos_patterns(pattern: &str) -> Result<()> {
235 let dangerous_patterns = &[
242 "(.*)*", "(.*)+", "(a+)+", "(a*)*", ];
247
248 for dangerous in dangerous_patterns {
249 if pattern.contains(dangerous) {
250 return Err(TermError::SecurityError(
251 "Regex pattern might cause ReDoS attack".to_string(),
252 ));
253 }
254 }
255
256 Ok(())
257 }
258
259 fn check_dangerous_sql_patterns(expression: &str) -> Result<()> {
261 let expression_lower = expression.to_lowercase();
262
263 static DANGEROUS_KEYWORDS: OnceLock<HashSet<&'static str>> = OnceLock::new();
265 let keywords = DANGEROUS_KEYWORDS.get_or_init(|| {
266 [
267 "drop",
269 "create",
270 "alter",
271 "truncate",
272 "insert",
274 "update",
275 "delete",
276 "exec",
278 "execute",
279 "xp_",
280 "sp_",
281 "declare",
283 "cursor",
284 "fetch",
285 "open",
286 "close",
287 "begin",
288 "commit",
289 "rollback",
290 "transaction",
291 "information_schema",
293 "sys.",
294 "pg_",
295 "bulk",
297 "openrowset",
298 "opendatasource",
299 "load_file",
300 "into outfile",
301 "into dumpfile",
302 "--",
304 "/*",
305 "*/",
306 ]
307 .into_iter()
308 .collect()
309 });
310
311 for keyword in keywords {
313 if expression_lower.contains(keyword) {
314 return Err(TermError::SecurityError(format!(
315 "SQL expression contains dangerous keyword: '{keyword}'"
316 )));
317 }
318 }
319
320 let suspicious_patterns = &[
322 r";\s*\w+", r"union\s+select", r"'\s*or\s+'", r"'\s*and\s+'", r"=\s*\(.*select.*\)", r"\(\s*select\s+.*\)", r"in\s*\(\s*select\s+.*\)", ];
330
331 for pattern in suspicious_patterns {
332 if let Ok(regex) = Regex::new(pattern) {
333 if regex.is_match(&expression_lower) {
334 return Err(TermError::SecurityError(format!(
335 "SQL expression contains suspicious pattern matching: {pattern}"
336 )));
337 }
338 }
339 }
340
341 Ok(())
342 }
343}
344
345pub struct InputValidator;
347
348impl InputValidator {
349 pub fn validate_threshold(value: f64, name: &str) -> Result<()> {
351 if !value.is_finite() {
352 return Err(TermError::SecurityError(format!(
353 "Invalid {name} value: must be finite (not NaN or infinite)"
354 )));
355 }
356 Ok(())
357 }
358
359 pub fn validate_percentage(value: f64, name: &str) -> Result<()> {
361 Self::validate_threshold(value, name)?;
362
363 if !(0.0..=1.0).contains(&value) {
364 return Err(TermError::SecurityError(format!(
365 "Invalid {name} value: must be between 0.0 and 1.0, got {value}"
366 )));
367 }
368 Ok(())
369 }
370
371 pub fn validate_string_length(value: &str, max_length: usize, name: &str) -> Result<()> {
373 if value.len() > max_length {
374 return Err(TermError::SecurityError(format!(
375 "{name} too long: {} characters (max {max_length})",
376 value.len()
377 )));
378 }
379 Ok(())
380 }
381
382 pub fn validate_no_null_bytes(value: &str, name: &str) -> Result<()> {
384 if value.contains('\0') {
385 return Err(TermError::SecurityError(format!(
386 "{name} cannot contain null bytes"
387 )));
388 }
389 Ok(())
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_secure_string_zeroization() {
399 let password = "secret123";
400 let secure = SecureString::new(password);
401 assert_eq!(secure.expose(), "secret123");
402
403 let _extracted = secure.into_string();
404 }
406
407 #[test]
408 fn test_valid_sql_identifiers() {
409 assert!(SqlSecurity::validate_identifier("customer_id").is_ok());
410 assert!(SqlSecurity::validate_identifier("table1").is_ok());
411 assert!(SqlSecurity::validate_identifier("_private_col").is_ok());
412 assert!(SqlSecurity::validate_identifier("schema.table").is_ok());
413 }
414
415 #[test]
416 fn test_invalid_sql_identifiers() {
417 assert!(SqlSecurity::validate_identifier("").is_err());
419
420 assert!(SqlSecurity::validate_identifier(&"a".repeat(200)).is_err());
422
423 assert!(SqlSecurity::validate_identifier("id; DROP TABLE").is_err());
425 assert!(SqlSecurity::validate_identifier("col--comment").is_err());
426 assert!(SqlSecurity::validate_identifier("union_select").is_err());
427
428 assert!(SqlSecurity::validate_identifier("col name").is_err()); assert!(SqlSecurity::validate_identifier("col-name").is_err()); assert!(SqlSecurity::validate_identifier("123col").is_err()); }
433
434 #[test]
435 fn test_sql_identifier_escaping() {
436 let result = SqlSecurity::escape_identifier("customer_id").unwrap();
437 assert_eq!(result, "\"customer_id\"");
438
439 let result = SqlSecurity::escape_identifier("col\"with\"quotes").unwrap();
440 assert_eq!(result, "\"col\"\"with\"\"quotes\"");
441 }
442
443 #[test]
444 fn test_regex_pattern_validation() {
445 assert!(SqlSecurity::validate_regex_pattern(r"^[A-Z]\d+$").is_ok());
446 assert!(SqlSecurity::validate_regex_pattern(r"email@domain\.com").is_ok());
447
448 assert!(SqlSecurity::validate_regex_pattern(r"[unclosed").is_err());
450
451 assert!(SqlSecurity::validate_regex_pattern(&"a".repeat(2000)).is_err());
453
454 let result = SqlSecurity::validate_regex_pattern("it's a pattern").unwrap();
456 assert_eq!(result, "it''s a pattern");
457 }
458
459 #[test]
460 fn test_sql_expression_validation() {
461 assert!(SqlSecurity::validate_sql_expression("price > 100").is_ok());
463 assert!(SqlSecurity::validate_sql_expression("name IS NOT NULL").is_ok());
464 assert!(SqlSecurity::validate_sql_expression("age BETWEEN 18 AND 65").is_ok());
465
466 assert!(SqlSecurity::validate_sql_expression("price > 0; DROP TABLE users").is_err());
468 assert!(SqlSecurity::validate_sql_expression("name = '' OR '1'='1'").is_err());
469 assert!(SqlSecurity::validate_sql_expression("id IN (SELECT * FROM passwords)").is_err());
470 assert!(SqlSecurity::validate_sql_expression("EXEC sp_droplogin").is_err());
471 }
472
473 #[test]
474 fn test_input_validation() {
475 assert!(InputValidator::validate_threshold(5.5, "threshold").is_ok());
477 assert!(InputValidator::validate_percentage(0.95, "percentage").is_ok());
478 assert!(InputValidator::validate_string_length("short", 100, "name").is_ok());
479
480 assert!(InputValidator::validate_threshold(f64::NAN, "threshold").is_err());
482 assert!(InputValidator::validate_threshold(f64::INFINITY, "threshold").is_err());
483 assert!(InputValidator::validate_percentage(1.5, "percentage").is_err());
484 assert!(InputValidator::validate_percentage(-0.1, "percentage").is_err());
485 assert!(InputValidator::validate_string_length("too long", 5, "name").is_err());
486 assert!(InputValidator::validate_no_null_bytes("contains\0null", "name").is_err());
487 }
488}