rustyclaw_core/security/
validator.rs1use std::collections::HashSet;
14
15#[derive(Debug, Clone)]
17pub struct ValidationResult {
18 pub is_valid: bool,
20 pub errors: Vec<ValidationError>,
22 pub warnings: Vec<String>,
24}
25
26impl ValidationResult {
27 pub fn ok() -> Self {
29 Self {
30 is_valid: true,
31 errors: vec![],
32 warnings: vec![],
33 }
34 }
35
36 pub fn error(error: ValidationError) -> Self {
38 Self {
39 is_valid: false,
40 errors: vec![error],
41 warnings: vec![],
42 }
43 }
44
45 pub fn with_warning(mut self, warning: impl Into<String>) -> Self {
47 self.warnings.push(warning.into());
48 self
49 }
50
51 pub fn merge(mut self, other: Self) -> Self {
53 self.is_valid = self.is_valid && other.is_valid;
54 self.errors.extend(other.errors);
55 self.warnings.extend(other.warnings);
56 self
57 }
58}
59
60impl Default for ValidationResult {
61 fn default() -> Self {
62 Self::ok()
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct ValidationError {
69 pub field: String,
71 pub message: String,
73 pub code: ValidationErrorCode,
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
79pub enum ValidationErrorCode {
80 Empty,
81 TooLong,
82 TooShort,
83 InvalidFormat,
84 ForbiddenContent,
85 InvalidEncoding,
86 SuspiciousPattern,
87}
88
89pub struct InputValidator {
91 max_length: usize,
93 min_length: usize,
95 forbidden_patterns: HashSet<String>,
97}
98
99impl InputValidator {
100 pub fn new() -> Self {
102 Self {
103 max_length: 100_000,
104 min_length: 1,
105 forbidden_patterns: HashSet::new(),
106 }
107 }
108
109 pub fn with_max_length(mut self, max: usize) -> Self {
111 self.max_length = max;
112 self
113 }
114
115 pub fn with_min_length(mut self, min: usize) -> Self {
117 self.min_length = min;
118 self
119 }
120
121 pub fn forbid_pattern(mut self, pattern: impl Into<String>) -> Self {
123 self.forbidden_patterns
124 .insert(pattern.into().to_lowercase());
125 self
126 }
127
128 pub fn validate(&self, input: &str) -> ValidationResult {
130 let mut result = ValidationResult::ok();
131
132 if input.is_empty() {
134 return ValidationResult::error(ValidationError {
135 field: "input".to_string(),
136 message: "Input cannot be empty".to_string(),
137 code: ValidationErrorCode::Empty,
138 });
139 }
140
141 if input.len() > self.max_length {
143 result = result.merge(ValidationResult::error(ValidationError {
144 field: "input".to_string(),
145 message: format!(
146 "Input too long: {} bytes (max {})",
147 input.len(),
148 self.max_length
149 ),
150 code: ValidationErrorCode::TooLong,
151 }));
152 }
153
154 if input.len() < self.min_length {
155 result = result.merge(ValidationResult::error(ValidationError {
156 field: "input".to_string(),
157 message: format!(
158 "Input too short: {} bytes (min {})",
159 input.len(),
160 self.min_length
161 ),
162 code: ValidationErrorCode::TooShort,
163 }));
164 }
165
166 if input.chars().any(|c| c == '\x00') {
168 result = result.merge(ValidationResult::error(ValidationError {
169 field: "input".to_string(),
170 message: "Input contains null bytes".to_string(),
171 code: ValidationErrorCode::InvalidEncoding,
172 }));
173 }
174
175 let lower_input = input.to_lowercase();
177 for pattern in &self.forbidden_patterns {
178 if lower_input.contains(pattern) {
179 result = result.merge(ValidationResult::error(ValidationError {
180 field: "input".to_string(),
181 message: format!("Input contains forbidden pattern: {}", pattern),
182 code: ValidationErrorCode::ForbiddenContent,
183 }));
184 }
185 }
186
187 let whitespace_ratio =
189 input.chars().filter(|c| c.is_whitespace()).count() as f64 / input.len() as f64;
190 if whitespace_ratio > 0.9 && input.len() > 100 {
191 result = result.with_warning("Input has unusually high whitespace ratio");
192 }
193
194 if has_excessive_repetition(input) {
196 result = result.with_warning("Input has excessive character repetition");
197 }
198
199 result
200 }
201
202 pub fn validate_tool_params(&self, params: &serde_json::Value) -> ValidationResult {
204 let mut result = ValidationResult::ok();
205
206 fn check_strings(
207 value: &serde_json::Value,
208 validator: &InputValidator,
209 result: &mut ValidationResult,
210 ) {
211 match value {
212 serde_json::Value::String(s) => {
213 let string_result = validator.validate(s);
214 *result = std::mem::take(result).merge(string_result);
215 }
216 serde_json::Value::Array(arr) => {
217 for item in arr {
218 check_strings(item, validator, result);
219 }
220 }
221 serde_json::Value::Object(obj) => {
222 for (_, v) in obj {
223 check_strings(v, validator, result);
224 }
225 }
226 _ => {}
227 }
228 }
229
230 check_strings(params, self, &mut result);
231 result
232 }
233}
234
235impl Default for InputValidator {
236 fn default() -> Self {
237 Self::new()
238 }
239}
240
241fn has_excessive_repetition(s: &str) -> bool {
243 if s.len() < 50 {
244 return false;
245 }
246
247 let chars: Vec<char> = s.chars().collect();
248 let mut max_repeat = 1;
249 let mut current_repeat = 1;
250
251 for i in 1..chars.len() {
252 if chars[i] == chars[i - 1] {
253 current_repeat += 1;
254 max_repeat = max_repeat.max(current_repeat);
255 } else {
256 current_repeat = 1;
257 }
258 }
259
260 max_repeat > 20
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_valid_input() {
270 let validator = InputValidator::new();
271 let result = validator.validate("Hello, this is a normal message.");
272 assert!(result.is_valid);
273 assert!(result.errors.is_empty());
274 }
275
276 #[test]
277 fn test_empty_input() {
278 let validator = InputValidator::new();
279 let result = validator.validate("");
280 assert!(!result.is_valid);
281 assert!(result.errors.iter().any(|e| e.code == ValidationErrorCode::Empty));
282 }
283
284 #[test]
285 fn test_too_long_input() {
286 let validator = InputValidator::new().with_max_length(10);
287 let result = validator.validate("This is way too long for the limit");
288 assert!(!result.is_valid);
289 assert!(result.errors.iter().any(|e| e.code == ValidationErrorCode::TooLong));
290 }
291
292 #[test]
293 fn test_forbidden_pattern() {
294 let validator = InputValidator::new().forbid_pattern("forbidden");
295 let result = validator.validate("This contains FORBIDDEN content");
296 assert!(!result.is_valid);
297 assert!(result.errors.iter().any(|e| e.code == ValidationErrorCode::ForbiddenContent));
298 }
299
300 #[test]
301 fn test_excessive_repetition_warning() {
302 let validator = InputValidator::new();
303 let result = validator.validate(&format!(
305 "Start of message{}End of message",
306 "a".repeat(30)
307 ));
308 assert!(result.is_valid); assert!(!result.warnings.is_empty());
310 }
311
312 #[test]
313 fn test_null_bytes_rejected() {
314 let validator = InputValidator::new();
315 let result = validator.validate("Hello\x00World");
316 assert!(!result.is_valid);
317 assert!(result.errors.iter().any(|e| e.code == ValidationErrorCode::InvalidEncoding));
318 }
319
320 #[test]
321 fn test_validate_tool_params() {
322 let validator = InputValidator::new().forbid_pattern("secret_word");
323 let params = serde_json::json!({
324 "name": "test",
325 "nested": {
326 "value": "contains secret_word here"
327 }
328 });
329 let result = validator.validate_tool_params(¶ms);
330 assert!(!result.is_valid);
331 }
332
333 #[test]
334 fn test_high_whitespace_warning() {
335 let validator = InputValidator::new();
336 let whitespace_heavy = format!("a{}", " ".repeat(150));
338 let result = validator.validate(&whitespace_heavy);
339 assert!(result.is_valid); assert!(result.warnings.iter().any(|w| w.contains("whitespace")));
341 }
342}