sh_layer0/
input_validator.rs1use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Serialize, Deserialize)]
11pub struct ValidationResult {
12 pub valid: bool,
14 pub errors: Vec<String>,
16 pub sanitized: Option<String>,
18}
19
20pub struct InputValidator {
22 max_length: usize,
24 forbidden_patterns: Vec<String>,
26 #[allow(dead_code)]
28 required_fields: HashMap<String, bool>,
29}
30
31impl InputValidator {
32 pub fn new() -> Self {
33 Self {
34 max_length: 100_000, forbidden_patterns: vec![
36 "<script>".to_string(),
38 "javascript:".to_string(),
39 "data:".to_string(),
40 ],
41 required_fields: HashMap::new(),
42 }
43 }
44
45 pub fn validate(&self, input: &str) -> Result<ValidationResult> {
47 let mut errors = Vec::new();
48
49 if input.len() > self.max_length {
51 errors.push(format!(
52 "Input too long: {} bytes (max {})",
53 input.len(),
54 self.max_length
55 ));
56 }
57
58 for pattern in &self.forbidden_patterns {
60 if input.contains(pattern) {
61 errors.push(format!("Forbidden pattern detected: {}", pattern));
62 }
63 }
64
65 if input.trim().is_empty() {
67 errors.push("Input is empty".to_string());
68 }
69
70 let valid = errors.is_empty();
71 let sanitized = if valid {
72 Some(self.sanitize(input))
73 } else {
74 None
75 };
76
77 Ok(ValidationResult {
78 valid,
79 errors,
80 sanitized,
81 })
82 }
83
84 fn sanitize(&self, input: &str) -> String {
86 input
88 .chars()
89 .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
90 .collect::<String>()
91 .trim()
92 .to_string()
93 }
94
95 pub fn with_max_length(mut self, max_length: usize) -> Self {
97 self.max_length = max_length;
98 self
99 }
100
101 pub fn add_forbidden_pattern(mut self, pattern: String) -> Self {
103 self.forbidden_patterns.push(pattern);
104 self
105 }
106}
107
108impl Default for InputValidator {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_valid_input() {
120 let validator = InputValidator::new();
121 let result = validator.validate("Hello, world!").unwrap();
122 assert!(result.valid);
123 assert!(result.sanitized.is_some());
124 }
125
126 #[test]
127 fn test_empty_input() {
128 let validator = InputValidator::new();
129 let result = validator.validate("").unwrap();
130 assert!(!result.valid);
131 assert!(result.errors.contains(&"Input is empty".to_string()));
132 }
133
134 #[test]
135 fn test_forbidden_pattern() {
136 let validator = InputValidator::new();
137 let result = validator.validate("<script>alert('xss')</script>").unwrap();
138 assert!(!result.valid);
139 }
140
141 #[test]
142 fn test_max_length_boundary() {
143 let validator = InputValidator::new().with_max_length(100);
145
146 let input_at_limit = "a".repeat(100);
148 let result = validator.validate(&input_at_limit).unwrap();
149 assert!(result.valid, "Input at max length should be valid");
150
151 let input_over_limit = "a".repeat(101);
153 let result = validator.validate(&input_over_limit).unwrap();
154 assert!(!result.valid, "Input over max length should be invalid");
155 assert!(result.errors.iter().any(|e| e.contains("too long")));
156 }
157
158 #[test]
159 fn test_unicode_handling() {
160 let validator = InputValidator::new();
161
162 let result = validator.validate("你好世界,这是一个测试").unwrap();
164 assert!(result.valid);
165
166 let result = validator.validate("Hello 🦀 Rust 🚀🎉").unwrap();
168 assert!(result.valid);
169
170 let result = validator.validate("日本語テスト العربية עברית").unwrap();
172 assert!(result.valid);
173
174 let result = validator.validate("Hello\x00World").unwrap();
176 assert!(result.valid);
177 assert!(result.sanitized.unwrap().contains("HelloWorld"));
178 }
179
180 #[test]
181 fn test_concurrent_validation() {
182 use std::sync::Arc;
183 use std::thread;
184
185 let validator = Arc::new(InputValidator::new());
186 let mut handles = vec![];
187
188 for i in 0..10 {
189 let v = Arc::clone(&validator);
190 handles.push(thread::spawn(move || {
191 let input = format!("Test input {}", i);
192 v.validate(&input).unwrap()
193 }));
194 }
195
196 let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
197
198 for result in results {
200 assert!(result.valid);
201 }
202 }
203
204 #[test]
205 fn test_sanitize_removes_control_chars() {
206 let validator = InputValidator::new();
207
208 let input = "Hello\x00\x01\x02World\nNewLine\tTab";
210 let result = validator.validate(input).unwrap();
211
212 assert!(result.valid);
213 let sanitized = result.sanitized.unwrap();
214 assert!(!sanitized.contains('\x00'));
216 assert!(!sanitized.contains('\x01'));
217 assert!(!sanitized.contains('\x02'));
218 assert!(sanitized.contains('\n'));
220 assert!(sanitized.contains('\t'));
221 }
222
223 #[test]
224 fn test_whitespace_only_input() {
225 let validator = InputValidator::new();
226
227 let result = validator.validate(" \t\n ").unwrap();
228 assert!(!result.valid);
229 assert!(result.errors.contains(&"Input is empty".to_string()));
230 }
231
232 #[test]
233 fn test_custom_forbidden_patterns() {
234 let validator = InputValidator::new()
235 .add_forbidden_pattern("SELECT * FROM".to_string())
236 .add_forbidden_pattern("DROP TABLE".to_string());
237
238 let result = validator.validate("SELECT * FROM users").unwrap();
240 assert!(!result.valid);
241 assert!(result.errors.iter().any(|e| e.contains("SELECT * FROM")));
242
243 let result = validator.validate("DROP TABLE users").unwrap();
244 assert!(!result.valid);
245
246 let result = validator
248 .validate("SELECT your option from the menu")
249 .unwrap();
250 assert!(result.valid);
251 }
252
253 #[test]
256 fn test_multiple_forbidden_patterns() {
257 let validator = InputValidator::new();
258
259 let input = "<script>javascript:alert('xss')data:text/html";
261 let result = validator.validate(input).unwrap();
262 assert!(!result.valid);
263 assert!(result.errors.len() >= 3);
265 }
266
267 #[test]
268 fn test_nested_forbidden_patterns() {
269 let validator = InputValidator::new();
270
271 let result = validator
273 .validate("<script><script>nested</script></script>")
274 .unwrap();
275 assert!(!result.valid);
276 }
277
278 #[test]
279 fn test_case_sensitive_patterns() {
280 let validator = InputValidator::new();
281
282 let result = validator.validate("<SCRIPT>alert('xss')</SCRIPT>").unwrap();
284 assert!(result.valid);
286 }
287
288 #[test]
289 fn test_partial_forbidden_pattern() {
290 let validator = InputValidator::new().add_forbidden_pattern("dangerous".to_string());
291
292 let result = validator.validate("This is dangerous!").unwrap();
294 assert!(!result.valid);
295
296 let result = validator.validate("verydangerousword").unwrap();
298 assert!(!result.valid);
299 }
300
301 #[test]
302 fn test_multiple_errors_accumulation() {
303 let validator = InputValidator::new().with_max_length(50);
304
305 let input = format!("<script>{}", "a".repeat(100));
307 let result = validator.validate(&input).unwrap();
308 assert!(!result.valid);
309 assert!(result.errors.len() >= 2);
310 }
311
312 #[test]
313 fn test_extreme_max_length() {
314 let validator = InputValidator::new().with_max_length(1);
316
317 let result = validator.validate("ab").unwrap();
318 assert!(!result.valid);
319
320 let result = validator.validate("a").unwrap();
321 assert!(result.valid);
322 }
323
324 #[test]
325 fn test_zero_max_length() {
326 let validator = InputValidator::new().with_max_length(0);
327
328 let result = validator.validate("a").unwrap();
330 assert!(!result.valid);
331
332 let result = validator.validate("").unwrap();
334 assert!(!result.valid);
335 }
336
337 #[test]
338 fn test_sanitize_preserves_newlines_and_tabs() {
339 let validator = InputValidator::new();
340
341 let input = "Line1\nLine2\tTabbed";
342 let result = validator.validate(input).unwrap();
343 assert!(result.valid);
344 let sanitized = result.sanitized.unwrap();
345 assert!(sanitized.contains('\n'));
346 assert!(sanitized.contains('\t'));
347 }
348
349 #[test]
350 fn test_sanitize_trims_whitespace() {
351 let validator = InputValidator::new();
352
353 let input = " hello world ";
354 let result = validator.validate(input).unwrap();
355 assert!(result.valid);
356 assert_eq!(result.sanitized.unwrap(), "hello world");
357 }
358
359 #[test]
360 fn test_all_control_characters_removed() {
361 let validator = InputValidator::new();
362
363 let input = "Hello\x00\x01\x02\x03\x04\x05\x06\x07\x08World";
365 let result = validator.validate(input).unwrap();
366 assert!(result.valid);
367 let sanitized = result.sanitized.unwrap();
368 assert!(!sanitized.contains('\x00'));
369 assert!(!sanitized.contains('\x01'));
370 assert!(!sanitized.contains('\x07'));
371 assert!(sanitized.contains("Hello"));
372 assert!(sanitized.contains("World"));
373 }
374
375 #[test]
376 fn test_result_serialization() {
377 let validator = InputValidator::new();
378
379 let result = validator.validate("Hello").unwrap();
380 let json = serde_json::to_string(&result).unwrap();
381 assert!(json.contains("valid"));
382 assert!(json.contains("errors"));
383 assert!(json.contains("sanitized"));
384 }
385
386 #[test]
387 fn test_result_deserialization() {
388 let json = "{\"valid\":true,\"errors\":[],\"sanitized\":\"test\"}";
389 let result: ValidationResult = serde_json::from_str(json).unwrap();
390 assert!(result.valid);
391 assert!(result.errors.is_empty());
392 assert_eq!(result.sanitized, Some("test".to_string()));
393 }
394
395 #[test]
396 fn test_empty_errors_list_when_valid() {
397 let validator = InputValidator::new();
398
399 let result = validator.validate("valid input").unwrap();
400 assert!(result.valid);
401 assert!(result.errors.is_empty());
402 }
403
404 #[test]
405 fn test_sanitized_none_when_invalid() {
406 let validator = InputValidator::new();
407
408 let result = validator.validate("<script>").unwrap();
409 assert!(!result.valid);
410 assert!(result.sanitized.is_none());
411 }
412
413 #[test]
414 fn test_validator_builder_chain() {
415 let validator = InputValidator::new()
416 .with_max_length(500)
417 .add_forbidden_pattern("bad1".to_string())
418 .add_forbidden_pattern("bad2".to_string())
419 .add_forbidden_pattern("bad3".to_string());
420
421 let result = validator.validate("bad1bad2bad3").unwrap();
422 assert!(!result.valid);
423 assert!(result.errors.len() >= 3);
424 }
425}