vex_api/
sanitize.rs

1//! Input sanitization and validation for security
2//!
3//! Provides functions to sanitize and validate user inputs to prevent
4//! injection attacks and ensure data integrity.
5
6use thiserror::Error;
7
8/// Sanitization errors
9#[derive(Debug, Error)]
10pub enum SanitizeError {
11    #[error("Input too long: {actual} chars (max {max})")]
12    TooLong { actual: usize, max: usize },
13
14    #[error("Input too short: {actual} chars (min {min})")]
15    TooShort { actual: usize, min: usize },
16
17    #[error("Input contains forbidden pattern: {pattern}")]
18    ForbiddenPattern { pattern: String },
19
20    #[error("Input contains invalid characters")]
21    InvalidCharacters,
22
23    #[error("Input is empty or whitespace only")]
24    EmptyInput,
25}
26
27/// Configuration for input sanitization
28#[derive(Debug, Clone)]
29pub struct SanitizeConfig {
30    /// Maximum length allowed
31    pub max_length: usize,
32    /// Minimum length required
33    pub min_length: usize,
34    /// Strip leading/trailing whitespace
35    pub trim: bool,
36    /// Check for prompt injection patterns
37    pub check_injection: bool,
38    /// Allow newlines
39    pub allow_newlines: bool,
40    /// Allow special characters
41    pub allow_special_chars: bool,
42}
43
44impl Default for SanitizeConfig {
45    fn default() -> Self {
46        Self {
47            max_length: 10000,
48            min_length: 1,
49            trim: true,
50            check_injection: true,
51            allow_newlines: true,
52            allow_special_chars: true,
53        }
54    }
55}
56
57impl SanitizeConfig {
58    /// Strict config for names/identifiers
59    pub fn strict() -> Self {
60        Self {
61            max_length: 100,
62            min_length: 1,
63            trim: true,
64            check_injection: true,
65            allow_newlines: false,
66            allow_special_chars: false,
67        }
68    }
69
70    /// Config for role descriptions
71    pub fn role() -> Self {
72        Self {
73            max_length: 500,
74            min_length: 3,
75            trim: true,
76            check_injection: true,
77            allow_newlines: true,
78            allow_special_chars: true,
79        }
80    }
81
82    /// Config for prompts (more permissive)
83    pub fn prompt() -> Self {
84        Self {
85            max_length: 50000,
86            min_length: 1,
87            trim: true,
88            check_injection: true,
89            allow_newlines: true,
90            allow_special_chars: true,
91        }
92    }
93}
94
95/// Patterns that may indicate prompt injection attempts
96const INJECTION_PATTERNS: &[&str] = &[
97    // System prompt overrides
98    "ignore previous instructions",
99    "ignore all previous",
100    "disregard previous",
101    "forget previous",
102    "new instructions:",
103    "system prompt:",
104    "you are now",
105    "pretend you are",
106    "act as if",
107    "roleplay as",
108    // Jailbreak attempts
109    "dan mode",
110    "developer mode",
111    "jailbreak",
112    "unlock",
113    "bypass",
114    // Encoding attacks
115    "base64:",
116    "\\x",
117    "\\u00",
118];
119
120/// Sanitize and validate input text
121pub fn sanitize(input: &str, config: &SanitizeConfig) -> Result<String, SanitizeError> {
122    // Trim if configured
123    let text = if config.trim { input.trim() } else { input };
124
125    // Check empty
126    if text.is_empty() {
127        return Err(SanitizeError::EmptyInput);
128    }
129
130    // Check length
131    if text.len() < config.min_length {
132        return Err(SanitizeError::TooShort {
133            actual: text.len(),
134            min: config.min_length,
135        });
136    }
137
138    if text.len() > config.max_length {
139        return Err(SanitizeError::TooLong {
140            actual: text.len(),
141            max: config.max_length,
142        });
143    }
144
145    // Check for newlines if not allowed
146    if !config.allow_newlines && text.contains('\n') {
147        return Err(SanitizeError::InvalidCharacters);
148    }
149
150    // Check for special characters if not allowed
151    if !config.allow_special_chars {
152        for c in text.chars() {
153            if !c.is_alphanumeric() && c != ' ' && c != '-' && c != '_' {
154                return Err(SanitizeError::InvalidCharacters);
155            }
156        }
157    }
158
159    // Check for injection patterns
160    if config.check_injection {
161        let lower = text.to_lowercase();
162        for pattern in INJECTION_PATTERNS {
163            if lower.contains(pattern) {
164                tracing::warn!(pattern = pattern, "Potential prompt injection detected");
165                return Err(SanitizeError::ForbiddenPattern {
166                    pattern: pattern.to_string(),
167                });
168            }
169        }
170    }
171
172    // Remove null bytes and other control characters (except newlines/tabs if allowed)
173    let sanitized: String = text
174        .chars()
175        .filter(|c| {
176            if *c == '\n' || *c == '\t' {
177                config.allow_newlines
178            } else {
179                !c.is_control()
180            }
181        })
182        .collect();
183
184    Ok(sanitized)
185}
186
187/// Sanitize a name field (strict)
188pub fn sanitize_name(input: &str) -> Result<String, SanitizeError> {
189    sanitize(input, &SanitizeConfig::strict())
190}
191
192/// Sanitize a role description
193pub fn sanitize_role(input: &str) -> Result<String, SanitizeError> {
194    sanitize(input, &SanitizeConfig::role())
195}
196
197/// Sanitize a prompt
198pub fn sanitize_prompt(input: &str) -> Result<String, SanitizeError> {
199    sanitize(input, &SanitizeConfig::prompt())
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_sanitize_valid_input() {
208        let result = sanitize("Hello world", &SanitizeConfig::default());
209        assert!(result.is_ok());
210        assert_eq!(result.unwrap(), "Hello world");
211    }
212
213    #[test]
214    fn test_sanitize_trims_whitespace() {
215        let result = sanitize("  Hello  ", &SanitizeConfig::default());
216        assert!(result.is_ok());
217        assert_eq!(result.unwrap(), "Hello");
218    }
219
220    #[test]
221    fn test_sanitize_rejects_empty() {
222        let result = sanitize("", &SanitizeConfig::default());
223        assert!(matches!(result, Err(SanitizeError::EmptyInput)));
224    }
225
226    #[test]
227    fn test_sanitize_rejects_too_long() {
228        let long_input = "a".repeat(101);
229        let result = sanitize(&long_input, &SanitizeConfig::strict());
230        assert!(matches!(result, Err(SanitizeError::TooLong { .. })));
231    }
232
233    #[test]
234    fn test_sanitize_detects_injection() {
235        let result = sanitize(
236            "Please ignore previous instructions",
237            &SanitizeConfig::default(),
238        );
239        assert!(matches!(
240            result,
241            Err(SanitizeError::ForbiddenPattern { .. })
242        ));
243    }
244
245    #[test]
246    fn test_sanitize_name_rejects_special_chars() {
247        let result = sanitize_name("agent<script>");
248        assert!(matches!(result, Err(SanitizeError::InvalidCharacters)));
249    }
250
251    #[test]
252    fn test_sanitize_removes_control_chars() {
253        let input = "Hello\x00World";
254        let result = sanitize(input, &SanitizeConfig::default());
255        assert!(result.is_ok());
256        assert_eq!(result.unwrap(), "HelloWorld");
257    }
258}