term_guard/
security.rs

1//! Security utilities for Term data validation library.
2//!
3//! This module provides security hardening utilities to prevent SQL injection,
4//! validate inputs, and handle credentials securely.
5
6use crate::error::{Result, TermError};
7use once_cell::sync::Lazy;
8use regex::Regex;
9use std::collections::HashSet;
10use std::sync::OnceLock;
11use zeroize::{Zeroize, ZeroizeOnDrop};
12
13/// A secure string that automatically clears its contents when dropped.
14#[derive(Clone, ZeroizeOnDrop)]
15pub struct SecureString(String);
16
17impl std::fmt::Debug for SecureString {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        write!(f, "SecureString(***)")
20    }
21}
22
23impl SecureString {
24    /// Create a new secure string.
25    pub fn new(value: impl Into<String>) -> Self {
26        Self(value.into())
27    }
28
29    /// Get the string value. Use carefully and avoid storing the result.
30    pub fn expose(&self) -> &str {
31        &self.0
32    }
33
34    /// Convert to a regular string. The SecureString will be zeroized.
35    pub fn into_string(mut self) -> String {
36        let value = std::mem::take(&mut self.0);
37        self.0.zeroize();
38        value
39    }
40}
41
42impl From<String> for SecureString {
43    fn from(value: String) -> Self {
44        Self(value)
45    }
46}
47
48impl From<&str> for SecureString {
49    fn from(value: &str) -> Self {
50        Self(value.to_string())
51    }
52}
53
54/// SQL identifier validation and escaping utilities.
55pub struct SqlSecurity;
56
57impl SqlSecurity {
58    /// Validates and escapes a SQL identifier (table name, column name, etc.).
59    ///
60    /// This function ensures that user-provided identifiers are safe to use in SQL queries
61    /// by validating their format and properly escaping them.
62    ///
63    /// # Arguments
64    /// * `identifier` - The identifier to validate and escape
65    ///
66    /// # Returns
67    /// * `Ok(String)` - The safely escaped identifier ready for SQL use
68    /// * `Err(TermError)` - If the identifier is invalid or potentially malicious
69    ///
70    /// # Security
71    /// This function prevents SQL injection by:
72    /// - Validating identifier format against allowed patterns
73    /// - Checking against a blocklist of dangerous patterns
74    /// - Properly escaping identifiers using double quotes
75    /// - Limiting identifier length to prevent DoS attacks
76    ///
77    /// # Examples
78    /// ```rust
79    /// use term_guard::security::SqlSecurity;
80    ///
81    /// // Valid identifiers
82    /// assert!(SqlSecurity::escape_identifier("customer_id").is_ok());
83    /// assert!(SqlSecurity::escape_identifier("table1").is_ok());
84    ///
85    /// // Invalid identifiers  
86    /// assert!(SqlSecurity::escape_identifier("id; DROP TABLE users--").is_err());
87    /// assert!(SqlSecurity::escape_identifier(&"very_long_name_".repeat(100)).is_err());
88    /// ```
89    pub fn escape_identifier(identifier: &str) -> Result<String> {
90        // Input validation
91        Self::validate_identifier(identifier)?;
92
93        // Escape the identifier using double quotes and escape any internal double quotes
94        let escaped = identifier.replace('"', "\"\"");
95        Ok(format!("\"{escaped}\""))
96    }
97
98    /// Validates a SQL identifier without escaping it.
99    ///
100    /// This function checks if an identifier is safe to use but doesn't escape it.
101    /// Useful for cases where you need validation but will use the identifier in a
102    /// different context.
103    pub fn validate_identifier(identifier: &str) -> Result<()> {
104        // Check for empty identifier
105        if identifier.is_empty() {
106            return Err(TermError::SecurityError(
107                "SQL identifier cannot be empty".to_string(),
108            ));
109        }
110
111        // Check identifier length (prevent DoS)
112        if identifier.len() > 128 {
113            return Err(TermError::SecurityError(
114                "SQL identifier too long (max 128 characters)".to_string(),
115            ));
116        }
117
118        // Check for null bytes
119        if identifier.contains('\0') {
120            return Err(TermError::SecurityError(
121                "SQL identifier cannot contain null bytes".to_string(),
122            ));
123        }
124
125        // Validate identifier format using regex
126        static IDENTIFIER_REGEX: Lazy<Regex> = Lazy::new(|| {
127            // Allow letters, numbers, underscores, dots (for qualified names), and quotes
128            // Must start with letter or underscore (or quote for quoted identifiers)
129            // This regex is compile-time constant and known to be valid
130            #[allow(clippy::expect_used)]
131            Regex::new(r#"^[a-zA-Z_"][a-zA-Z0-9_"]*(\.[a-zA-Z_"][a-zA-Z0-9_"]*)*$"#)
132                .expect("Hard-coded regex pattern should be valid")
133        });
134        let regex = &*IDENTIFIER_REGEX;
135
136        if !regex.is_match(identifier) {
137            return Err(TermError::SecurityError(format!(
138                "Invalid SQL identifier format: '{identifier}'. Identifiers must start with a letter or underscore and contain only letters, numbers, underscores, and dots"
139            )));
140        }
141
142        // Check against dangerous patterns
143        Self::check_dangerous_patterns(identifier)?;
144
145        Ok(())
146    }
147
148    /// Validates a regex pattern for safety.
149    ///
150    /// This function ensures that user-provided regex patterns are safe to use
151    /// in SQL queries and won't cause ReDoS attacks or other security issues.
152    pub fn validate_regex_pattern(pattern: &str) -> Result<String> {
153        // Check pattern length (prevent DoS)
154        if pattern.len() > 1000 {
155            return Err(TermError::SecurityError(
156                "Regex pattern too long (max 1000 characters)".to_string(),
157            ));
158        }
159
160        // Check for null bytes
161        if pattern.contains('\0') {
162            return Err(TermError::SecurityError(
163                "Regex pattern cannot contain null bytes".to_string(),
164            ));
165        }
166
167        // Validate that it's a valid regex pattern
168        match Regex::new(pattern) {
169            Ok(_) => (),
170            Err(e) => {
171                return Err(TermError::SecurityError(format!(
172                    "Invalid regex pattern: {e}"
173                )));
174            }
175        }
176
177        // Check for patterns that might cause ReDoS
178        Self::check_redos_patterns(pattern)?;
179
180        // Escape single quotes for SQL
181        let escaped = pattern.replace('\'', "''");
182        Ok(escaped)
183    }
184
185    /// Validates a custom SQL expression for safety.
186    ///
187    /// This function performs security validation on user-provided SQL expressions
188    /// to prevent SQL injection and other attacks while still allowing legitimate
189    /// validation expressions.
190    pub fn validate_sql_expression(expression: &str) -> Result<()> {
191        // Check expression length (prevent DoS)
192        if expression.len() > 5000 {
193            return Err(TermError::SecurityError(
194                "SQL expression too long (max 5000 characters)".to_string(),
195            ));
196        }
197
198        // Check for null bytes
199        if expression.contains('\0') {
200            return Err(TermError::SecurityError(
201                "SQL expression cannot contain null bytes".to_string(),
202            ));
203        }
204
205        // Check against dangerous SQL keywords and patterns
206        Self::check_dangerous_sql_patterns(expression)?;
207
208        Ok(())
209    }
210
211    /// Checks for dangerous patterns in identifiers.
212    fn check_dangerous_patterns(identifier: &str) -> Result<()> {
213        let identifier_lower = identifier.to_lowercase();
214
215        // Check for SQL injection attempts - only block actual dangerous patterns
216        // Allow common column name patterns like created_at, updated_by, etc.
217
218        // Direct dangerous characters/sequences that should never appear in identifiers
219        let dangerous_chars = &[";", "--", "/*", "*/"];
220
221        for pattern in dangerous_chars {
222            if identifier_lower.contains(pattern) {
223                return Err(TermError::SecurityError(format!(
224                    "SQL identifier contains dangerous character sequence: '{pattern}'"
225                )));
226            }
227        }
228
229        // Check for SQL injection keywords only when they appear as complete statements
230        // or with suspicious patterns, not as part of legitimate column names
231        if identifier_lower.starts_with("xp_") || identifier_lower.starts_with("sp_") {
232            return Err(TermError::SecurityError(
233                "SQL identifier looks like a system stored procedure".to_string(),
234            ));
235        }
236
237        // Check for obvious SQL injection patterns
238        // Block keywords followed by space OR underscore (common in injection attempts)
239        let injection_patterns = &[
240            "union ", "union_", "select ", "select_", "insert ", "insert_", "update ", "update_",
241            "delete ", "delete_", "drop ", "drop_", "create ", "alter ", "exec ", "execute ",
242            "declare ", "cursor ", "fetch ", "open ", "close ",
243        ];
244
245        for pattern in injection_patterns {
246            if identifier_lower.contains(pattern) {
247                return Err(TermError::SecurityError(format!(
248                    "SQL identifier contains suspicious SQL keyword pattern: '{}'",
249                    pattern.trim_end_matches('_').trim()
250                )));
251            }
252        }
253
254        Ok(())
255    }
256
257    /// Checks for patterns that might cause ReDoS (Regular Expression Denial of Service).
258    fn check_redos_patterns(pattern: &str) -> Result<()> {
259        // For now, disable ReDoS checking as it's being too aggressive
260        // with legitimate patterns like email validation.
261        // In a production system, you'd want more sophisticated ReDoS detection
262        // that can distinguish between safe and dangerous patterns.
263
264        // Check for extremely obvious ReDoS patterns only
265        let dangerous_patterns = &[
266            "(.*)*", // Classic catastrophic backtracking
267            "(.*)+", // Another classic
268            "(a+)+", // Nested quantifiers on same pattern
269            "(a*)*", // Nested quantifiers on same pattern
270        ];
271
272        for dangerous in dangerous_patterns {
273            if pattern.contains(dangerous) {
274                return Err(TermError::SecurityError(
275                    "Regex pattern might cause ReDoS attack".to_string(),
276                ));
277            }
278        }
279
280        Ok(())
281    }
282
283    /// Checks for dangerous SQL patterns in expressions.
284    fn check_dangerous_sql_patterns(expression: &str) -> Result<()> {
285        let expression_lower = expression.to_lowercase();
286
287        // Dangerous SQL keywords and patterns
288        static DANGEROUS_KEYWORDS: OnceLock<HashSet<&'static str>> = OnceLock::new();
289        let keywords = DANGEROUS_KEYWORDS.get_or_init(|| {
290            [
291                // DDL operations
292                "drop",
293                "create",
294                "alter",
295                "truncate",
296                // DML operations
297                "insert",
298                "update",
299                "delete",
300                // System procedures
301                "exec",
302                "execute",
303                "xp_",
304                "sp_",
305                // Advanced operations
306                "declare",
307                "cursor",
308                "fetch",
309                "open",
310                "close",
311                "begin",
312                "commit",
313                "rollback",
314                "transaction",
315                // Information schema access
316                "information_schema",
317                "sys.",
318                "pg_",
319                // File operations (MSSQL/MySQL specific)
320                "bulk",
321                "openrowset",
322                "opendatasource",
323                "load_file",
324                "into outfile",
325                "into dumpfile",
326                // Comments that might hide attacks
327                "--",
328                "/*",
329                "*/",
330            ]
331            .into_iter()
332            .collect()
333        });
334
335        // Check for dangerous keywords
336        for keyword in keywords {
337            if expression_lower.contains(keyword) {
338                return Err(TermError::SecurityError(format!(
339                    "SQL expression contains dangerous keyword: '{keyword}'"
340                )));
341            }
342        }
343
344        // Check for suspicious patterns
345        let suspicious_patterns = &[
346            r";\s*\w+",                 // Commands after semicolon
347            r"union\s+select",          // Union-based injection
348            r"'\s*or\s+'",              // OR-based injection
349            r"'\s*and\s+'",             // AND-based injection
350            r"=\s*\(.*select.*\)",      // Subquery injection
351            r"\(\s*select\s+.*\)",      // Subqueries in general
352            r"in\s*\(\s*select\s+.*\)", // IN with subquery
353        ];
354
355        for pattern in suspicious_patterns {
356            if let Ok(regex) = Regex::new(pattern) {
357                if regex.is_match(&expression_lower) {
358                    return Err(TermError::SecurityError(format!(
359                        "SQL expression contains suspicious pattern matching: {pattern}"
360                    )));
361                }
362            }
363        }
364
365        Ok(())
366    }
367}
368
369/// Input validation utilities for various data types.
370pub struct InputValidator;
371
372impl InputValidator {
373    /// Validates a numeric threshold value.
374    pub fn validate_threshold(value: f64, name: &str) -> Result<()> {
375        if !value.is_finite() {
376            return Err(TermError::SecurityError(format!(
377                "Invalid {name} value: must be finite (not NaN or infinite)"
378            )));
379        }
380        Ok(())
381    }
382
383    /// Validates a percentage value (0.0 to 1.0).
384    pub fn validate_percentage(value: f64, name: &str) -> Result<()> {
385        Self::validate_threshold(value, name)?;
386
387        if !(0.0..=1.0).contains(&value) {
388            return Err(TermError::SecurityError(format!(
389                "Invalid {name} value: must be between 0.0 and 1.0, got {value}"
390            )));
391        }
392        Ok(())
393    }
394
395    /// Validates a string length.
396    pub fn validate_string_length(value: &str, max_length: usize, name: &str) -> Result<()> {
397        if value.len() > max_length {
398            return Err(TermError::SecurityError(format!(
399                "{name} too long: {} characters (max {max_length})",
400                value.len()
401            )));
402        }
403        Ok(())
404    }
405
406    /// Validates that a string doesn't contain null bytes.
407    pub fn validate_no_null_bytes(value: &str, name: &str) -> Result<()> {
408        if value.contains('\0') {
409            return Err(TermError::SecurityError(format!(
410                "{name} cannot contain null bytes"
411            )));
412        }
413        Ok(())
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn test_secure_string_zeroization() {
423        let password = "secret123";
424        let secure = SecureString::new(password);
425        assert_eq!(secure.expose(), "secret123");
426
427        let _extracted = secure.into_string();
428        // secure should now be zeroized (can't easily test without unsafe code)
429    }
430
431    #[test]
432    fn test_valid_sql_identifiers() {
433        assert!(SqlSecurity::validate_identifier("customer_id").is_ok());
434        assert!(SqlSecurity::validate_identifier("table1").is_ok());
435        assert!(SqlSecurity::validate_identifier("_private_col").is_ok());
436        assert!(SqlSecurity::validate_identifier("schema.table").is_ok());
437    }
438
439    #[test]
440    fn test_invalid_sql_identifiers() {
441        // Empty identifier
442        assert!(SqlSecurity::validate_identifier("").is_err());
443
444        // Too long
445        assert!(SqlSecurity::validate_identifier(&"a".repeat(200)).is_err());
446
447        // Contains dangerous patterns
448        assert!(SqlSecurity::validate_identifier("id; DROP TABLE").is_err());
449        assert!(SqlSecurity::validate_identifier("col--comment").is_err());
450        assert!(SqlSecurity::validate_identifier("union_select").is_err());
451
452        // Invalid characters
453        assert!(SqlSecurity::validate_identifier("col name").is_err()); // space
454        assert!(SqlSecurity::validate_identifier("col-name").is_err()); // dash
455        assert!(SqlSecurity::validate_identifier("123col").is_err()); // starts with number
456    }
457
458    #[test]
459    fn test_sql_identifier_escaping() {
460        let result = SqlSecurity::escape_identifier("customer_id").unwrap();
461        assert_eq!(result, "\"customer_id\"");
462
463        let result = SqlSecurity::escape_identifier("col\"with\"quotes").unwrap();
464        assert_eq!(result, "\"col\"\"with\"\"quotes\"");
465    }
466
467    #[test]
468    fn test_regex_pattern_validation() {
469        assert!(SqlSecurity::validate_regex_pattern(r"^[A-Z]\d+$").is_ok());
470        assert!(SqlSecurity::validate_regex_pattern(r"email@domain\.com").is_ok());
471
472        // Invalid regex
473        assert!(SqlSecurity::validate_regex_pattern(r"[unclosed").is_err());
474
475        // Too long
476        assert!(SqlSecurity::validate_regex_pattern(&"a".repeat(2000)).is_err());
477
478        // Contains quotes that should be escaped
479        let result = SqlSecurity::validate_regex_pattern("it's a pattern").unwrap();
480        assert_eq!(result, "it''s a pattern");
481    }
482
483    #[test]
484    fn test_sql_expression_validation() {
485        // Valid expressions
486        assert!(SqlSecurity::validate_sql_expression("price > 100").is_ok());
487        assert!(SqlSecurity::validate_sql_expression("name IS NOT NULL").is_ok());
488        assert!(SqlSecurity::validate_sql_expression("age BETWEEN 18 AND 65").is_ok());
489
490        // Dangerous expressions
491        assert!(SqlSecurity::validate_sql_expression("price > 0; DROP TABLE users").is_err());
492        assert!(SqlSecurity::validate_sql_expression("name = '' OR '1'='1'").is_err());
493        assert!(SqlSecurity::validate_sql_expression("id IN (SELECT * FROM passwords)").is_err());
494        assert!(SqlSecurity::validate_sql_expression("EXEC sp_droplogin").is_err());
495    }
496
497    #[test]
498    fn test_input_validation() {
499        // Valid inputs
500        assert!(InputValidator::validate_threshold(5.5, "threshold").is_ok());
501        assert!(InputValidator::validate_percentage(0.95, "percentage").is_ok());
502        assert!(InputValidator::validate_string_length("short", 100, "name").is_ok());
503
504        // Invalid inputs
505        assert!(InputValidator::validate_threshold(f64::NAN, "threshold").is_err());
506        assert!(InputValidator::validate_threshold(f64::INFINITY, "threshold").is_err());
507        assert!(InputValidator::validate_percentage(1.5, "percentage").is_err());
508        assert!(InputValidator::validate_percentage(-0.1, "percentage").is_err());
509        assert!(InputValidator::validate_string_length("too long", 5, "name").is_err());
510        assert!(InputValidator::validate_no_null_bytes("contains\0null", "name").is_err());
511    }
512}