Skip to main content

sh_layer0/
input_validator.rs

1//! 输入验证模块
2//!
3//! 验证所有外部输入的格式和安全性。
4
5use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// 验证结果
10#[derive(Debug, Serialize, Deserialize)]
11pub struct ValidationResult {
12    /// 是否有效
13    pub valid: bool,
14    /// 错误消息
15    pub errors: Vec<String>,
16    /// 验证后的数据(可选)
17    pub sanitized: Option<String>,
18}
19
20/// 输入验证器
21pub struct InputValidator {
22    /// 最大输入长度
23    max_length: usize,
24    /// 禁止的模式
25    forbidden_patterns: Vec<String>,
26    /// 必须的字段
27    #[allow(dead_code)]
28    required_fields: HashMap<String, bool>,
29}
30
31impl InputValidator {
32    pub fn new() -> Self {
33        Self {
34            max_length: 100_000, // 100KB 默认上限
35            forbidden_patterns: vec![
36                // 潜在危险模式
37                "<script>".to_string(),
38                "javascript:".to_string(),
39                "data:".to_string(),
40            ],
41            required_fields: HashMap::new(),
42        }
43    }
44
45    /// 验证输入
46    pub fn validate(&self, input: &str) -> Result<ValidationResult> {
47        let mut errors = Vec::new();
48
49        // 检查长度
50        if input.len() > self.max_length {
51            errors.push(format!(
52                "Input too long: {} bytes (max {})",
53                input.len(),
54                self.max_length
55            ));
56        }
57
58        // 检查禁止模式
59        for pattern in &self.forbidden_patterns {
60            if input.contains(pattern) {
61                errors.push(format!("Forbidden pattern detected: {}", pattern));
62            }
63        }
64
65        // 检查空输入
66        if input.trim().is_empty() {
67            errors.push("Input is empty".to_string());
68        }
69
70        let valid = errors.is_empty();
71        let sanitized = if valid {
72            Some(self.sanitize(input))
73        } else {
74            None
75        };
76
77        Ok(ValidationResult {
78            valid,
79            errors,
80            sanitized,
81        })
82    }
83
84    /// 清理输入
85    fn sanitize(&self, input: &str) -> String {
86        // 移除控制字符
87        input
88            .chars()
89            .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
90            .collect::<String>()
91            .trim()
92            .to_string()
93    }
94
95    /// 设置最大长度
96    pub fn with_max_length(mut self, max_length: usize) -> Self {
97        self.max_length = max_length;
98        self
99    }
100
101    /// 添加禁止模式
102    pub fn add_forbidden_pattern(mut self, pattern: String) -> Self {
103        self.forbidden_patterns.push(pattern);
104        self
105    }
106}
107
108impl Default for InputValidator {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_valid_input() {
120        let validator = InputValidator::new();
121        let result = validator.validate("Hello, world!").unwrap();
122        assert!(result.valid);
123        assert!(result.sanitized.is_some());
124    }
125
126    #[test]
127    fn test_empty_input() {
128        let validator = InputValidator::new();
129        let result = validator.validate("").unwrap();
130        assert!(!result.valid);
131        assert!(result.errors.contains(&"Input is empty".to_string()));
132    }
133
134    #[test]
135    fn test_forbidden_pattern() {
136        let validator = InputValidator::new();
137        let result = validator.validate("<script>alert('xss')</script>").unwrap();
138        assert!(!result.valid);
139    }
140
141    #[test]
142    fn test_max_length_boundary() {
143        // 测试刚好在边界上
144        let validator = InputValidator::new().with_max_length(100);
145
146        // 刚好 100 字节 - 应该有效
147        let input_at_limit = "a".repeat(100);
148        let result = validator.validate(&input_at_limit).unwrap();
149        assert!(result.valid, "Input at max length should be valid");
150
151        // 超过 100 字节 - 应该无效
152        let input_over_limit = "a".repeat(101);
153        let result = validator.validate(&input_over_limit).unwrap();
154        assert!(!result.valid, "Input over max length should be invalid");
155        assert!(result.errors.iter().any(|e| e.contains("too long")));
156    }
157
158    #[test]
159    fn test_unicode_handling() {
160        let validator = InputValidator::new();
161
162        // 中文输入
163        let result = validator.validate("你好世界,这是一个测试").unwrap();
164        assert!(result.valid);
165
166        // Emoji 输入
167        let result = validator.validate("Hello 🦀 Rust 🚀🎉").unwrap();
168        assert!(result.valid);
169
170        // 混合 Unicode
171        let result = validator.validate("日本語テスト العربية עברית").unwrap();
172        assert!(result.valid);
173
174        // Unicode 控制字符应该被清理
175        let result = validator.validate("Hello\x00World").unwrap();
176        assert!(result.valid);
177        assert!(result.sanitized.unwrap().contains("HelloWorld"));
178    }
179
180    #[test]
181    fn test_concurrent_validation() {
182        use std::sync::Arc;
183        use std::thread;
184
185        let validator = Arc::new(InputValidator::new());
186        let mut handles = vec![];
187
188        for i in 0..10 {
189            let v = Arc::clone(&validator);
190            handles.push(thread::spawn(move || {
191                let input = format!("Test input {}", i);
192                v.validate(&input).unwrap()
193            }));
194        }
195
196        let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
197
198        // 所有验证都应该成功
199        for result in results {
200            assert!(result.valid);
201        }
202    }
203
204    #[test]
205    fn test_sanitize_removes_control_chars() {
206        let validator = InputValidator::new();
207
208        // 包含各种控制字符的输入
209        let input = "Hello\x00\x01\x02World\nNewLine\tTab";
210        let result = validator.validate(input).unwrap();
211
212        assert!(result.valid);
213        let sanitized = result.sanitized.unwrap();
214        // 控制字符应该被移除(除了 \n 和 \t)
215        assert!(!sanitized.contains('\x00'));
216        assert!(!sanitized.contains('\x01'));
217        assert!(!sanitized.contains('\x02'));
218        // 换行和制表符应该保留
219        assert!(sanitized.contains('\n'));
220        assert!(sanitized.contains('\t'));
221    }
222
223    #[test]
224    fn test_whitespace_only_input() {
225        let validator = InputValidator::new();
226
227        let result = validator.validate("   \t\n  ").unwrap();
228        assert!(!result.valid);
229        assert!(result.errors.contains(&"Input is empty".to_string()));
230    }
231
232    #[test]
233    fn test_custom_forbidden_patterns() {
234        let validator = InputValidator::new()
235            .add_forbidden_pattern("SELECT * FROM".to_string())
236            .add_forbidden_pattern("DROP TABLE".to_string());
237
238        // SQL 注入尝试
239        let result = validator.validate("SELECT * FROM users").unwrap();
240        assert!(!result.valid);
241        assert!(result.errors.iter().any(|e| e.contains("SELECT * FROM")));
242
243        let result = validator.validate("DROP TABLE users").unwrap();
244        assert!(!result.valid);
245
246        // 正常输入
247        let result = validator
248            .validate("SELECT your option from the menu")
249            .unwrap();
250        assert!(result.valid);
251    }
252
253    // ========== 错误处理测试 ==========
254
255    #[test]
256    fn test_multiple_forbidden_patterns() {
257        let validator = InputValidator::new();
258
259        // 同时包含多个禁止模式的输入
260        let input = "<script>javascript:alert('xss')data:text/html";
261        let result = validator.validate(input).unwrap();
262        assert!(!result.valid);
263        // 应该检测到所有三个禁止模式
264        assert!(result.errors.len() >= 3);
265    }
266
267    #[test]
268    fn test_nested_forbidden_patterns() {
269        let validator = InputValidator::new();
270
271        // 嵌套的禁止模式
272        let result = validator
273            .validate("<script><script>nested</script></script>")
274            .unwrap();
275        assert!(!result.valid);
276    }
277
278    #[test]
279    fn test_case_sensitive_patterns() {
280        let validator = InputValidator::new();
281
282        // 大小写测试
283        let result = validator.validate("<SCRIPT>alert('xss')</SCRIPT>").unwrap();
284        // 默认是大小写敏感的,大写应该通过
285        assert!(result.valid);
286    }
287
288    #[test]
289    fn test_partial_forbidden_pattern() {
290        let validator = InputValidator::new().add_forbidden_pattern("dangerous".to_string());
291
292        // 部分匹配应该被检测到
293        let result = validator.validate("This is dangerous!").unwrap();
294        assert!(!result.valid);
295
296        // 作为子字符串也应该被检测
297        let result = validator.validate("verydangerousword").unwrap();
298        assert!(!result.valid);
299    }
300
301    #[test]
302    fn test_multiple_errors_accumulation() {
303        let validator = InputValidator::new().with_max_length(50);
304
305        // 同时触发多个错误:超长 + 禁止模式 + 空(trim后)
306        let input = format!("<script>{}", "a".repeat(100));
307        let result = validator.validate(&input).unwrap();
308        assert!(!result.valid);
309        assert!(result.errors.len() >= 2);
310    }
311
312    #[test]
313    fn test_extreme_max_length() {
314        // 极小的最大长度
315        let validator = InputValidator::new().with_max_length(1);
316
317        let result = validator.validate("ab").unwrap();
318        assert!(!result.valid);
319
320        let result = validator.validate("a").unwrap();
321        assert!(result.valid);
322    }
323
324    #[test]
325    fn test_zero_max_length() {
326        let validator = InputValidator::new().with_max_length(0);
327
328        // 任何非空输入都应该被拒绝
329        let result = validator.validate("a").unwrap();
330        assert!(!result.valid);
331
332        // 空输入也会被拒绝(因为是空的)
333        let result = validator.validate("").unwrap();
334        assert!(!result.valid);
335    }
336
337    #[test]
338    fn test_sanitize_preserves_newlines_and_tabs() {
339        let validator = InputValidator::new();
340
341        let input = "Line1\nLine2\tTabbed";
342        let result = validator.validate(input).unwrap();
343        assert!(result.valid);
344        let sanitized = result.sanitized.unwrap();
345        assert!(sanitized.contains('\n'));
346        assert!(sanitized.contains('\t'));
347    }
348
349    #[test]
350    fn test_sanitize_trims_whitespace() {
351        let validator = InputValidator::new();
352
353        let input = "   hello world   ";
354        let result = validator.validate(input).unwrap();
355        assert!(result.valid);
356        assert_eq!(result.sanitized.unwrap(), "hello world");
357    }
358
359    #[test]
360    fn test_all_control_characters_removed() {
361        let validator = InputValidator::new();
362
363        // 各种控制字符
364        let input = "Hello\x00\x01\x02\x03\x04\x05\x06\x07\x08World";
365        let result = validator.validate(input).unwrap();
366        assert!(result.valid);
367        let sanitized = result.sanitized.unwrap();
368        assert!(!sanitized.contains('\x00'));
369        assert!(!sanitized.contains('\x01'));
370        assert!(!sanitized.contains('\x07'));
371        assert!(sanitized.contains("Hello"));
372        assert!(sanitized.contains("World"));
373    }
374
375    #[test]
376    fn test_result_serialization() {
377        let validator = InputValidator::new();
378
379        let result = validator.validate("Hello").unwrap();
380        let json = serde_json::to_string(&result).unwrap();
381        assert!(json.contains("valid"));
382        assert!(json.contains("errors"));
383        assert!(json.contains("sanitized"));
384    }
385
386    #[test]
387    fn test_result_deserialization() {
388        let json = "{\"valid\":true,\"errors\":[],\"sanitized\":\"test\"}";
389        let result: ValidationResult = serde_json::from_str(json).unwrap();
390        assert!(result.valid);
391        assert!(result.errors.is_empty());
392        assert_eq!(result.sanitized, Some("test".to_string()));
393    }
394
395    #[test]
396    fn test_empty_errors_list_when_valid() {
397        let validator = InputValidator::new();
398
399        let result = validator.validate("valid input").unwrap();
400        assert!(result.valid);
401        assert!(result.errors.is_empty());
402    }
403
404    #[test]
405    fn test_sanitized_none_when_invalid() {
406        let validator = InputValidator::new();
407
408        let result = validator.validate("<script>").unwrap();
409        assert!(!result.valid);
410        assert!(result.sanitized.is_none());
411    }
412
413    #[test]
414    fn test_validator_builder_chain() {
415        let validator = InputValidator::new()
416            .with_max_length(500)
417            .add_forbidden_pattern("bad1".to_string())
418            .add_forbidden_pattern("bad2".to_string())
419            .add_forbidden_pattern("bad3".to_string());
420
421        let result = validator.validate("bad1bad2bad3").unwrap();
422        assert!(!result.valid);
423        assert!(result.errors.len() >= 3);
424    }
425}