1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use thiserror::Error;
4
5#[derive(Error, Debug)]
6pub enum ValidationError {
7 #[error("Required field '{field}' is missing")]
8 Required { field: String },
9
10 #[error("Field '{field}' has invalid format: {reason}")]
11 InvalidFormat { field: String, reason: String },
12
13 #[error("Field '{field}' exceeds maximum length of {max_length}")]
14 TooLong { field: String, max_length: usize },
15
16 #[error("Field '{field}' is below minimum length of {min_length}")]
17 TooShort { field: String, min_length: usize },
18
19 #[error("Field '{field}' value '{value}' is not in allowed list")]
20 NotInAllowedList { field: String, value: String },
21
22 #[error("Field '{field}' contains prohibited characters")]
23 ProhibitedCharacters { field: String },
24
25 #[error("Field '{field}' has invalid pattern")]
26 InvalidPattern { field: String },
27
28 #[error("Multiple validation errors: {errors:?}")]
29 Multiple { errors: Vec<ValidationError> },
30}
31
32pub type ValidationResult<T> = std::result::Result<T, ValidationError>;
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ValidationRule {
36 pub required: bool,
37 pub min_length: Option<usize>,
38 pub max_length: Option<usize>,
39 pub pattern: Option<String>,
40 pub allowed_values: Option<Vec<String>>,
41 pub prohibited_chars: Option<Vec<char>>,
42}
43
44impl Default for ValidationRule {
45 fn default() -> Self {
46 Self {
47 required: false,
48 min_length: None,
49 max_length: None,
50 pattern: None,
51 allowed_values: None,
52 prohibited_chars: Some(vec!['<', '>', '&', '"', '\'']), }
54 }
55}
56
57pub struct InputValidator {
58 rules: HashMap<String, ValidationRule>,
59}
60
61impl Default for InputValidator {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl InputValidator {
68 pub fn new() -> Self {
69 Self {
70 rules: HashMap::new(),
71 }
72 }
73
74 pub fn add_rule<S: Into<String>>(mut self, field: S, rule: ValidationRule) -> Self {
75 self.rules.insert(field.into(), rule);
76 self
77 }
78
79 pub fn validate_string(&self, field: &str, value: Option<&str>) -> ValidationResult<()> {
80 let default_rule = ValidationRule::default();
81 let rule = self.rules.get(field).unwrap_or(&default_rule);
82
83 if rule.required && value.is_none() {
85 return Err(ValidationError::Required {
86 field: field.to_string(),
87 });
88 }
89
90 if let Some(val) = value {
91 if let Some(min_len) = rule.min_length {
93 if val.len() < min_len {
94 return Err(ValidationError::TooShort {
95 field: field.to_string(),
96 min_length: min_len,
97 });
98 }
99 }
100
101 if let Some(max_len) = rule.max_length {
102 if val.len() > max_len {
103 return Err(ValidationError::TooLong {
104 field: field.to_string(),
105 max_length: max_len,
106 });
107 }
108 }
109
110 if let Some(allowed) = &rule.allowed_values {
112 if !allowed.contains(&val.to_string()) {
113 return Err(ValidationError::NotInAllowedList {
114 field: field.to_string(),
115 value: val.to_string(),
116 });
117 }
118 }
119
120 if let Some(prohibited) = &rule.prohibited_chars {
122 for &ch in prohibited {
123 if val.contains(ch) {
124 return Err(ValidationError::ProhibitedCharacters {
125 field: field.to_string(),
126 });
127 }
128 }
129 }
130
131 #[cfg(feature = "transpiler")]
133 if let Some(pattern) = &rule.pattern {
134 let regex =
135 regex::Regex::new(pattern).map_err(|_| ValidationError::InvalidPattern {
136 field: field.to_string(),
137 })?;
138
139 if !regex.is_match(val) {
140 return Err(ValidationError::InvalidFormat {
141 field: field.to_string(),
142 reason: format!("does not match pattern: {}", pattern),
143 });
144 }
145 }
146
147 #[cfg(not(feature = "transpiler"))]
149 if rule.pattern.is_some() {
150 return Err(ValidationError::InvalidFormat {
153 field: field.to_string(),
154 reason: "Pattern validation requires transpiler feature".to_string(),
155 });
156 }
157 }
158
159 Ok(())
160 }
161
162 pub fn validate_mission_input(&self, input: &serde_json::Value) -> ValidationResult<()> {
163 let mut errors = Vec::new();
164
165 if let Some(obj) = input.as_object() {
166 for (key, value) in obj {
167 let string_value = value.as_str();
168 if let Err(e) = self.validate_string(key, string_value) {
169 errors.push(e);
170 }
171 }
172 }
173
174 if errors.is_empty() {
175 Ok(())
176 } else if errors.len() == 1 {
177 Err(errors.into_iter().next().unwrap())
178 } else {
179 Err(ValidationError::Multiple { errors })
180 }
181 }
182}
183
184pub fn create_mission_validator() -> InputValidator {
185 InputValidator::new()
186 .add_rule(
187 "name",
188 ValidationRule {
189 required: true,
190 min_length: Some(1),
191 max_length: Some(100),
192 pattern: Some(r"^[a-zA-Z0-9\s\-_]+$".to_string()),
193 ..Default::default()
194 },
195 )
196 .add_rule(
197 "version",
198 ValidationRule {
199 required: true,
200 pattern: Some(r"^\d+\.\d+(\.\d+)?$".to_string()),
201 ..Default::default()
202 },
203 )
204 .add_rule(
205 "description",
206 ValidationRule {
207 max_length: Some(1000),
208 ..Default::default()
209 },
210 )
211}
212
213pub fn create_tool_input_validator() -> InputValidator {
214 InputValidator::new()
215 .add_rule(
216 "tool_name",
217 ValidationRule {
218 required: true,
219 min_length: Some(1),
220 max_length: Some(50),
221 pattern: Some(r"^[a-zA-Z0-9_]+$".to_string()),
222 ..Default::default()
223 },
224 )
225 .add_rule(
226 "command",
227 ValidationRule {
228 max_length: Some(500),
229 prohibited_chars: Some(vec!['&', '|', ';', '`', '$']),
230 ..Default::default()
231 },
232 )
233 .add_rule(
234 "file_path",
235 ValidationRule {
236 max_length: Some(255),
237 prohibited_chars: Some(vec!['<', '>', ':', '"', '|', '?', '*']),
238 ..Default::default()
239 },
240 )
241}
242
243pub fn create_api_input_validator() -> InputValidator {
244 InputValidator::new()
245 .add_rule(
246 "api_key",
247 ValidationRule {
248 required: true,
249 min_length: Some(16),
250 max_length: Some(128),
251 pattern: Some(r"^[a-zA-Z0-9\-_]+$".to_string()),
252 ..Default::default()
253 },
254 )
255 .add_rule(
256 "endpoint",
257 ValidationRule {
258 required: true,
259 pattern: Some(r"^/[a-zA-Z0-9\-_/]*$".to_string()),
260 max_length: Some(200),
261 ..Default::default()
262 },
263 )
264 .add_rule(
265 "user_input",
266 ValidationRule {
267 max_length: Some(10000),
268 prohibited_chars: Some(vec!['<', '>', '&', '"', '\'']),
269 ..Default::default()
270 },
271 )
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_required_field_validation() {
280 let validator = InputValidator::new().add_rule(
281 "required_field",
282 ValidationRule {
283 required: true,
284 ..Default::default()
285 },
286 );
287
288 assert!(validator.validate_string("required_field", None).is_err());
289 assert!(validator
290 .validate_string("required_field", Some("value"))
291 .is_ok());
292 }
293
294 #[test]
295 fn test_length_validation() {
296 let validator = InputValidator::new().add_rule(
297 "length_field",
298 ValidationRule {
299 min_length: Some(3),
300 max_length: Some(10),
301 ..Default::default()
302 },
303 );
304
305 assert!(validator
306 .validate_string("length_field", Some("ab"))
307 .is_err());
308 assert!(validator
309 .validate_string("length_field", Some("abc"))
310 .is_ok());
311 assert!(validator
312 .validate_string("length_field", Some("abcdefghijk"))
313 .is_err());
314 }
315
316 #[test]
317 fn test_prohibited_characters() {
318 let validator = InputValidator::new().add_rule(
319 "safe_field",
320 ValidationRule {
321 prohibited_chars: Some(vec!['<', '>', '&']),
322 ..Default::default()
323 },
324 );
325
326 assert!(validator
327 .validate_string("safe_field", Some("safe text"))
328 .is_ok());
329 assert!(validator
330 .validate_string("safe_field", Some("unsafe <script>"))
331 .is_err());
332 assert!(validator
333 .validate_string("safe_field", Some("unsafe & dangerous"))
334 .is_err());
335 }
336
337 #[test]
338 fn test_pattern_validation() {
339 let validator = InputValidator::new().add_rule(
340 "version",
341 ValidationRule {
342 pattern: Some(r"^\d+\.\d+\.\d+$".to_string()),
343 ..Default::default()
344 },
345 );
346
347 assert!(validator.validate_string("version", Some("1.0.0")).is_ok());
348 assert!(validator
349 .validate_string("version", Some("invalid"))
350 .is_err());
351 }
352
353 #[test]
354 fn test_mission_validator() {
355 let validator = create_mission_validator();
356
357 let valid_mission = serde_json::json!({
358 "name": "Valid Mission",
359 "version": "1.0.0",
360 "description": "A valid mission description"
361 });
362
363 assert!(validator.validate_mission_input(&valid_mission).is_ok());
364
365 let invalid_mission = serde_json::json!({
366 "name": "Invalid<script>",
367 "version": "invalid_version"
368 });
369
370 assert!(validator.validate_mission_input(&invalid_mission).is_err());
371 }
372
373 #[test]
374 fn test_tool_input_validator() {
375 let validator = create_tool_input_validator();
376
377 assert!(validator
378 .validate_string("tool_name", Some("valid_tool"))
379 .is_ok());
380 assert!(validator
381 .validate_string("tool_name", Some("invalid-tool!"))
382 .is_err());
383
384 assert!(validator.validate_string("command", Some("ls -la")).is_ok());
385 assert!(validator
386 .validate_string("command", Some("rm -rf / && evil"))
387 .is_err());
388 }
389}