turbomcp_protocol/
validation.rs

1//! # Protocol Validation
2//!
3//! This module provides comprehensive validation for MCP protocol messages,
4//! ensuring data integrity and specification compliance.
5
6use regex::Regex;
7use serde_json::Value;
8use std::collections::{HashMap, HashSet};
9
10use crate::jsonrpc::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
11use crate::types::*;
12
13/// Protocol message validator
14#[derive(Debug, Clone)]
15pub struct ProtocolValidator {
16    /// Validation rules
17    rules: ValidationRules,
18    /// Strict validation mode
19    strict_mode: bool,
20}
21
22/// Validation rules configuration
23#[derive(Debug, Clone)]
24pub struct ValidationRules {
25    /// Maximum message size in bytes
26    pub max_message_size: usize,
27    /// Maximum batch size
28    pub max_batch_size: usize,
29    /// Maximum string length
30    pub max_string_length: usize,
31    /// Maximum array length
32    pub max_array_length: usize,
33    /// Maximum object depth
34    pub max_object_depth: usize,
35    /// URI validation regex
36    pub uri_regex: Regex,
37    /// Method name validation regex
38    pub method_name_regex: Regex,
39    /// Required fields per message type
40    pub required_fields: HashMap<String, HashSet<String>>,
41}
42
43/// Validation result
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum ValidationResult {
46    /// Validation passed
47    Valid,
48    /// Validation passed with warnings
49    ValidWithWarnings(Vec<ValidationWarning>),
50    /// Validation failed
51    Invalid(Vec<ValidationError>),
52}
53
54/// Validation warning
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct ValidationWarning {
57    /// Warning code
58    pub code: String,
59    /// Warning message
60    pub message: String,
61    /// Field path (if applicable)
62    pub field_path: Option<String>,
63}
64
65/// Validation error
66#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct ValidationError {
68    /// Error code
69    pub code: String,
70    /// Error message
71    pub message: String,
72    /// Field path (if applicable)
73    pub field_path: Option<String>,
74}
75
76/// Validation context for tracking state during validation
77#[derive(Debug, Clone)]
78struct ValidationContext {
79    /// Current field path
80    path: Vec<String>,
81    /// Current object depth
82    depth: usize,
83    /// Accumulated warnings
84    warnings: Vec<ValidationWarning>,
85    /// Accumulated errors
86    errors: Vec<ValidationError>,
87}
88
89impl Default for ValidationRules {
90    fn default() -> Self {
91        let uri_regex = Regex::new(r"^[a-zA-Z][a-zA-Z0-9+.-]*:").unwrap();
92        let method_name_regex = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_/]*$").unwrap();
93
94        let mut required_fields = HashMap::new();
95
96        // JSON-RPC required fields
97        required_fields.insert(
98            "request".to_string(),
99            ["jsonrpc", "method", "id"]
100                .iter()
101                .map(|s| s.to_string())
102                .collect(),
103        );
104        required_fields.insert(
105            "response".to_string(),
106            ["jsonrpc", "id"].iter().map(|s| s.to_string()).collect(),
107        );
108        required_fields.insert(
109            "notification".to_string(),
110            ["jsonrpc", "method"]
111                .iter()
112                .map(|s| s.to_string())
113                .collect(),
114        );
115
116        // MCP message required fields
117        required_fields.insert(
118            "initialize".to_string(),
119            ["protocolVersion", "capabilities", "clientInfo"]
120                .iter()
121                .map(|s| s.to_string())
122                .collect(),
123        );
124        required_fields.insert(
125            "tool".to_string(),
126            ["name", "inputSchema"]
127                .iter()
128                .map(|s| s.to_string())
129                .collect(),
130        );
131        required_fields.insert(
132            "prompt".to_string(),
133            ["name"].iter().map(|s| s.to_string()).collect(),
134        );
135        required_fields.insert(
136            "resource".to_string(),
137            ["uri", "name"].iter().map(|s| s.to_string()).collect(),
138        );
139
140        Self {
141            max_message_size: 10 * 1024 * 1024, // 10MB
142            max_batch_size: 100,
143            max_string_length: 1024 * 1024, // 1MB
144            max_array_length: 10000,
145            max_object_depth: 32,
146            uri_regex,
147            method_name_regex,
148            required_fields,
149        }
150    }
151}
152
153impl ProtocolValidator {
154    /// Create a new validator with default rules
155    pub fn new() -> Self {
156        Self {
157            rules: ValidationRules::default(),
158            strict_mode: false,
159        }
160    }
161
162    /// Enable strict validation mode
163    pub fn with_strict_mode(mut self) -> Self {
164        self.strict_mode = true;
165        self
166    }
167
168    /// Set custom validation rules
169    pub fn with_rules(mut self, rules: ValidationRules) -> Self {
170        self.rules = rules;
171        self
172    }
173
174    /// Validate a JSON-RPC request
175    pub fn validate_request(&self, request: &JsonRpcRequest) -> ValidationResult {
176        let mut ctx = ValidationContext::new();
177
178        // Validate JSON-RPC structure
179        self.validate_jsonrpc_request(request, &mut ctx);
180
181        // Validate method name
182        self.validate_method_name(&request.method, &mut ctx);
183
184        // Validate parameters based on method
185        if let Some(params) = &request.params {
186            self.validate_method_params(&request.method, params, &mut ctx);
187        }
188
189        ctx.into_result()
190    }
191
192    /// Validate a JSON-RPC response
193    pub fn validate_response(&self, response: &JsonRpcResponse) -> ValidationResult {
194        let mut ctx = ValidationContext::new();
195
196        // Validate JSON-RPC structure
197        self.validate_jsonrpc_response(response, &mut ctx);
198
199        // Ensure either result or error is present (but not both)
200        match (response.result.is_some(), response.error.is_some()) {
201            (true, true) => {
202                ctx.add_error(
203                    "RESPONSE_BOTH_RESULT_AND_ERROR",
204                    "Response cannot have both result and error".to_string(),
205                    None,
206                );
207            }
208            (false, false) => {
209                ctx.add_error(
210                    "RESPONSE_MISSING_RESULT_OR_ERROR",
211                    "Response must have either result or error".to_string(),
212                    None,
213                );
214            }
215            _ => {} // Valid
216        }
217
218        ctx.into_result()
219    }
220
221    /// Validate a JSON-RPC notification
222    pub fn validate_notification(&self, notification: &JsonRpcNotification) -> ValidationResult {
223        let mut ctx = ValidationContext::new();
224
225        // Validate JSON-RPC structure
226        self.validate_jsonrpc_notification(notification, &mut ctx);
227
228        // Validate method name
229        self.validate_method_name(&notification.method, &mut ctx);
230
231        // Validate parameters based on method
232        if let Some(params) = &notification.params {
233            self.validate_method_params(&notification.method, params, &mut ctx);
234        }
235
236        ctx.into_result()
237    }
238
239    /// Validate MCP protocol types
240    pub fn validate_tool(&self, tool: &Tool) -> ValidationResult {
241        let mut ctx = ValidationContext::new();
242
243        // Validate tool name
244        if tool.name.is_empty() {
245            ctx.add_error(
246                "TOOL_EMPTY_NAME",
247                "Tool name cannot be empty".to_string(),
248                Some("name".to_string()),
249            );
250        }
251
252        if tool.name.len() > self.rules.max_string_length {
253            ctx.add_error(
254                "TOOL_NAME_TOO_LONG",
255                format!(
256                    "Tool name exceeds maximum length of {}",
257                    self.rules.max_string_length
258                ),
259                Some("name".to_string()),
260            );
261        }
262
263        // Validate input schema
264        self.validate_tool_input(&tool.input_schema, &mut ctx);
265
266        ctx.into_result()
267    }
268
269    /// Validate a prompt
270    pub fn validate_prompt(&self, prompt: &Prompt) -> ValidationResult {
271        let mut ctx = ValidationContext::new();
272
273        // Validate prompt name
274        if prompt.name.is_empty() {
275            ctx.add_error(
276                "PROMPT_EMPTY_NAME",
277                "Prompt name cannot be empty".to_string(),
278                Some("name".to_string()),
279            );
280        }
281
282        // Validate arguments if present
283        if let Some(arguments) = &prompt.arguments
284            && arguments.len() > self.rules.max_array_length
285        {
286            ctx.add_error(
287                "PROMPT_TOO_MANY_ARGS",
288                format!(
289                    "Prompt has too many arguments (max: {})",
290                    self.rules.max_array_length
291                ),
292                Some("arguments".to_string()),
293            );
294        }
295
296        ctx.into_result()
297    }
298
299    /// Validate a resource
300    pub fn validate_resource(&self, resource: &Resource) -> ValidationResult {
301        let mut ctx = ValidationContext::new();
302
303        // Validate URI
304        if !self.rules.uri_regex.is_match(&resource.uri) {
305            ctx.add_error(
306                "RESOURCE_INVALID_URI",
307                format!("Invalid URI format: {}", resource.uri),
308                Some("uri".to_string()),
309            );
310        }
311
312        // Validate name
313        if resource.name.is_empty() {
314            ctx.add_error(
315                "RESOURCE_EMPTY_NAME",
316                "Resource name cannot be empty".to_string(),
317                Some("name".to_string()),
318            );
319        }
320
321        ctx.into_result()
322    }
323
324    /// Validate initialization request
325    pub fn validate_initialize_request(&self, request: &InitializeRequest) -> ValidationResult {
326        let mut ctx = ValidationContext::new();
327
328        // Validate protocol version
329        if !crate::SUPPORTED_VERSIONS.contains(&request.protocol_version.as_str()) {
330            ctx.add_warning(
331                "UNSUPPORTED_PROTOCOL_VERSION",
332                format!(
333                    "Protocol version {} is not officially supported",
334                    request.protocol_version
335                ),
336                Some("protocolVersion".to_string()),
337            );
338        }
339
340        // Validate client info
341        if request.client_info.name.is_empty() {
342            ctx.add_error(
343                "EMPTY_CLIENT_NAME",
344                "Client name cannot be empty".to_string(),
345                Some("clientInfo.name".to_string()),
346            );
347        }
348
349        if request.client_info.version.is_empty() {
350            ctx.add_error(
351                "EMPTY_CLIENT_VERSION",
352                "Client version cannot be empty".to_string(),
353                Some("clientInfo.version".to_string()),
354            );
355        }
356
357        ctx.into_result()
358    }
359
360    // Private validation methods
361
362    fn validate_jsonrpc_request(&self, _request: &JsonRpcRequest, _ctx: &mut ValidationContext) {
363        // Method name validation is done separately
364
365        // Validate ID is present (required for requests)
366        // Note: ID validation is handled by the type system
367    }
368
369    fn validate_jsonrpc_response(&self, response: &JsonRpcResponse, ctx: &mut ValidationContext) {
370        // Basic structure validation is handled by the type system
371        if let Some(error) = &response.error {
372            self.validate_jsonrpc_error(error, ctx);
373        }
374    }
375
376    fn validate_jsonrpc_notification(
377        &self,
378        _notification: &JsonRpcNotification,
379        _ctx: &mut ValidationContext,
380    ) {
381        // Basic structure validation is handled by the type system
382    }
383
384    fn validate_jsonrpc_error(
385        &self,
386        error: &crate::jsonrpc::JsonRpcError,
387        ctx: &mut ValidationContext,
388    ) {
389        // Error codes should be in the valid range
390        if error.code >= 0 {
391            ctx.add_warning(
392                "POSITIVE_ERROR_CODE",
393                "Error codes should be negative according to JSON-RPC spec".to_string(),
394                Some("error.code".to_string()),
395            );
396        }
397
398        if error.message.is_empty() {
399            ctx.add_error(
400                "EMPTY_ERROR_MESSAGE",
401                "Error message cannot be empty".to_string(),
402                Some("error.message".to_string()),
403            );
404        }
405    }
406
407    fn validate_method_name(&self, method: &str, ctx: &mut ValidationContext) {
408        if method.is_empty() {
409            ctx.add_error(
410                "EMPTY_METHOD_NAME",
411                "Method name cannot be empty".to_string(),
412                Some("method".to_string()),
413            );
414            return;
415        }
416
417        if !self.rules.method_name_regex.is_match(method) {
418            ctx.add_error(
419                "INVALID_METHOD_NAME",
420                format!("Invalid method name format: {method}"),
421                Some("method".to_string()),
422            );
423        }
424    }
425
426    fn validate_method_params(&self, method: &str, params: &Value, ctx: &mut ValidationContext) {
427        ctx.push_path("params".to_string());
428
429        match method {
430            "initialize" => self.validate_value_structure(params, "initialize", ctx),
431            "tools/list" => {
432                // Should be empty object or null
433                if !params.is_null() && !params.as_object().is_some_and(|obj| obj.is_empty()) {
434                    ctx.add_warning(
435                        "UNEXPECTED_PARAMS",
436                        "tools/list should not have parameters".to_string(),
437                        None,
438                    );
439                }
440            }
441            "tools/call" => self.validate_value_structure(params, "call_tool", ctx),
442            _ => {
443                // Unknown method - validate basic structure
444                self.validate_value_structure(params, "generic", ctx);
445            }
446        }
447
448        ctx.pop_path();
449    }
450
451    fn validate_tool_input(&self, input: &ToolInputSchema, ctx: &mut ValidationContext) {
452        ctx.push_path("inputSchema".to_string());
453
454        // Validate schema type
455        if input.schema_type != "object" {
456            ctx.add_warning(
457                "NON_OBJECT_SCHEMA",
458                "Tool input schema should typically be 'object'".to_string(),
459                Some("type".to_string()),
460            );
461        }
462
463        ctx.pop_path();
464    }
465
466    fn validate_value_structure(
467        &self,
468        value: &Value,
469        _expected_type: &str,
470        ctx: &mut ValidationContext,
471    ) {
472        // Prevent infinite recursion
473        if ctx.depth > self.rules.max_object_depth {
474            ctx.add_error(
475                "MAX_DEPTH_EXCEEDED",
476                format!(
477                    "Maximum object depth ({}) exceeded",
478                    self.rules.max_object_depth
479                ),
480                None,
481            );
482            return;
483        }
484
485        match value {
486            Value::Object(obj) => {
487                ctx.depth += 1;
488                for (key, val) in obj {
489                    ctx.push_path(key.clone());
490                    self.validate_value_structure(val, "unknown", ctx);
491                    ctx.pop_path();
492                }
493                ctx.depth -= 1;
494            }
495            Value::Array(arr) => {
496                if arr.len() > self.rules.max_array_length {
497                    ctx.add_error(
498                        "ARRAY_TOO_LONG",
499                        format!(
500                            "Array exceeds maximum length of {}",
501                            self.rules.max_array_length
502                        ),
503                        None,
504                    );
505                }
506
507                for (index, val) in arr.iter().enumerate() {
508                    ctx.push_path(index.to_string());
509                    self.validate_value_structure(val, "unknown", ctx);
510                    ctx.pop_path();
511                }
512            }
513            Value::String(s) => {
514                if s.len() > self.rules.max_string_length {
515                    ctx.add_error(
516                        "STRING_TOO_LONG",
517                        format!(
518                            "String exceeds maximum length of {}",
519                            self.rules.max_string_length
520                        ),
521                        None,
522                    );
523                }
524            }
525            _ => {} // Other types are fine
526        }
527    }
528}
529
530impl Default for ProtocolValidator {
531    fn default() -> Self {
532        Self::new()
533    }
534}
535
536impl ValidationContext {
537    fn new() -> Self {
538        Self {
539            path: Vec::new(),
540            depth: 0,
541            warnings: Vec::new(),
542            errors: Vec::new(),
543        }
544    }
545
546    fn push_path(&mut self, segment: String) {
547        self.path.push(segment);
548    }
549
550    fn pop_path(&mut self) {
551        self.path.pop();
552    }
553
554    fn current_path(&self) -> Option<String> {
555        if self.path.is_empty() {
556            None
557        } else {
558            Some(self.path.join("."))
559        }
560    }
561
562    fn add_error(&mut self, code: &str, message: String, field_path: Option<String>) {
563        let path = field_path.or_else(|| self.current_path());
564        self.errors.push(ValidationError {
565            code: code.to_string(),
566            message,
567            field_path: path,
568        });
569    }
570
571    fn add_warning(&mut self, code: &str, message: String, field_path: Option<String>) {
572        let path = field_path.or_else(|| self.current_path());
573        self.warnings.push(ValidationWarning {
574            code: code.to_string(),
575            message,
576            field_path: path,
577        });
578    }
579
580    fn into_result(self) -> ValidationResult {
581        if !self.errors.is_empty() {
582            ValidationResult::Invalid(self.errors)
583        } else if !self.warnings.is_empty() {
584            ValidationResult::ValidWithWarnings(self.warnings)
585        } else {
586            ValidationResult::Valid
587        }
588    }
589}
590
591impl ValidationResult {
592    /// Check if validation passed (with or without warnings)
593    pub fn is_valid(&self) -> bool {
594        !matches!(self, ValidationResult::Invalid(_))
595    }
596
597    /// Check if validation failed
598    pub fn is_invalid(&self) -> bool {
599        matches!(self, ValidationResult::Invalid(_))
600    }
601
602    /// Check if validation has warnings
603    pub fn has_warnings(&self) -> bool {
604        matches!(self, ValidationResult::ValidWithWarnings(_))
605    }
606
607    /// Get warnings (if any)
608    pub fn warnings(&self) -> &[ValidationWarning] {
609        match self {
610            ValidationResult::ValidWithWarnings(warnings) => warnings,
611            _ => &[],
612        }
613    }
614
615    /// Get errors (if any)
616    pub fn errors(&self) -> &[ValidationError] {
617        match self {
618            ValidationResult::Invalid(errors) => errors,
619            _ => &[],
620        }
621    }
622}
623
624/// Utility functions for validation
625pub mod utils {
626    use super::*;
627
628    /// Create a validation error
629    pub fn error(code: &str, message: &str) -> ValidationError {
630        ValidationError {
631            code: code.to_string(),
632            message: message.to_string(),
633            field_path: None,
634        }
635    }
636
637    /// Create a validation warning
638    pub fn warning(code: &str, message: &str) -> ValidationWarning {
639        ValidationWarning {
640            code: code.to_string(),
641            message: message.to_string(),
642            field_path: None,
643        }
644    }
645
646    /// Check if a string is a valid URI
647    pub fn is_valid_uri(uri: &str) -> bool {
648        ValidationRules::default().uri_regex.is_match(uri)
649    }
650
651    /// Check if a string is a valid method name
652    pub fn is_valid_method_name(method: &str) -> bool {
653        ValidationRules::default()
654            .method_name_regex
655            .is_match(method)
656    }
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662    use crate::jsonrpc::JsonRpcVersion;
663    // use serde_json::json;
664
665    #[test]
666    fn test_tool_validation() {
667        let validator = ProtocolValidator::new();
668
669        let tool = Tool {
670            name: "test_tool".to_string(),
671            title: Some("Test Tool".to_string()),
672            description: Some("A test tool".to_string()),
673            input_schema: ToolInputSchema {
674                schema_type: "object".to_string(),
675                properties: None,
676                required: None,
677                additional_properties: None,
678            },
679            output_schema: None,
680            annotations: None,
681            meta: None,
682        };
683
684        let result = validator.validate_tool(&tool);
685        assert!(result.is_valid());
686
687        // Test empty name
688        let invalid_tool = Tool {
689            name: String::new(),
690            title: None,
691            description: None,
692            input_schema: tool.input_schema.clone(),
693            output_schema: None,
694            annotations: None,
695            meta: None,
696        };
697
698        let result = validator.validate_tool(&invalid_tool);
699        assert!(result.is_invalid());
700    }
701
702    #[test]
703    fn test_request_validation() {
704        let validator = ProtocolValidator::new();
705
706        let request = JsonRpcRequest {
707            jsonrpc: JsonRpcVersion,
708            method: "tools/list".to_string(),
709            params: None,
710            id: RequestId::String("test-id".to_string()),
711        };
712
713        let result = validator.validate_request(&request);
714        assert!(result.is_valid());
715
716        // Test invalid method name
717        let invalid_request = JsonRpcRequest {
718            jsonrpc: JsonRpcVersion,
719            method: String::new(),
720            params: None,
721            id: RequestId::String("test-id".to_string()),
722        };
723
724        let result = validator.validate_request(&invalid_request);
725        assert!(result.is_invalid());
726    }
727
728    #[test]
729    fn test_initialize_validation() {
730        let validator = ProtocolValidator::new();
731
732        let request = InitializeRequest {
733            protocol_version: "2025-06-18".to_string(),
734            capabilities: ClientCapabilities::default(),
735            client_info: Implementation {
736                name: "test-client".to_string(),
737                title: Some("Test Client".to_string()),
738                version: "1.0.0".to_string(),
739            },
740        };
741
742        let result = validator.validate_initialize_request(&request);
743        assert!(result.is_valid());
744
745        // Test unsupported version (should warn, not error)
746        let request_with_old_version = InitializeRequest {
747            protocol_version: "2023-01-01".to_string(),
748            capabilities: ClientCapabilities::default(),
749            client_info: Implementation {
750                name: "test-client".to_string(),
751                title: Some("Test Client".to_string()),
752                version: "1.0.0".to_string(),
753            },
754        };
755
756        let result = validator.validate_initialize_request(&request_with_old_version);
757        assert!(result.is_valid()); // Valid but with warnings
758        assert!(result.has_warnings());
759    }
760
761    #[test]
762    fn test_validation_result() {
763        let valid = ValidationResult::Valid;
764        assert!(valid.is_valid());
765        assert!(!valid.is_invalid());
766        assert!(!valid.has_warnings());
767
768        let warnings = vec![utils::warning("TEST", "Test warning")];
769        let valid_with_warnings = ValidationResult::ValidWithWarnings(warnings.clone());
770        assert!(valid_with_warnings.is_valid());
771        assert!(valid_with_warnings.has_warnings());
772        assert_eq!(valid_with_warnings.warnings(), &warnings);
773
774        let errors = vec![utils::error("TEST", "Test error")];
775        let invalid = ValidationResult::Invalid(errors.clone());
776        assert!(!invalid.is_valid());
777        assert!(invalid.is_invalid());
778        assert_eq!(invalid.errors(), &errors);
779    }
780
781    #[test]
782    fn test_utils() {
783        assert!(utils::is_valid_uri("file://test.txt"));
784        assert!(utils::is_valid_uri("https://example.com"));
785        assert!(!utils::is_valid_uri("not-a-uri"));
786
787        assert!(utils::is_valid_method_name("tools/list"));
788        assert!(utils::is_valid_method_name("initialize"));
789        assert!(!utils::is_valid_method_name("invalid-method-name!"));
790    }
791}