1use regex::Regex;
7use serde_json::Value;
8use std::collections::{HashMap, HashSet};
9
10use crate::jsonrpc::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
11use crate::types::*;
12
13#[derive(Debug, Clone)]
15pub struct ProtocolValidator {
16 rules: ValidationRules,
18 strict_mode: bool,
20}
21
22#[derive(Debug, Clone)]
24pub struct ValidationRules {
25 pub max_message_size: usize,
27 pub max_batch_size: usize,
29 pub max_string_length: usize,
31 pub max_array_length: usize,
33 pub max_object_depth: usize,
35 pub uri_regex: Regex,
37 pub method_name_regex: Regex,
39 pub required_fields: HashMap<String, HashSet<String>>,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum ValidationResult {
46 Valid,
48 ValidWithWarnings(Vec<ValidationWarning>),
50 Invalid(Vec<ValidationError>),
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct ValidationWarning {
57 pub code: String,
59 pub message: String,
61 pub field_path: Option<String>,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct ValidationError {
68 pub code: String,
70 pub message: String,
72 pub field_path: Option<String>,
74}
75
76#[derive(Debug, Clone)]
78struct ValidationContext {
79 path: Vec<String>,
81 depth: usize,
83 warnings: Vec<ValidationWarning>,
85 errors: Vec<ValidationError>,
87}
88
89impl Default for ValidationRules {
90 fn default() -> Self {
91 let uri_regex = Regex::new(r"^[a-zA-Z][a-zA-Z0-9+.-]*:").unwrap();
92 let method_name_regex = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_/]*$").unwrap();
93
94 let mut required_fields = HashMap::new();
95
96 required_fields.insert(
98 "request".to_string(),
99 ["jsonrpc", "method", "id"]
100 .iter()
101 .map(|s| s.to_string())
102 .collect(),
103 );
104 required_fields.insert(
105 "response".to_string(),
106 ["jsonrpc", "id"].iter().map(|s| s.to_string()).collect(),
107 );
108 required_fields.insert(
109 "notification".to_string(),
110 ["jsonrpc", "method"]
111 .iter()
112 .map(|s| s.to_string())
113 .collect(),
114 );
115
116 required_fields.insert(
118 "initialize".to_string(),
119 ["protocolVersion", "capabilities", "clientInfo"]
120 .iter()
121 .map(|s| s.to_string())
122 .collect(),
123 );
124 required_fields.insert(
125 "tool".to_string(),
126 ["name", "inputSchema"]
127 .iter()
128 .map(|s| s.to_string())
129 .collect(),
130 );
131 required_fields.insert(
132 "prompt".to_string(),
133 ["name"].iter().map(|s| s.to_string()).collect(),
134 );
135 required_fields.insert(
136 "resource".to_string(),
137 ["uri", "name"].iter().map(|s| s.to_string()).collect(),
138 );
139
140 Self {
141 max_message_size: 10 * 1024 * 1024, max_batch_size: 100,
143 max_string_length: 1024 * 1024, max_array_length: 10000,
145 max_object_depth: 32,
146 uri_regex,
147 method_name_regex,
148 required_fields,
149 }
150 }
151}
152
153impl ProtocolValidator {
154 pub fn new() -> Self {
156 Self {
157 rules: ValidationRules::default(),
158 strict_mode: false,
159 }
160 }
161
162 pub fn with_strict_mode(mut self) -> Self {
164 self.strict_mode = true;
165 self
166 }
167
168 pub fn with_rules(mut self, rules: ValidationRules) -> Self {
170 self.rules = rules;
171 self
172 }
173
174 pub fn validate_request(&self, request: &JsonRpcRequest) -> ValidationResult {
176 let mut ctx = ValidationContext::new();
177
178 self.validate_jsonrpc_request(request, &mut ctx);
180
181 self.validate_method_name(&request.method, &mut ctx);
183
184 if let Some(params) = &request.params {
186 self.validate_method_params(&request.method, params, &mut ctx);
187 }
188
189 ctx.into_result()
190 }
191
192 pub fn validate_response(&self, response: &JsonRpcResponse) -> ValidationResult {
194 let mut ctx = ValidationContext::new();
195
196 self.validate_jsonrpc_response(response, &mut ctx);
198
199 match (response.result.is_some(), response.error.is_some()) {
201 (true, true) => {
202 ctx.add_error(
203 "RESPONSE_BOTH_RESULT_AND_ERROR",
204 "Response cannot have both result and error".to_string(),
205 None,
206 );
207 }
208 (false, false) => {
209 ctx.add_error(
210 "RESPONSE_MISSING_RESULT_OR_ERROR",
211 "Response must have either result or error".to_string(),
212 None,
213 );
214 }
215 _ => {} }
217
218 ctx.into_result()
219 }
220
221 pub fn validate_notification(&self, notification: &JsonRpcNotification) -> ValidationResult {
223 let mut ctx = ValidationContext::new();
224
225 self.validate_jsonrpc_notification(notification, &mut ctx);
227
228 self.validate_method_name(¬ification.method, &mut ctx);
230
231 if let Some(params) = ¬ification.params {
233 self.validate_method_params(¬ification.method, params, &mut ctx);
234 }
235
236 ctx.into_result()
237 }
238
239 pub fn validate_tool(&self, tool: &Tool) -> ValidationResult {
241 let mut ctx = ValidationContext::new();
242
243 if tool.name.is_empty() {
245 ctx.add_error(
246 "TOOL_EMPTY_NAME",
247 "Tool name cannot be empty".to_string(),
248 Some("name".to_string()),
249 );
250 }
251
252 if tool.name.len() > self.rules.max_string_length {
253 ctx.add_error(
254 "TOOL_NAME_TOO_LONG",
255 format!(
256 "Tool name exceeds maximum length of {}",
257 self.rules.max_string_length
258 ),
259 Some("name".to_string()),
260 );
261 }
262
263 self.validate_tool_input(&tool.input_schema, &mut ctx);
265
266 ctx.into_result()
267 }
268
269 pub fn validate_prompt(&self, prompt: &Prompt) -> ValidationResult {
271 let mut ctx = ValidationContext::new();
272
273 if prompt.name.is_empty() {
275 ctx.add_error(
276 "PROMPT_EMPTY_NAME",
277 "Prompt name cannot be empty".to_string(),
278 Some("name".to_string()),
279 );
280 }
281
282 if let Some(arguments) = &prompt.arguments
284 && arguments.len() > self.rules.max_array_length
285 {
286 ctx.add_error(
287 "PROMPT_TOO_MANY_ARGS",
288 format!(
289 "Prompt has too many arguments (max: {})",
290 self.rules.max_array_length
291 ),
292 Some("arguments".to_string()),
293 );
294 }
295
296 ctx.into_result()
297 }
298
299 pub fn validate_resource(&self, resource: &Resource) -> ValidationResult {
301 let mut ctx = ValidationContext::new();
302
303 if !self.rules.uri_regex.is_match(&resource.uri) {
305 ctx.add_error(
306 "RESOURCE_INVALID_URI",
307 format!("Invalid URI format: {}", resource.uri),
308 Some("uri".to_string()),
309 );
310 }
311
312 if resource.name.is_empty() {
314 ctx.add_error(
315 "RESOURCE_EMPTY_NAME",
316 "Resource name cannot be empty".to_string(),
317 Some("name".to_string()),
318 );
319 }
320
321 ctx.into_result()
322 }
323
324 pub fn validate_initialize_request(&self, request: &InitializeRequest) -> ValidationResult {
326 let mut ctx = ValidationContext::new();
327
328 if !crate::SUPPORTED_VERSIONS.contains(&request.protocol_version.as_str()) {
330 ctx.add_warning(
331 "UNSUPPORTED_PROTOCOL_VERSION",
332 format!(
333 "Protocol version {} is not officially supported",
334 request.protocol_version
335 ),
336 Some("protocolVersion".to_string()),
337 );
338 }
339
340 if request.client_info.name.is_empty() {
342 ctx.add_error(
343 "EMPTY_CLIENT_NAME",
344 "Client name cannot be empty".to_string(),
345 Some("clientInfo.name".to_string()),
346 );
347 }
348
349 if request.client_info.version.is_empty() {
350 ctx.add_error(
351 "EMPTY_CLIENT_VERSION",
352 "Client version cannot be empty".to_string(),
353 Some("clientInfo.version".to_string()),
354 );
355 }
356
357 ctx.into_result()
358 }
359
360 fn validate_jsonrpc_request(&self, _request: &JsonRpcRequest, _ctx: &mut ValidationContext) {
363 }
368
369 fn validate_jsonrpc_response(&self, response: &JsonRpcResponse, ctx: &mut ValidationContext) {
370 if let Some(error) = &response.error {
372 self.validate_jsonrpc_error(error, ctx);
373 }
374 }
375
376 fn validate_jsonrpc_notification(
377 &self,
378 _notification: &JsonRpcNotification,
379 _ctx: &mut ValidationContext,
380 ) {
381 }
383
384 fn validate_jsonrpc_error(
385 &self,
386 error: &crate::jsonrpc::JsonRpcError,
387 ctx: &mut ValidationContext,
388 ) {
389 if error.code >= 0 {
391 ctx.add_warning(
392 "POSITIVE_ERROR_CODE",
393 "Error codes should be negative according to JSON-RPC spec".to_string(),
394 Some("error.code".to_string()),
395 );
396 }
397
398 if error.message.is_empty() {
399 ctx.add_error(
400 "EMPTY_ERROR_MESSAGE",
401 "Error message cannot be empty".to_string(),
402 Some("error.message".to_string()),
403 );
404 }
405 }
406
407 fn validate_method_name(&self, method: &str, ctx: &mut ValidationContext) {
408 if method.is_empty() {
409 ctx.add_error(
410 "EMPTY_METHOD_NAME",
411 "Method name cannot be empty".to_string(),
412 Some("method".to_string()),
413 );
414 return;
415 }
416
417 if !self.rules.method_name_regex.is_match(method) {
418 ctx.add_error(
419 "INVALID_METHOD_NAME",
420 format!("Invalid method name format: {method}"),
421 Some("method".to_string()),
422 );
423 }
424 }
425
426 fn validate_method_params(&self, method: &str, params: &Value, ctx: &mut ValidationContext) {
427 ctx.push_path("params".to_string());
428
429 match method {
430 "initialize" => self.validate_value_structure(params, "initialize", ctx),
431 "tools/list" => {
432 if !params.is_null() && !params.as_object().is_some_and(|obj| obj.is_empty()) {
434 ctx.add_warning(
435 "UNEXPECTED_PARAMS",
436 "tools/list should not have parameters".to_string(),
437 None,
438 );
439 }
440 }
441 "tools/call" => self.validate_value_structure(params, "call_tool", ctx),
442 _ => {
443 self.validate_value_structure(params, "generic", ctx);
445 }
446 }
447
448 ctx.pop_path();
449 }
450
451 fn validate_tool_input(&self, input: &ToolInputSchema, ctx: &mut ValidationContext) {
452 ctx.push_path("inputSchema".to_string());
453
454 if input.schema_type != "object" {
456 ctx.add_warning(
457 "NON_OBJECT_SCHEMA",
458 "Tool input schema should typically be 'object'".to_string(),
459 Some("type".to_string()),
460 );
461 }
462
463 ctx.pop_path();
464 }
465
466 fn validate_value_structure(
467 &self,
468 value: &Value,
469 _expected_type: &str,
470 ctx: &mut ValidationContext,
471 ) {
472 if ctx.depth > self.rules.max_object_depth {
474 ctx.add_error(
475 "MAX_DEPTH_EXCEEDED",
476 format!(
477 "Maximum object depth ({}) exceeded",
478 self.rules.max_object_depth
479 ),
480 None,
481 );
482 return;
483 }
484
485 match value {
486 Value::Object(obj) => {
487 ctx.depth += 1;
488 for (key, val) in obj {
489 ctx.push_path(key.clone());
490 self.validate_value_structure(val, "unknown", ctx);
491 ctx.pop_path();
492 }
493 ctx.depth -= 1;
494 }
495 Value::Array(arr) => {
496 if arr.len() > self.rules.max_array_length {
497 ctx.add_error(
498 "ARRAY_TOO_LONG",
499 format!(
500 "Array exceeds maximum length of {}",
501 self.rules.max_array_length
502 ),
503 None,
504 );
505 }
506
507 for (index, val) in arr.iter().enumerate() {
508 ctx.push_path(index.to_string());
509 self.validate_value_structure(val, "unknown", ctx);
510 ctx.pop_path();
511 }
512 }
513 Value::String(s) => {
514 if s.len() > self.rules.max_string_length {
515 ctx.add_error(
516 "STRING_TOO_LONG",
517 format!(
518 "String exceeds maximum length of {}",
519 self.rules.max_string_length
520 ),
521 None,
522 );
523 }
524 }
525 _ => {} }
527 }
528}
529
530impl Default for ProtocolValidator {
531 fn default() -> Self {
532 Self::new()
533 }
534}
535
536impl ValidationContext {
537 fn new() -> Self {
538 Self {
539 path: Vec::new(),
540 depth: 0,
541 warnings: Vec::new(),
542 errors: Vec::new(),
543 }
544 }
545
546 fn push_path(&mut self, segment: String) {
547 self.path.push(segment);
548 }
549
550 fn pop_path(&mut self) {
551 self.path.pop();
552 }
553
554 fn current_path(&self) -> Option<String> {
555 if self.path.is_empty() {
556 None
557 } else {
558 Some(self.path.join("."))
559 }
560 }
561
562 fn add_error(&mut self, code: &str, message: String, field_path: Option<String>) {
563 let path = field_path.or_else(|| self.current_path());
564 self.errors.push(ValidationError {
565 code: code.to_string(),
566 message,
567 field_path: path,
568 });
569 }
570
571 fn add_warning(&mut self, code: &str, message: String, field_path: Option<String>) {
572 let path = field_path.or_else(|| self.current_path());
573 self.warnings.push(ValidationWarning {
574 code: code.to_string(),
575 message,
576 field_path: path,
577 });
578 }
579
580 fn into_result(self) -> ValidationResult {
581 if !self.errors.is_empty() {
582 ValidationResult::Invalid(self.errors)
583 } else if !self.warnings.is_empty() {
584 ValidationResult::ValidWithWarnings(self.warnings)
585 } else {
586 ValidationResult::Valid
587 }
588 }
589}
590
591impl ValidationResult {
592 pub fn is_valid(&self) -> bool {
594 !matches!(self, ValidationResult::Invalid(_))
595 }
596
597 pub fn is_invalid(&self) -> bool {
599 matches!(self, ValidationResult::Invalid(_))
600 }
601
602 pub fn has_warnings(&self) -> bool {
604 matches!(self, ValidationResult::ValidWithWarnings(_))
605 }
606
607 pub fn warnings(&self) -> &[ValidationWarning] {
609 match self {
610 ValidationResult::ValidWithWarnings(warnings) => warnings,
611 _ => &[],
612 }
613 }
614
615 pub fn errors(&self) -> &[ValidationError] {
617 match self {
618 ValidationResult::Invalid(errors) => errors,
619 _ => &[],
620 }
621 }
622}
623
624pub mod utils {
626 use super::*;
627
628 pub fn error(code: &str, message: &str) -> ValidationError {
630 ValidationError {
631 code: code.to_string(),
632 message: message.to_string(),
633 field_path: None,
634 }
635 }
636
637 pub fn warning(code: &str, message: &str) -> ValidationWarning {
639 ValidationWarning {
640 code: code.to_string(),
641 message: message.to_string(),
642 field_path: None,
643 }
644 }
645
646 pub fn is_valid_uri(uri: &str) -> bool {
648 ValidationRules::default().uri_regex.is_match(uri)
649 }
650
651 pub fn is_valid_method_name(method: &str) -> bool {
653 ValidationRules::default()
654 .method_name_regex
655 .is_match(method)
656 }
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662 use crate::jsonrpc::JsonRpcVersion;
663 #[test]
666 fn test_tool_validation() {
667 let validator = ProtocolValidator::new();
668
669 let tool = Tool {
670 name: "test_tool".to_string(),
671 title: Some("Test Tool".to_string()),
672 description: Some("A test tool".to_string()),
673 input_schema: ToolInputSchema {
674 schema_type: "object".to_string(),
675 properties: None,
676 required: None,
677 additional_properties: None,
678 },
679 output_schema: None,
680 annotations: None,
681 meta: None,
682 };
683
684 let result = validator.validate_tool(&tool);
685 assert!(result.is_valid());
686
687 let invalid_tool = Tool {
689 name: String::new(),
690 title: None,
691 description: None,
692 input_schema: tool.input_schema.clone(),
693 output_schema: None,
694 annotations: None,
695 meta: None,
696 };
697
698 let result = validator.validate_tool(&invalid_tool);
699 assert!(result.is_invalid());
700 }
701
702 #[test]
703 fn test_request_validation() {
704 let validator = ProtocolValidator::new();
705
706 let request = JsonRpcRequest {
707 jsonrpc: JsonRpcVersion,
708 method: "tools/list".to_string(),
709 params: None,
710 id: RequestId::String("test-id".to_string()),
711 };
712
713 let result = validator.validate_request(&request);
714 assert!(result.is_valid());
715
716 let invalid_request = JsonRpcRequest {
718 jsonrpc: JsonRpcVersion,
719 method: String::new(),
720 params: None,
721 id: RequestId::String("test-id".to_string()),
722 };
723
724 let result = validator.validate_request(&invalid_request);
725 assert!(result.is_invalid());
726 }
727
728 #[test]
729 fn test_initialize_validation() {
730 let validator = ProtocolValidator::new();
731
732 let request = InitializeRequest {
733 protocol_version: "2025-06-18".to_string(),
734 capabilities: ClientCapabilities::default(),
735 client_info: Implementation {
736 name: "test-client".to_string(),
737 title: Some("Test Client".to_string()),
738 version: "1.0.0".to_string(),
739 },
740 };
741
742 let result = validator.validate_initialize_request(&request);
743 assert!(result.is_valid());
744
745 let request_with_old_version = InitializeRequest {
747 protocol_version: "2023-01-01".to_string(),
748 capabilities: ClientCapabilities::default(),
749 client_info: Implementation {
750 name: "test-client".to_string(),
751 title: Some("Test Client".to_string()),
752 version: "1.0.0".to_string(),
753 },
754 };
755
756 let result = validator.validate_initialize_request(&request_with_old_version);
757 assert!(result.is_valid()); assert!(result.has_warnings());
759 }
760
761 #[test]
762 fn test_validation_result() {
763 let valid = ValidationResult::Valid;
764 assert!(valid.is_valid());
765 assert!(!valid.is_invalid());
766 assert!(!valid.has_warnings());
767
768 let warnings = vec![utils::warning("TEST", "Test warning")];
769 let valid_with_warnings = ValidationResult::ValidWithWarnings(warnings.clone());
770 assert!(valid_with_warnings.is_valid());
771 assert!(valid_with_warnings.has_warnings());
772 assert_eq!(valid_with_warnings.warnings(), &warnings);
773
774 let errors = vec![utils::error("TEST", "Test error")];
775 let invalid = ValidationResult::Invalid(errors.clone());
776 assert!(!invalid.is_valid());
777 assert!(invalid.is_invalid());
778 assert_eq!(invalid.errors(), &errors);
779 }
780
781 #[test]
782 fn test_utils() {
783 assert!(utils::is_valid_uri("file://test.txt"));
784 assert!(utils::is_valid_uri("https://example.com"));
785 assert!(!utils::is_valid_uri("not-a-uri"));
786
787 assert!(utils::is_valid_method_name("tools/list"));
788 assert!(utils::is_valid_method_name("initialize"));
789 assert!(!utils::is_valid_method_name("invalid-method-name!"));
790 }
791}