1use regex::Regex;
7use serde_json::Value;
8use std::collections::{HashMap, HashSet};
9
10use crate::jsonrpc::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
11use crate::types::*;
12
13#[derive(Debug, Clone)]
15pub struct ProtocolValidator {
16 rules: ValidationRules,
18 strict_mode: bool,
20}
21
22#[derive(Debug, Clone)]
24pub struct ValidationRules {
25 pub max_message_size: usize,
27 pub max_batch_size: usize,
29 pub max_string_length: usize,
31 pub max_array_length: usize,
33 pub max_object_depth: usize,
35 pub uri_regex: Regex,
37 pub method_name_regex: Regex,
39 pub required_fields: HashMap<String, HashSet<String>>,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum ValidationResult {
46 Valid,
48 ValidWithWarnings(Vec<ValidationWarning>),
50 Invalid(Vec<ValidationError>),
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct ValidationWarning {
57 pub code: String,
59 pub message: String,
61 pub field_path: Option<String>,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct ValidationError {
68 pub code: String,
70 pub message: String,
72 pub field_path: Option<String>,
74}
75
76#[derive(Debug, Clone)]
78struct ValidationContext {
79 path: Vec<String>,
81 depth: usize,
83 warnings: Vec<ValidationWarning>,
85 errors: Vec<ValidationError>,
87}
88
89impl Default for ValidationRules {
90 fn default() -> Self {
91 let uri_regex = Regex::new(r"^[a-zA-Z][a-zA-Z0-9+.-]*:").unwrap();
92 let method_name_regex = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_/]*$").unwrap();
93
94 let mut required_fields = HashMap::new();
95
96 required_fields.insert(
98 "request".to_string(),
99 ["jsonrpc", "method", "id"]
100 .iter()
101 .map(|s| s.to_string())
102 .collect(),
103 );
104 required_fields.insert(
105 "response".to_string(),
106 ["jsonrpc", "id"].iter().map(|s| s.to_string()).collect(),
107 );
108 required_fields.insert(
109 "notification".to_string(),
110 ["jsonrpc", "method"]
111 .iter()
112 .map(|s| s.to_string())
113 .collect(),
114 );
115
116 required_fields.insert(
118 "initialize".to_string(),
119 ["protocolVersion", "capabilities", "clientInfo"]
120 .iter()
121 .map(|s| s.to_string())
122 .collect(),
123 );
124 required_fields.insert(
125 "tool".to_string(),
126 ["name", "inputSchema"]
127 .iter()
128 .map(|s| s.to_string())
129 .collect(),
130 );
131 required_fields.insert(
132 "prompt".to_string(),
133 ["name"].iter().map(|s| s.to_string()).collect(),
134 );
135 required_fields.insert(
136 "resource".to_string(),
137 ["uri", "name"].iter().map(|s| s.to_string()).collect(),
138 );
139
140 Self {
141 max_message_size: 10 * 1024 * 1024, max_batch_size: 100,
143 max_string_length: 1024 * 1024, max_array_length: 10000,
145 max_object_depth: 32,
146 uri_regex,
147 method_name_regex,
148 required_fields,
149 }
150 }
151}
152
153impl ProtocolValidator {
154 pub fn new() -> Self {
156 Self {
157 rules: ValidationRules::default(),
158 strict_mode: false,
159 }
160 }
161
162 pub fn with_strict_mode(mut self) -> Self {
164 self.strict_mode = true;
165 self
166 }
167
168 pub fn with_rules(mut self, rules: ValidationRules) -> Self {
170 self.rules = rules;
171 self
172 }
173
174 pub fn validate_request(&self, request: &JsonRpcRequest) -> ValidationResult {
176 let mut ctx = ValidationContext::new();
177
178 self.validate_jsonrpc_request(request, &mut ctx);
180
181 if let Some(params) = &request.params {
183 self.validate_method_params(&request.method, params, &mut ctx);
184 }
185
186 ctx.into_result()
187 }
188
189 pub fn validate_response(&self, response: &JsonRpcResponse) -> ValidationResult {
191 let mut ctx = ValidationContext::new();
192
193 self.validate_jsonrpc_response(response, &mut ctx);
195
196 match (response.result().is_some(), response.error().is_some()) {
200 (true, true) => {
201 ctx.add_error(
202 "RESPONSE_BOTH_RESULT_AND_ERROR",
203 "Response cannot have both result and error".to_string(),
204 None,
205 );
206 }
207 (false, false) => {
208 ctx.add_error(
209 "RESPONSE_MISSING_RESULT_OR_ERROR",
210 "Response must have either result or error".to_string(),
211 None,
212 );
213 }
214 _ => {} }
216
217 ctx.into_result()
218 }
219
220 pub fn validate_notification(&self, notification: &JsonRpcNotification) -> ValidationResult {
222 let mut ctx = ValidationContext::new();
223
224 self.validate_jsonrpc_notification(notification, &mut ctx);
226
227 self.validate_method_name(¬ification.method, &mut ctx);
229
230 if let Some(params) = ¬ification.params {
232 self.validate_method_params(¬ification.method, params, &mut ctx);
233 }
234
235 ctx.into_result()
236 }
237
238 pub fn validate_tool(&self, tool: &Tool) -> ValidationResult {
240 let mut ctx = ValidationContext::new();
241
242 if tool.name.is_empty() {
244 ctx.add_error(
245 "TOOL_EMPTY_NAME",
246 "Tool name cannot be empty".to_string(),
247 Some("name".to_string()),
248 );
249 }
250
251 if tool.name.len() > self.rules.max_string_length {
252 ctx.add_error(
253 "TOOL_NAME_TOO_LONG",
254 format!(
255 "Tool name exceeds maximum length of {}",
256 self.rules.max_string_length
257 ),
258 Some("name".to_string()),
259 );
260 }
261
262 self.validate_tool_input(&tool.input_schema, &mut ctx);
264
265 ctx.into_result()
266 }
267
268 pub fn validate_prompt(&self, prompt: &Prompt) -> ValidationResult {
270 let mut ctx = ValidationContext::new();
271
272 if prompt.name.is_empty() {
274 ctx.add_error(
275 "PROMPT_EMPTY_NAME",
276 "Prompt name cannot be empty".to_string(),
277 Some("name".to_string()),
278 );
279 }
280
281 if let Some(arguments) = &prompt.arguments
283 && arguments.len() > self.rules.max_array_length
284 {
285 ctx.add_error(
286 "PROMPT_TOO_MANY_ARGS",
287 format!(
288 "Prompt has too many arguments (max: {})",
289 self.rules.max_array_length
290 ),
291 Some("arguments".to_string()),
292 );
293 }
294
295 ctx.into_result()
296 }
297
298 pub fn validate_resource(&self, resource: &Resource) -> ValidationResult {
300 let mut ctx = ValidationContext::new();
301
302 if !self.rules.uri_regex.is_match(&resource.uri) {
304 ctx.add_error(
305 "RESOURCE_INVALID_URI",
306 format!("Invalid URI format: {}", resource.uri),
307 Some("uri".to_string()),
308 );
309 }
310
311 if resource.name.is_empty() {
313 ctx.add_error(
314 "RESOURCE_EMPTY_NAME",
315 "Resource name cannot be empty".to_string(),
316 Some("name".to_string()),
317 );
318 }
319
320 ctx.into_result()
321 }
322
323 pub fn validate_initialize_request(&self, request: &InitializeRequest) -> ValidationResult {
325 let mut ctx = ValidationContext::new();
326
327 if !crate::SUPPORTED_VERSIONS.contains(&request.protocol_version.as_str()) {
329 ctx.add_warning(
330 "UNSUPPORTED_PROTOCOL_VERSION",
331 format!(
332 "Protocol version {} is not officially supported",
333 request.protocol_version
334 ),
335 Some("protocolVersion".to_string()),
336 );
337 }
338
339 if request.client_info.name.is_empty() {
341 ctx.add_error(
342 "EMPTY_CLIENT_NAME",
343 "Client name cannot be empty".to_string(),
344 Some("clientInfo.name".to_string()),
345 );
346 }
347
348 if request.client_info.version.is_empty() {
349 ctx.add_error(
350 "EMPTY_CLIENT_VERSION",
351 "Client version cannot be empty".to_string(),
352 Some("clientInfo.version".to_string()),
353 );
354 }
355
356 ctx.into_result()
357 }
358
359 fn validate_jsonrpc_request(&self, request: &JsonRpcRequest, ctx: &mut ValidationContext) {
362 if request.method.is_empty() {
367 ctx.add_error(
368 "EMPTY_METHOD_NAME",
369 "Method name cannot be empty".to_string(),
370 Some("method".to_string()),
371 );
372 } else if request.method.len() > self.rules.max_string_length {
373 ctx.add_error(
374 "METHOD_NAME_TOO_LONG",
375 format!(
376 "Method name exceeds maximum length of {}",
377 self.rules.max_string_length
378 ),
379 Some("method".to_string()),
380 );
381 } else if !utils::is_valid_method_name(&request.method) {
382 ctx.add_error(
383 "INVALID_METHOD_NAME",
384 format!("Invalid method name format: '{}'", request.method),
385 Some("method".to_string()),
386 );
387 }
388
389 if let Some(ref params) = request.params {
391 self.validate_parameters(params, ctx);
392 }
393
394 self.validate_request_id(&request.id, ctx);
397 }
398
399 fn validate_jsonrpc_response(&self, response: &JsonRpcResponse, ctx: &mut ValidationContext) {
400 self.validate_response_id(&response.id, ctx);
408
409 if let Some(error) = response.error() {
411 self.validate_jsonrpc_error(error, ctx);
412 }
413
414 if let Some(result) = response.result() {
416 self.validate_result_value(result, ctx);
417 }
418 }
419
420 fn validate_jsonrpc_notification(
421 &self,
422 notification: &JsonRpcNotification,
423 ctx: &mut ValidationContext,
424 ) {
425 if notification.method.is_empty() {
430 ctx.add_error(
431 "EMPTY_METHOD_NAME",
432 "Method name cannot be empty".to_string(),
433 Some("method".to_string()),
434 );
435 } else if notification.method.len() > self.rules.max_string_length {
436 ctx.add_error(
437 "METHOD_NAME_TOO_LONG",
438 format!(
439 "Method name exceeds maximum length of {}",
440 self.rules.max_string_length
441 ),
442 Some("method".to_string()),
443 );
444 } else if !utils::is_valid_method_name(¬ification.method) {
445 ctx.add_error(
446 "INVALID_METHOD_NAME",
447 format!("Invalid method name format: '{}'", notification.method),
448 Some("method".to_string()),
449 );
450 }
451
452 if let Some(ref params) = notification.params {
454 self.validate_parameters(params, ctx);
455 }
456
457 }
459
460 fn validate_jsonrpc_error(
461 &self,
462 error: &crate::jsonrpc::JsonRpcError,
463 ctx: &mut ValidationContext,
464 ) {
465 if error.code >= 0 {
467 ctx.add_warning(
468 "POSITIVE_ERROR_CODE",
469 "Error codes should be negative according to JSON-RPC spec".to_string(),
470 Some("error.code".to_string()),
471 );
472 }
473
474 if error.message.is_empty() {
475 ctx.add_error(
476 "EMPTY_ERROR_MESSAGE",
477 "Error message cannot be empty".to_string(),
478 Some("error.message".to_string()),
479 );
480 }
481 }
482
483 fn validate_method_name(&self, method: &str, ctx: &mut ValidationContext) {
484 if method.is_empty() {
485 ctx.add_error(
486 "EMPTY_METHOD_NAME",
487 "Method name cannot be empty".to_string(),
488 Some("method".to_string()),
489 );
490 return;
491 }
492
493 if !self.rules.method_name_regex.is_match(method) {
494 ctx.add_error(
495 "INVALID_METHOD_NAME",
496 format!("Invalid method name format: {method}"),
497 Some("method".to_string()),
498 );
499 }
500 }
501
502 fn validate_method_params(&self, method: &str, params: &Value, ctx: &mut ValidationContext) {
503 ctx.push_path("params".to_string());
504
505 match method {
506 "initialize" => self.validate_value_structure(params, "initialize", ctx),
507 "tools/list" => {
508 if !params.is_null() && !params.as_object().is_some_and(|obj| obj.is_empty()) {
510 ctx.add_warning(
511 "UNEXPECTED_PARAMS",
512 "tools/list should not have parameters".to_string(),
513 None,
514 );
515 }
516 }
517 "tools/call" => self.validate_value_structure(params, "call_tool", ctx),
518 _ => {
519 self.validate_value_structure(params, "generic", ctx);
521 }
522 }
523
524 ctx.pop_path();
525 }
526
527 fn validate_tool_input(&self, input: &ToolInputSchema, ctx: &mut ValidationContext) {
528 ctx.push_path("inputSchema".to_string());
529
530 if input.schema_type != "object" {
532 ctx.add_warning(
533 "NON_OBJECT_SCHEMA",
534 "Tool input schema should typically be 'object'".to_string(),
535 Some("type".to_string()),
536 );
537 }
538
539 ctx.pop_path();
540 }
541
542 fn validate_value_structure(
543 &self,
544 value: &Value,
545 _expected_type: &str,
546 ctx: &mut ValidationContext,
547 ) {
548 if ctx.depth > self.rules.max_object_depth {
550 ctx.add_error(
551 "MAX_DEPTH_EXCEEDED",
552 format!(
553 "Maximum object depth ({}) exceeded",
554 self.rules.max_object_depth
555 ),
556 None,
557 );
558 return;
559 }
560
561 match value {
562 Value::Object(obj) => {
563 ctx.depth += 1;
564 for (key, val) in obj {
565 ctx.push_path(key.clone());
566 self.validate_value_structure(val, "unknown", ctx);
567 ctx.pop_path();
568 }
569 ctx.depth -= 1;
570 }
571 Value::Array(arr) => {
572 if arr.len() > self.rules.max_array_length {
573 ctx.add_error(
574 "ARRAY_TOO_LONG",
575 format!(
576 "Array exceeds maximum length of {}",
577 self.rules.max_array_length
578 ),
579 None,
580 );
581 }
582
583 for (index, val) in arr.iter().enumerate() {
584 ctx.push_path(index.to_string());
585 self.validate_value_structure(val, "unknown", ctx);
586 ctx.pop_path();
587 }
588 }
589 Value::String(s) => {
590 if s.len() > self.rules.max_string_length {
591 ctx.add_error(
592 "STRING_TOO_LONG",
593 format!(
594 "String exceeds maximum length of {}",
595 self.rules.max_string_length
596 ),
597 None,
598 );
599 }
600 }
601 _ => {} }
603 }
604
605 fn validate_parameters(&self, params: &Value, ctx: &mut ValidationContext) {
606 self.validate_value_structure(params, "params", ctx);
608
609 match params {
611 Value::Array(arr) => {
612 if arr.len() > self.rules.max_array_length {
614 ctx.add_error(
615 "PARAMS_ARRAY_TOO_LONG",
616 format!(
617 "Parameter array exceeds maximum length of {}",
618 self.rules.max_array_length
619 ),
620 Some("params".to_string()),
621 );
622 }
623 }
624 _ => {
625 }
627 }
628 }
629
630 fn validate_request_id(&self, _id: &crate::types::RequestId, _ctx: &mut ValidationContext) {
631 }
635
636 fn validate_response_id(&self, id: &crate::jsonrpc::ResponseId, _ctx: &mut ValidationContext) {
637 if id.is_null() {
639 }
642 }
644
645 fn validate_result_value(&self, result: &Value, ctx: &mut ValidationContext) {
646 self.validate_value_structure(result, "result", ctx);
648
649 }
652}
653
654impl Default for ProtocolValidator {
655 fn default() -> Self {
656 Self::new()
657 }
658}
659
660impl ValidationContext {
661 fn new() -> Self {
662 Self {
663 path: Vec::new(),
664 depth: 0,
665 warnings: Vec::new(),
666 errors: Vec::new(),
667 }
668 }
669
670 fn push_path(&mut self, segment: String) {
671 self.path.push(segment);
672 }
673
674 fn pop_path(&mut self) {
675 self.path.pop();
676 }
677
678 fn current_path(&self) -> Option<String> {
679 if self.path.is_empty() {
680 None
681 } else {
682 Some(self.path.join("."))
683 }
684 }
685
686 fn add_error(&mut self, code: &str, message: String, field_path: Option<String>) {
687 let path = field_path.or_else(|| self.current_path());
688 self.errors.push(ValidationError {
689 code: code.to_string(),
690 message,
691 field_path: path,
692 });
693 }
694
695 fn add_warning(&mut self, code: &str, message: String, field_path: Option<String>) {
696 let path = field_path.or_else(|| self.current_path());
697 self.warnings.push(ValidationWarning {
698 code: code.to_string(),
699 message,
700 field_path: path,
701 });
702 }
703
704 fn into_result(self) -> ValidationResult {
705 if !self.errors.is_empty() {
706 ValidationResult::Invalid(self.errors)
707 } else if !self.warnings.is_empty() {
708 ValidationResult::ValidWithWarnings(self.warnings)
709 } else {
710 ValidationResult::Valid
711 }
712 }
713}
714
715impl ValidationResult {
716 pub fn is_valid(&self) -> bool {
718 !matches!(self, ValidationResult::Invalid(_))
719 }
720
721 pub fn is_invalid(&self) -> bool {
723 matches!(self, ValidationResult::Invalid(_))
724 }
725
726 pub fn has_warnings(&self) -> bool {
728 matches!(self, ValidationResult::ValidWithWarnings(_))
729 }
730
731 pub fn warnings(&self) -> &[ValidationWarning] {
733 match self {
734 ValidationResult::ValidWithWarnings(warnings) => warnings,
735 _ => &[],
736 }
737 }
738
739 pub fn errors(&self) -> &[ValidationError] {
741 match self {
742 ValidationResult::Invalid(errors) => errors,
743 _ => &[],
744 }
745 }
746}
747
748pub mod utils {
750 use super::*;
751
752 pub fn error(code: &str, message: &str) -> ValidationError {
754 ValidationError {
755 code: code.to_string(),
756 message: message.to_string(),
757 field_path: None,
758 }
759 }
760
761 pub fn warning(code: &str, message: &str) -> ValidationWarning {
763 ValidationWarning {
764 code: code.to_string(),
765 message: message.to_string(),
766 field_path: None,
767 }
768 }
769
770 pub fn is_valid_uri(uri: &str) -> bool {
772 ValidationRules::default().uri_regex.is_match(uri)
773 }
774
775 pub fn is_valid_method_name(method: &str) -> bool {
777 ValidationRules::default()
778 .method_name_regex
779 .is_match(method)
780 }
781}
782
783#[cfg(test)]
784mod tests {
785 use super::*;
786 use crate::jsonrpc::JsonRpcVersion;
787 #[test]
790 fn test_tool_validation() {
791 let validator = ProtocolValidator::new();
792
793 let tool = Tool {
794 name: "test_tool".to_string(),
795 title: Some("Test Tool".to_string()),
796 description: Some("A test tool".to_string()),
797 input_schema: ToolInputSchema {
798 schema_type: "object".to_string(),
799 properties: None,
800 required: None,
801 additional_properties: None,
802 },
803 output_schema: None,
804 annotations: None,
805 meta: None,
806 };
807
808 let result = validator.validate_tool(&tool);
809 assert!(result.is_valid());
810
811 let invalid_tool = Tool {
813 name: String::new(),
814 title: None,
815 description: None,
816 input_schema: tool.input_schema.clone(),
817 output_schema: None,
818 annotations: None,
819 meta: None,
820 };
821
822 let result = validator.validate_tool(&invalid_tool);
823 assert!(result.is_invalid());
824 }
825
826 #[test]
827 fn test_request_validation() {
828 let validator = ProtocolValidator::new();
829
830 let request = JsonRpcRequest {
831 jsonrpc: JsonRpcVersion,
832 method: "tools/list".to_string(),
833 params: None,
834 id: RequestId::String("test-id".to_string()),
835 };
836
837 let result = validator.validate_request(&request);
838 assert!(result.is_valid());
839
840 let invalid_request = JsonRpcRequest {
842 jsonrpc: JsonRpcVersion,
843 method: String::new(),
844 params: None,
845 id: RequestId::String("test-id".to_string()),
846 };
847
848 let result = validator.validate_request(&invalid_request);
849 assert!(result.is_invalid());
850 }
851
852 #[test]
853 fn test_initialize_validation() {
854 let validator = ProtocolValidator::new();
855
856 let request = InitializeRequest {
857 protocol_version: "2025-06-18".to_string(),
858 capabilities: ClientCapabilities::default(),
859 client_info: Implementation {
860 name: "test-client".to_string(),
861 title: Some("Test Client".to_string()),
862 version: "1.0.0".to_string(),
863 },
864 _meta: None,
865 };
866
867 let result = validator.validate_initialize_request(&request);
868 assert!(result.is_valid());
869
870 let request_with_old_version = InitializeRequest {
872 protocol_version: "2023-01-01".to_string(),
873 capabilities: ClientCapabilities::default(),
874 client_info: Implementation {
875 name: "test-client".to_string(),
876 title: Some("Test Client".to_string()),
877 version: "1.0.0".to_string(),
878 },
879 _meta: None,
880 };
881
882 let result = validator.validate_initialize_request(&request_with_old_version);
883 assert!(result.is_valid()); assert!(result.has_warnings());
885 }
886
887 #[test]
888 fn test_validation_result() {
889 let valid = ValidationResult::Valid;
890 assert!(valid.is_valid());
891 assert!(!valid.is_invalid());
892 assert!(!valid.has_warnings());
893
894 let warnings = vec![utils::warning("TEST", "Test warning")];
895 let valid_with_warnings = ValidationResult::ValidWithWarnings(warnings.clone());
896 assert!(valid_with_warnings.is_valid());
897 assert!(valid_with_warnings.has_warnings());
898 assert_eq!(valid_with_warnings.warnings(), &warnings);
899
900 let errors = vec![utils::error("TEST", "Test error")];
901 let invalid = ValidationResult::Invalid(errors.clone());
902 assert!(!invalid.is_valid());
903 assert!(invalid.is_invalid());
904 assert_eq!(invalid.errors(), &errors);
905 }
906
907 #[test]
908 fn test_utils() {
909 assert!(utils::is_valid_uri("file://test.txt"));
910 assert!(utils::is_valid_uri("https://example.com"));
911 assert!(!utils::is_valid_uri("not-a-uri"));
912
913 assert!(utils::is_valid_method_name("tools/list"));
914 assert!(utils::is_valid_method_name("initialize"));
915 assert!(!utils::is_valid_method_name("invalid-method-name!"));
916 }
917}