rustchain/validation/
mod.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use thiserror::Error;
4
5#[derive(Error, Debug)]
6pub enum ValidationError {
7    #[error("Required field '{field}' is missing")]
8    Required { field: String },
9
10    #[error("Field '{field}' has invalid format: {reason}")]
11    InvalidFormat { field: String, reason: String },
12
13    #[error("Field '{field}' exceeds maximum length of {max_length}")]
14    TooLong { field: String, max_length: usize },
15
16    #[error("Field '{field}' is below minimum length of {min_length}")]
17    TooShort { field: String, min_length: usize },
18
19    #[error("Field '{field}' value '{value}' is not in allowed list")]
20    NotInAllowedList { field: String, value: String },
21
22    #[error("Field '{field}' contains prohibited characters")]
23    ProhibitedCharacters { field: String },
24
25    #[error("Field '{field}' has invalid pattern")]
26    InvalidPattern { field: String },
27
28    #[error("Multiple validation errors: {errors:?}")]
29    Multiple { errors: Vec<ValidationError> },
30}
31
32pub type ValidationResult<T> = std::result::Result<T, ValidationError>;
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ValidationRule {
36    pub required: bool,
37    pub min_length: Option<usize>,
38    pub max_length: Option<usize>,
39    pub pattern: Option<String>,
40    pub allowed_values: Option<Vec<String>>,
41    pub prohibited_chars: Option<Vec<char>>,
42}
43
44impl Default for ValidationRule {
45    fn default() -> Self {
46        Self {
47            required: false,
48            min_length: None,
49            max_length: None,
50            pattern: None,
51            allowed_values: None,
52            prohibited_chars: Some(vec!['<', '>', '&', '"', '\'']), // Basic XSS protection
53        }
54    }
55}
56
57pub struct InputValidator {
58    rules: HashMap<String, ValidationRule>,
59}
60
61impl Default for InputValidator {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl InputValidator {
68    pub fn new() -> Self {
69        Self {
70            rules: HashMap::new(),
71        }
72    }
73
74    pub fn add_rule<S: Into<String>>(mut self, field: S, rule: ValidationRule) -> Self {
75        self.rules.insert(field.into(), rule);
76        self
77    }
78
79    pub fn validate_string(&self, field: &str, value: Option<&str>) -> ValidationResult<()> {
80        let default_rule = ValidationRule::default();
81        let rule = self.rules.get(field).unwrap_or(&default_rule);
82
83        // Check if required
84        if rule.required && value.is_none() {
85            return Err(ValidationError::Required {
86                field: field.to_string(),
87            });
88        }
89
90        if let Some(val) = value {
91            // Check length constraints
92            if let Some(min_len) = rule.min_length {
93                if val.len() < min_len {
94                    return Err(ValidationError::TooShort {
95                        field: field.to_string(),
96                        min_length: min_len,
97                    });
98                }
99            }
100
101            if let Some(max_len) = rule.max_length {
102                if val.len() > max_len {
103                    return Err(ValidationError::TooLong {
104                        field: field.to_string(),
105                        max_length: max_len,
106                    });
107                }
108            }
109
110            // Check allowed values
111            if let Some(allowed) = &rule.allowed_values {
112                if !allowed.contains(&val.to_string()) {
113                    return Err(ValidationError::NotInAllowedList {
114                        field: field.to_string(),
115                        value: val.to_string(),
116                    });
117                }
118            }
119
120            // Check prohibited characters
121            if let Some(prohibited) = &rule.prohibited_chars {
122                for &ch in prohibited {
123                    if val.contains(ch) {
124                        return Err(ValidationError::ProhibitedCharacters {
125                            field: field.to_string(),
126                        });
127                    }
128                }
129            }
130
131            // Check pattern
132            #[cfg(feature = "transpiler")]
133            if let Some(pattern) = &rule.pattern {
134                let regex =
135                    regex::Regex::new(pattern).map_err(|_| ValidationError::InvalidPattern {
136                        field: field.to_string(),
137                    })?;
138
139                if !regex.is_match(val) {
140                    return Err(ValidationError::InvalidFormat {
141                        field: field.to_string(),
142                        reason: format!("does not match pattern: {}", pattern),
143                    });
144                }
145            }
146            
147            // Fallback for when regex is not available
148            #[cfg(not(feature = "transpiler"))]
149            if rule.pattern.is_some() {
150                // Basic pattern matching without regex - just check if pattern is set
151                // This provides compatibility when transpiler feature is disabled
152                return Err(ValidationError::InvalidFormat {
153                    field: field.to_string(),
154                    reason: "Pattern validation requires transpiler feature".to_string(),
155                });
156            }
157        }
158
159        Ok(())
160    }
161
162    pub fn validate_mission_input(&self, input: &serde_json::Value) -> ValidationResult<()> {
163        let mut errors = Vec::new();
164
165        if let Some(obj) = input.as_object() {
166            for (key, value) in obj {
167                let string_value = value.as_str();
168                if let Err(e) = self.validate_string(key, string_value) {
169                    errors.push(e);
170                }
171            }
172        }
173
174        if errors.is_empty() {
175            Ok(())
176        } else if errors.len() == 1 {
177            Err(errors.into_iter().next().unwrap())
178        } else {
179            Err(ValidationError::Multiple { errors })
180        }
181    }
182}
183
184pub fn create_mission_validator() -> InputValidator {
185    InputValidator::new()
186        .add_rule(
187            "name",
188            ValidationRule {
189                required: true,
190                min_length: Some(1),
191                max_length: Some(100),
192                pattern: Some(r"^[a-zA-Z0-9\s\-_]+$".to_string()),
193                ..Default::default()
194            },
195        )
196        .add_rule(
197            "version",
198            ValidationRule {
199                required: true,
200                pattern: Some(r"^\d+\.\d+(\.\d+)?$".to_string()),
201                ..Default::default()
202            },
203        )
204        .add_rule(
205            "description",
206            ValidationRule {
207                max_length: Some(1000),
208                ..Default::default()
209            },
210        )
211}
212
213pub fn create_tool_input_validator() -> InputValidator {
214    InputValidator::new()
215        .add_rule(
216            "tool_name",
217            ValidationRule {
218                required: true,
219                min_length: Some(1),
220                max_length: Some(50),
221                pattern: Some(r"^[a-zA-Z0-9_]+$".to_string()),
222                ..Default::default()
223            },
224        )
225        .add_rule(
226            "command",
227            ValidationRule {
228                max_length: Some(500),
229                prohibited_chars: Some(vec!['&', '|', ';', '`', '$']),
230                ..Default::default()
231            },
232        )
233        .add_rule(
234            "file_path",
235            ValidationRule {
236                max_length: Some(255),
237                prohibited_chars: Some(vec!['<', '>', ':', '"', '|', '?', '*']),
238                ..Default::default()
239            },
240        )
241}
242
243pub fn create_api_input_validator() -> InputValidator {
244    InputValidator::new()
245        .add_rule(
246            "api_key",
247            ValidationRule {
248                required: true,
249                min_length: Some(16),
250                max_length: Some(128),
251                pattern: Some(r"^[a-zA-Z0-9\-_]+$".to_string()),
252                ..Default::default()
253            },
254        )
255        .add_rule(
256            "endpoint",
257            ValidationRule {
258                required: true,
259                pattern: Some(r"^/[a-zA-Z0-9\-_/]*$".to_string()),
260                max_length: Some(200),
261                ..Default::default()
262            },
263        )
264        .add_rule(
265            "user_input",
266            ValidationRule {
267                max_length: Some(10000),
268                prohibited_chars: Some(vec!['<', '>', '&', '"', '\'']),
269                ..Default::default()
270            },
271        )
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn test_required_field_validation() {
280        let validator = InputValidator::new().add_rule(
281            "required_field",
282            ValidationRule {
283                required: true,
284                ..Default::default()
285            },
286        );
287
288        assert!(validator.validate_string("required_field", None).is_err());
289        assert!(validator
290            .validate_string("required_field", Some("value"))
291            .is_ok());
292    }
293
294    #[test]
295    fn test_length_validation() {
296        let validator = InputValidator::new().add_rule(
297            "length_field",
298            ValidationRule {
299                min_length: Some(3),
300                max_length: Some(10),
301                ..Default::default()
302            },
303        );
304
305        assert!(validator
306            .validate_string("length_field", Some("ab"))
307            .is_err());
308        assert!(validator
309            .validate_string("length_field", Some("abc"))
310            .is_ok());
311        assert!(validator
312            .validate_string("length_field", Some("abcdefghijk"))
313            .is_err());
314    }
315
316    #[test]
317    fn test_prohibited_characters() {
318        let validator = InputValidator::new().add_rule(
319            "safe_field",
320            ValidationRule {
321                prohibited_chars: Some(vec!['<', '>', '&']),
322                ..Default::default()
323            },
324        );
325
326        assert!(validator
327            .validate_string("safe_field", Some("safe text"))
328            .is_ok());
329        assert!(validator
330            .validate_string("safe_field", Some("unsafe <script>"))
331            .is_err());
332        assert!(validator
333            .validate_string("safe_field", Some("unsafe & dangerous"))
334            .is_err());
335    }
336
337    #[test]
338    fn test_pattern_validation() {
339        let validator = InputValidator::new().add_rule(
340            "version",
341            ValidationRule {
342                pattern: Some(r"^\d+\.\d+\.\d+$".to_string()),
343                ..Default::default()
344            },
345        );
346
347        assert!(validator.validate_string("version", Some("1.0.0")).is_ok());
348        assert!(validator
349            .validate_string("version", Some("invalid"))
350            .is_err());
351    }
352
353    #[test]
354    fn test_mission_validator() {
355        let validator = create_mission_validator();
356
357        let valid_mission = serde_json::json!({
358            "name": "Valid Mission",
359            "version": "1.0.0",
360            "description": "A valid mission description"
361        });
362
363        assert!(validator.validate_mission_input(&valid_mission).is_ok());
364
365        let invalid_mission = serde_json::json!({
366            "name": "Invalid<script>",
367            "version": "invalid_version"
368        });
369
370        assert!(validator.validate_mission_input(&invalid_mission).is_err());
371    }
372
373    #[test]
374    fn test_tool_input_validator() {
375        let validator = create_tool_input_validator();
376
377        assert!(validator
378            .validate_string("tool_name", Some("valid_tool"))
379            .is_ok());
380        assert!(validator
381            .validate_string("tool_name", Some("invalid-tool!"))
382            .is_err());
383
384        assert!(validator.validate_string("command", Some("ls -la")).is_ok());
385        assert!(validator
386            .validate_string("command", Some("rm -rf / && evil"))
387            .is_err());
388    }
389}