term_guard/
security.rs

1//! Security utilities for Term data validation library.
2//!
3//! This module provides security hardening utilities to prevent SQL injection,
4//! validate inputs, and handle credentials securely.
5
6use crate::error::{Result, TermError};
7use once_cell::sync::Lazy;
8use regex::Regex;
9use std::collections::HashSet;
10use std::sync::OnceLock;
11use zeroize::{Zeroize, ZeroizeOnDrop};
12
13/// A secure string that automatically clears its contents when dropped.
14#[derive(Clone, ZeroizeOnDrop)]
15pub struct SecureString(String);
16
17impl std::fmt::Debug for SecureString {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        write!(f, "SecureString(***)")
20    }
21}
22
23impl SecureString {
24    /// Create a new secure string.
25    pub fn new(value: impl Into<String>) -> Self {
26        Self(value.into())
27    }
28
29    /// Get the string value. Use carefully and avoid storing the result.
30    pub fn expose(&self) -> &str {
31        &self.0
32    }
33
34    /// Convert to a regular string. The SecureString will be zeroized.
35    pub fn into_string(mut self) -> String {
36        let value = std::mem::take(&mut self.0);
37        self.0.zeroize();
38        value
39    }
40}
41
42impl From<String> for SecureString {
43    fn from(value: String) -> Self {
44        Self(value)
45    }
46}
47
48impl From<&str> for SecureString {
49    fn from(value: &str) -> Self {
50        Self(value.to_string())
51    }
52}
53
54/// SQL identifier validation and escaping utilities.
55pub struct SqlSecurity;
56
57impl SqlSecurity {
58    /// Validates and escapes a SQL identifier (table name, column name, etc.).
59    ///
60    /// This function ensures that user-provided identifiers are safe to use in SQL queries
61    /// by validating their format and properly escaping them.
62    ///
63    /// # Arguments
64    /// * `identifier` - The identifier to validate and escape
65    ///
66    /// # Returns
67    /// * `Ok(String)` - The safely escaped identifier ready for SQL use
68    /// * `Err(TermError)` - If the identifier is invalid or potentially malicious
69    ///
70    /// # Security
71    /// This function prevents SQL injection by:
72    /// - Validating identifier format against allowed patterns
73    /// - Checking against a blocklist of dangerous patterns
74    /// - Properly escaping identifiers using double quotes
75    /// - Limiting identifier length to prevent DoS attacks
76    ///
77    /// # Examples
78    /// ```rust
79    /// use term_guard::security::SqlSecurity;
80    ///
81    /// // Valid identifiers
82    /// assert!(SqlSecurity::escape_identifier("customer_id").is_ok());
83    /// assert!(SqlSecurity::escape_identifier("table1").is_ok());
84    ///
85    /// // Invalid identifiers  
86    /// assert!(SqlSecurity::escape_identifier("id; DROP TABLE users--").is_err());
87    /// assert!(SqlSecurity::escape_identifier(&"very_long_name_".repeat(100)).is_err());
88    /// ```
89    pub fn escape_identifier(identifier: &str) -> Result<String> {
90        // Input validation
91        Self::validate_identifier(identifier)?;
92
93        // Escape the identifier using double quotes and escape any internal double quotes
94        let escaped = identifier.replace('"', "\"\"");
95        Ok(format!("\"{escaped}\""))
96    }
97
98    /// Validates a SQL identifier without escaping it.
99    ///
100    /// This function checks if an identifier is safe to use but doesn't escape it.
101    /// Useful for cases where you need validation but will use the identifier in a
102    /// different context.
103    pub fn validate_identifier(identifier: &str) -> Result<()> {
104        // Check for empty identifier
105        if identifier.is_empty() {
106            return Err(TermError::SecurityError(
107                "SQL identifier cannot be empty".to_string(),
108            ));
109        }
110
111        // Check identifier length (prevent DoS)
112        if identifier.len() > 128 {
113            return Err(TermError::SecurityError(
114                "SQL identifier too long (max 128 characters)".to_string(),
115            ));
116        }
117
118        // Check for null bytes
119        if identifier.contains('\0') {
120            return Err(TermError::SecurityError(
121                "SQL identifier cannot contain null bytes".to_string(),
122            ));
123        }
124
125        // Validate identifier format using regex
126        static IDENTIFIER_REGEX: Lazy<Regex> = Lazy::new(|| {
127            // Allow letters, numbers, underscores, dots (for qualified names), and quotes
128            // Must start with letter or underscore (or quote for quoted identifiers)
129            // This regex is compile-time constant and known to be valid
130            #[allow(clippy::expect_used)]
131            Regex::new(r#"^[a-zA-Z_"][a-zA-Z0-9_"]*(\.[a-zA-Z_"][a-zA-Z0-9_"]*)*$"#)
132                .expect("Hard-coded regex pattern should be valid")
133        });
134        let regex = &*IDENTIFIER_REGEX;
135
136        if !regex.is_match(identifier) {
137            return Err(TermError::SecurityError(format!(
138                "Invalid SQL identifier format: '{identifier}'. Identifiers must start with a letter or underscore and contain only letters, numbers, underscores, and dots"
139            )));
140        }
141
142        // Check against dangerous patterns
143        Self::check_dangerous_patterns(identifier)?;
144
145        Ok(())
146    }
147
148    /// Validates a regex pattern for safety.
149    ///
150    /// This function ensures that user-provided regex patterns are safe to use
151    /// in SQL queries and won't cause ReDoS attacks or other security issues.
152    pub fn validate_regex_pattern(pattern: &str) -> Result<String> {
153        // Check pattern length (prevent DoS)
154        if pattern.len() > 1000 {
155            return Err(TermError::SecurityError(
156                "Regex pattern too long (max 1000 characters)".to_string(),
157            ));
158        }
159
160        // Check for null bytes
161        if pattern.contains('\0') {
162            return Err(TermError::SecurityError(
163                "Regex pattern cannot contain null bytes".to_string(),
164            ));
165        }
166
167        // Validate that it's a valid regex pattern
168        match Regex::new(pattern) {
169            Ok(_) => (),
170            Err(e) => {
171                return Err(TermError::SecurityError(format!(
172                    "Invalid regex pattern: {e}"
173                )));
174            }
175        }
176
177        // Check for patterns that might cause ReDoS
178        Self::check_redos_patterns(pattern)?;
179
180        // Escape single quotes for SQL
181        let escaped = pattern.replace('\'', "''");
182        Ok(escaped)
183    }
184
185    /// Validates a custom SQL expression for safety.
186    ///
187    /// This function performs security validation on user-provided SQL expressions
188    /// to prevent SQL injection and other attacks while still allowing legitimate
189    /// validation expressions.
190    pub fn validate_sql_expression(expression: &str) -> Result<()> {
191        // Check expression length (prevent DoS)
192        if expression.len() > 5000 {
193            return Err(TermError::SecurityError(
194                "SQL expression too long (max 5000 characters)".to_string(),
195            ));
196        }
197
198        // Check for null bytes
199        if expression.contains('\0') {
200            return Err(TermError::SecurityError(
201                "SQL expression cannot contain null bytes".to_string(),
202            ));
203        }
204
205        // Check against dangerous SQL keywords and patterns
206        Self::check_dangerous_sql_patterns(expression)?;
207
208        Ok(())
209    }
210
211    /// Checks for dangerous patterns in identifiers.
212    fn check_dangerous_patterns(identifier: &str) -> Result<()> {
213        let identifier_lower = identifier.to_lowercase();
214
215        // Check for SQL injection attempts
216        let dangerous_patterns = &[
217            ";", "--", "/*", "*/", "xp_", "sp_", "union", "select", "insert", "update", "delete",
218            "drop", "create", "alter", "exec", "execute", "declare", "cursor", "fetch", "open",
219            "close",
220        ];
221
222        for pattern in dangerous_patterns {
223            if identifier_lower.contains(pattern) {
224                return Err(TermError::SecurityError(format!(
225                    "SQL identifier contains dangerous pattern: '{pattern}'"
226                )));
227            }
228        }
229
230        Ok(())
231    }
232
233    /// Checks for patterns that might cause ReDoS (Regular Expression Denial of Service).
234    fn check_redos_patterns(pattern: &str) -> Result<()> {
235        // For now, disable ReDoS checking as it's being too aggressive
236        // with legitimate patterns like email validation.
237        // In a production system, you'd want more sophisticated ReDoS detection
238        // that can distinguish between safe and dangerous patterns.
239
240        // Check for extremely obvious ReDoS patterns only
241        let dangerous_patterns = &[
242            "(.*)*", // Classic catastrophic backtracking
243            "(.*)+", // Another classic
244            "(a+)+", // Nested quantifiers on same pattern
245            "(a*)*", // Nested quantifiers on same pattern
246        ];
247
248        for dangerous in dangerous_patterns {
249            if pattern.contains(dangerous) {
250                return Err(TermError::SecurityError(
251                    "Regex pattern might cause ReDoS attack".to_string(),
252                ));
253            }
254        }
255
256        Ok(())
257    }
258
259    /// Checks for dangerous SQL patterns in expressions.
260    fn check_dangerous_sql_patterns(expression: &str) -> Result<()> {
261        let expression_lower = expression.to_lowercase();
262
263        // Dangerous SQL keywords and patterns
264        static DANGEROUS_KEYWORDS: OnceLock<HashSet<&'static str>> = OnceLock::new();
265        let keywords = DANGEROUS_KEYWORDS.get_or_init(|| {
266            [
267                // DDL operations
268                "drop",
269                "create",
270                "alter",
271                "truncate",
272                // DML operations
273                "insert",
274                "update",
275                "delete",
276                // System procedures
277                "exec",
278                "execute",
279                "xp_",
280                "sp_",
281                // Advanced operations
282                "declare",
283                "cursor",
284                "fetch",
285                "open",
286                "close",
287                "begin",
288                "commit",
289                "rollback",
290                "transaction",
291                // Information schema access
292                "information_schema",
293                "sys.",
294                "pg_",
295                // File operations (MSSQL/MySQL specific)
296                "bulk",
297                "openrowset",
298                "opendatasource",
299                "load_file",
300                "into outfile",
301                "into dumpfile",
302                // Comments that might hide attacks
303                "--",
304                "/*",
305                "*/",
306            ]
307            .into_iter()
308            .collect()
309        });
310
311        // Check for dangerous keywords
312        for keyword in keywords {
313            if expression_lower.contains(keyword) {
314                return Err(TermError::SecurityError(format!(
315                    "SQL expression contains dangerous keyword: '{keyword}'"
316                )));
317            }
318        }
319
320        // Check for suspicious patterns
321        let suspicious_patterns = &[
322            r";\s*\w+",                 // Commands after semicolon
323            r"union\s+select",          // Union-based injection
324            r"'\s*or\s+'",              // OR-based injection
325            r"'\s*and\s+'",             // AND-based injection
326            r"=\s*\(.*select.*\)",      // Subquery injection
327            r"\(\s*select\s+.*\)",      // Subqueries in general
328            r"in\s*\(\s*select\s+.*\)", // IN with subquery
329        ];
330
331        for pattern in suspicious_patterns {
332            if let Ok(regex) = Regex::new(pattern) {
333                if regex.is_match(&expression_lower) {
334                    return Err(TermError::SecurityError(format!(
335                        "SQL expression contains suspicious pattern matching: {pattern}"
336                    )));
337                }
338            }
339        }
340
341        Ok(())
342    }
343}
344
345/// Input validation utilities for various data types.
346pub struct InputValidator;
347
348impl InputValidator {
349    /// Validates a numeric threshold value.
350    pub fn validate_threshold(value: f64, name: &str) -> Result<()> {
351        if !value.is_finite() {
352            return Err(TermError::SecurityError(format!(
353                "Invalid {name} value: must be finite (not NaN or infinite)"
354            )));
355        }
356        Ok(())
357    }
358
359    /// Validates a percentage value (0.0 to 1.0).
360    pub fn validate_percentage(value: f64, name: &str) -> Result<()> {
361        Self::validate_threshold(value, name)?;
362
363        if !(0.0..=1.0).contains(&value) {
364            return Err(TermError::SecurityError(format!(
365                "Invalid {name} value: must be between 0.0 and 1.0, got {value}"
366            )));
367        }
368        Ok(())
369    }
370
371    /// Validates a string length.
372    pub fn validate_string_length(value: &str, max_length: usize, name: &str) -> Result<()> {
373        if value.len() > max_length {
374            return Err(TermError::SecurityError(format!(
375                "{name} too long: {} characters (max {max_length})",
376                value.len()
377            )));
378        }
379        Ok(())
380    }
381
382    /// Validates that a string doesn't contain null bytes.
383    pub fn validate_no_null_bytes(value: &str, name: &str) -> Result<()> {
384        if value.contains('\0') {
385            return Err(TermError::SecurityError(format!(
386                "{name} cannot contain null bytes"
387            )));
388        }
389        Ok(())
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn test_secure_string_zeroization() {
399        let password = "secret123";
400        let secure = SecureString::new(password);
401        assert_eq!(secure.expose(), "secret123");
402
403        let _extracted = secure.into_string();
404        // secure should now be zeroized (can't easily test without unsafe code)
405    }
406
407    #[test]
408    fn test_valid_sql_identifiers() {
409        assert!(SqlSecurity::validate_identifier("customer_id").is_ok());
410        assert!(SqlSecurity::validate_identifier("table1").is_ok());
411        assert!(SqlSecurity::validate_identifier("_private_col").is_ok());
412        assert!(SqlSecurity::validate_identifier("schema.table").is_ok());
413    }
414
415    #[test]
416    fn test_invalid_sql_identifiers() {
417        // Empty identifier
418        assert!(SqlSecurity::validate_identifier("").is_err());
419
420        // Too long
421        assert!(SqlSecurity::validate_identifier(&"a".repeat(200)).is_err());
422
423        // Contains dangerous patterns
424        assert!(SqlSecurity::validate_identifier("id; DROP TABLE").is_err());
425        assert!(SqlSecurity::validate_identifier("col--comment").is_err());
426        assert!(SqlSecurity::validate_identifier("union_select").is_err());
427
428        // Invalid characters
429        assert!(SqlSecurity::validate_identifier("col name").is_err()); // space
430        assert!(SqlSecurity::validate_identifier("col-name").is_err()); // dash
431        assert!(SqlSecurity::validate_identifier("123col").is_err()); // starts with number
432    }
433
434    #[test]
435    fn test_sql_identifier_escaping() {
436        let result = SqlSecurity::escape_identifier("customer_id").unwrap();
437        assert_eq!(result, "\"customer_id\"");
438
439        let result = SqlSecurity::escape_identifier("col\"with\"quotes").unwrap();
440        assert_eq!(result, "\"col\"\"with\"\"quotes\"");
441    }
442
443    #[test]
444    fn test_regex_pattern_validation() {
445        assert!(SqlSecurity::validate_regex_pattern(r"^[A-Z]\d+$").is_ok());
446        assert!(SqlSecurity::validate_regex_pattern(r"email@domain\.com").is_ok());
447
448        // Invalid regex
449        assert!(SqlSecurity::validate_regex_pattern(r"[unclosed").is_err());
450
451        // Too long
452        assert!(SqlSecurity::validate_regex_pattern(&"a".repeat(2000)).is_err());
453
454        // Contains quotes that should be escaped
455        let result = SqlSecurity::validate_regex_pattern("it's a pattern").unwrap();
456        assert_eq!(result, "it''s a pattern");
457    }
458
459    #[test]
460    fn test_sql_expression_validation() {
461        // Valid expressions
462        assert!(SqlSecurity::validate_sql_expression("price > 100").is_ok());
463        assert!(SqlSecurity::validate_sql_expression("name IS NOT NULL").is_ok());
464        assert!(SqlSecurity::validate_sql_expression("age BETWEEN 18 AND 65").is_ok());
465
466        // Dangerous expressions
467        assert!(SqlSecurity::validate_sql_expression("price > 0; DROP TABLE users").is_err());
468        assert!(SqlSecurity::validate_sql_expression("name = '' OR '1'='1'").is_err());
469        assert!(SqlSecurity::validate_sql_expression("id IN (SELECT * FROM passwords)").is_err());
470        assert!(SqlSecurity::validate_sql_expression("EXEC sp_droplogin").is_err());
471    }
472
473    #[test]
474    fn test_input_validation() {
475        // Valid inputs
476        assert!(InputValidator::validate_threshold(5.5, "threshold").is_ok());
477        assert!(InputValidator::validate_percentage(0.95, "percentage").is_ok());
478        assert!(InputValidator::validate_string_length("short", 100, "name").is_ok());
479
480        // Invalid inputs
481        assert!(InputValidator::validate_threshold(f64::NAN, "threshold").is_err());
482        assert!(InputValidator::validate_threshold(f64::INFINITY, "threshold").is_err());
483        assert!(InputValidator::validate_percentage(1.5, "percentage").is_err());
484        assert!(InputValidator::validate_percentage(-0.1, "percentage").is_err());
485        assert!(InputValidator::validate_string_length("too long", 5, "name").is_err());
486        assert!(InputValidator::validate_no_null_bytes("contains\0null", "name").is_err());
487    }
488}