Skip to main content

rustyclaw_core/security/
validator.rs

1//! Input validation for the safety layer (inspired by IronClaw)
2//!
3//! Validates input text and tool parameters for security issues:
4//! - Length limits (prevent DoS via huge inputs)
5//! - Forbidden patterns
6//! - Excessive whitespace/repetition (padding attacks)
7//! - Null bytes and encoding issues
8//!
9//! # Attribution
10//!
11//! Input validation patterns inspired by [IronClaw](https://github.com/nearai/ironclaw) (Apache-2.0).
12
13use std::collections::HashSet;
14
15/// Result of validating input.
16#[derive(Debug, Clone)]
17pub struct ValidationResult {
18    /// Whether the input is valid.
19    pub is_valid: bool,
20    /// Validation errors if any.
21    pub errors: Vec<ValidationError>,
22    /// Warnings that don't block processing.
23    pub warnings: Vec<String>,
24}
25
26impl ValidationResult {
27    /// Create a successful validation result.
28    pub fn ok() -> Self {
29        Self {
30            is_valid: true,
31            errors: vec![],
32            warnings: vec![],
33        }
34    }
35
36    /// Create a validation result with an error.
37    pub fn error(error: ValidationError) -> Self {
38        Self {
39            is_valid: false,
40            errors: vec![error],
41            warnings: vec![],
42        }
43    }
44
45    /// Add a warning to the result.
46    pub fn with_warning(mut self, warning: impl Into<String>) -> Self {
47        self.warnings.push(warning.into());
48        self
49    }
50
51    /// Merge another validation result into this one.
52    pub fn merge(mut self, other: Self) -> Self {
53        self.is_valid = self.is_valid && other.is_valid;
54        self.errors.extend(other.errors);
55        self.warnings.extend(other.warnings);
56        self
57    }
58}
59
60impl Default for ValidationResult {
61    fn default() -> Self {
62        Self::ok()
63    }
64}
65
66/// A validation error.
67#[derive(Debug, Clone)]
68pub struct ValidationError {
69    /// Field or aspect that failed validation.
70    pub field: String,
71    /// Error message.
72    pub message: String,
73    /// Error code for programmatic handling.
74    pub code: ValidationErrorCode,
75}
76
77/// Error codes for validation errors.
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
79pub enum ValidationErrorCode {
80    Empty,
81    TooLong,
82    TooShort,
83    InvalidFormat,
84    ForbiddenContent,
85    InvalidEncoding,
86    SuspiciousPattern,
87}
88
89/// Input validator with configurable rules.
90pub struct InputValidator {
91    /// Maximum input length.
92    max_length: usize,
93    /// Minimum input length.
94    min_length: usize,
95    /// Forbidden substrings (case-insensitive).
96    forbidden_patterns: HashSet<String>,
97}
98
99impl InputValidator {
100    /// Create a new validator with default settings.
101    pub fn new() -> Self {
102        Self {
103            max_length: 100_000,
104            min_length: 1,
105            forbidden_patterns: HashSet::new(),
106        }
107    }
108
109    /// Set maximum input length.
110    pub fn with_max_length(mut self, max: usize) -> Self {
111        self.max_length = max;
112        self
113    }
114
115    /// Set minimum input length.
116    pub fn with_min_length(mut self, min: usize) -> Self {
117        self.min_length = min;
118        self
119    }
120
121    /// Add a forbidden pattern (case-insensitive).
122    pub fn forbid_pattern(mut self, pattern: impl Into<String>) -> Self {
123        self.forbidden_patterns
124            .insert(pattern.into().to_lowercase());
125        self
126    }
127
128    /// Validate input text.
129    pub fn validate(&self, input: &str) -> ValidationResult {
130        let mut result = ValidationResult::ok();
131
132        // Check empty
133        if input.is_empty() {
134            return ValidationResult::error(ValidationError {
135                field: "input".to_string(),
136                message: "Input cannot be empty".to_string(),
137                code: ValidationErrorCode::Empty,
138            });
139        }
140
141        // Check length
142        if input.len() > self.max_length {
143            result = result.merge(ValidationResult::error(ValidationError {
144                field: "input".to_string(),
145                message: format!(
146                    "Input too long: {} bytes (max {})",
147                    input.len(),
148                    self.max_length
149                ),
150                code: ValidationErrorCode::TooLong,
151            }));
152        }
153
154        if input.len() < self.min_length {
155            result = result.merge(ValidationResult::error(ValidationError {
156                field: "input".to_string(),
157                message: format!(
158                    "Input too short: {} bytes (min {})",
159                    input.len(),
160                    self.min_length
161                ),
162                code: ValidationErrorCode::TooShort,
163            }));
164        }
165
166        // Check for null bytes (invalid in most contexts)
167        if input.chars().any(|c| c == '\x00') {
168            result = result.merge(ValidationResult::error(ValidationError {
169                field: "input".to_string(),
170                message: "Input contains null bytes".to_string(),
171                code: ValidationErrorCode::InvalidEncoding,
172            }));
173        }
174
175        // Check forbidden patterns
176        let lower_input = input.to_lowercase();
177        for pattern in &self.forbidden_patterns {
178            if lower_input.contains(pattern) {
179                result = result.merge(ValidationResult::error(ValidationError {
180                    field: "input".to_string(),
181                    message: format!("Input contains forbidden pattern: {}", pattern),
182                    code: ValidationErrorCode::ForbiddenContent,
183                }));
184            }
185        }
186
187        // Check for excessive whitespace (might indicate padding attacks)
188        let whitespace_ratio =
189            input.chars().filter(|c| c.is_whitespace()).count() as f64 / input.len() as f64;
190        if whitespace_ratio > 0.9 && input.len() > 100 {
191            result = result.with_warning("Input has unusually high whitespace ratio");
192        }
193
194        // Check for repeated characters (might indicate padding)
195        if has_excessive_repetition(input) {
196            result = result.with_warning("Input has excessive character repetition");
197        }
198
199        result
200    }
201
202    /// Validate tool parameters (recursively checks all string values in JSON).
203    pub fn validate_tool_params(&self, params: &serde_json::Value) -> ValidationResult {
204        let mut result = ValidationResult::ok();
205
206        fn check_strings(
207            value: &serde_json::Value,
208            validator: &InputValidator,
209            result: &mut ValidationResult,
210        ) {
211            match value {
212                serde_json::Value::String(s) => {
213                    let string_result = validator.validate(s);
214                    *result = std::mem::take(result).merge(string_result);
215                }
216                serde_json::Value::Array(arr) => {
217                    for item in arr {
218                        check_strings(item, validator, result);
219                    }
220                }
221                serde_json::Value::Object(obj) => {
222                    for (_, v) in obj {
223                        check_strings(v, validator, result);
224                    }
225                }
226                _ => {}
227            }
228        }
229
230        check_strings(params, self, &mut result);
231        result
232    }
233}
234
235impl Default for InputValidator {
236    fn default() -> Self {
237        Self::new()
238    }
239}
240
241/// Check if string has excessive repetition of characters.
242fn has_excessive_repetition(s: &str) -> bool {
243    if s.len() < 50 {
244        return false;
245    }
246
247    let chars: Vec<char> = s.chars().collect();
248    let mut max_repeat = 1;
249    let mut current_repeat = 1;
250
251    for i in 1..chars.len() {
252        if chars[i] == chars[i - 1] {
253            current_repeat += 1;
254            max_repeat = max_repeat.max(current_repeat);
255        } else {
256            current_repeat = 1;
257        }
258    }
259
260    // More than 20 repeated characters is suspicious
261    max_repeat > 20
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_valid_input() {
270        let validator = InputValidator::new();
271        let result = validator.validate("Hello, this is a normal message.");
272        assert!(result.is_valid);
273        assert!(result.errors.is_empty());
274    }
275
276    #[test]
277    fn test_empty_input() {
278        let validator = InputValidator::new();
279        let result = validator.validate("");
280        assert!(!result.is_valid);
281        assert!(result.errors.iter().any(|e| e.code == ValidationErrorCode::Empty));
282    }
283
284    #[test]
285    fn test_too_long_input() {
286        let validator = InputValidator::new().with_max_length(10);
287        let result = validator.validate("This is way too long for the limit");
288        assert!(!result.is_valid);
289        assert!(result.errors.iter().any(|e| e.code == ValidationErrorCode::TooLong));
290    }
291
292    #[test]
293    fn test_forbidden_pattern() {
294        let validator = InputValidator::new().forbid_pattern("forbidden");
295        let result = validator.validate("This contains FORBIDDEN content");
296        assert!(!result.is_valid);
297        assert!(result.errors.iter().any(|e| e.code == ValidationErrorCode::ForbiddenContent));
298    }
299
300    #[test]
301    fn test_excessive_repetition_warning() {
302        let validator = InputValidator::new();
303        // String needs to be >= 50 chars for repetition check
304        let result = validator.validate(&format!(
305            "Start of message{}End of message",
306            "a".repeat(30)
307        ));
308        assert!(result.is_valid); // Still valid, just a warning
309        assert!(!result.warnings.is_empty());
310    }
311
312    #[test]
313    fn test_null_bytes_rejected() {
314        let validator = InputValidator::new();
315        let result = validator.validate("Hello\x00World");
316        assert!(!result.is_valid);
317        assert!(result.errors.iter().any(|e| e.code == ValidationErrorCode::InvalidEncoding));
318    }
319
320    #[test]
321    fn test_validate_tool_params() {
322        let validator = InputValidator::new().forbid_pattern("secret_word");
323        let params = serde_json::json!({
324            "name": "test",
325            "nested": {
326                "value": "contains secret_word here"
327            }
328        });
329        let result = validator.validate_tool_params(&params);
330        assert!(!result.is_valid);
331    }
332
333    #[test]
334    fn test_high_whitespace_warning() {
335        let validator = InputValidator::new();
336        // Create a string that's mostly whitespace
337        let whitespace_heavy = format!("a{}", " ".repeat(150));
338        let result = validator.validate(&whitespace_heavy);
339        assert!(result.is_valid); // Valid, but has warning
340        assert!(result.warnings.iter().any(|w| w.contains("whitespace")));
341    }
342}