1use std::any::{Any, TypeId};
7use std::collections::{HashMap, HashSet};
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::{Arc, RwLock};
11use std::task::{Context, Poll};
12
13use tower_service::Service;
14
15use crate::async_task::TaskStore;
16use crate::context::{
17 CancellationToken, ClientRequesterHandle, NotificationSender, RequestContext,
18 ServerNotification,
19};
20use crate::error::{Error, JsonRpcError, Result};
21use crate::prompt::Prompt;
22use crate::protocol::*;
23use crate::resource::{Resource, ResourceTemplate};
24use crate::session::SessionState;
25use crate::tool::Tool;
26
27pub type CompletionHandler = Arc<
29 dyn Fn(CompleteParams) -> Pin<Box<dyn Future<Output = Result<CompleteResult>> + Send>>
30 + Send
31 + Sync,
32>;
33
34#[derive(Clone)]
59pub struct McpRouter {
60 inner: Arc<McpRouterInner>,
61 session: SessionState,
62}
63
64impl std::fmt::Debug for McpRouter {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("McpRouter")
67 .field("server_name", &self.inner.server_name)
68 .field("server_version", &self.inner.server_version)
69 .field("tools_count", &self.inner.tools.len())
70 .field("resources_count", &self.inner.resources.len())
71 .field("prompts_count", &self.inner.prompts.len())
72 .field("session_phase", &self.session.phase())
73 .finish()
74 }
75}
76
77#[derive(Clone)]
79struct McpRouterInner {
80 server_name: String,
81 server_version: String,
82 server_title: Option<String>,
84 server_description: Option<String>,
86 server_icons: Option<Vec<ToolIcon>>,
88 server_website_url: Option<String>,
90 instructions: Option<String>,
91 tools: HashMap<String, Arc<Tool>>,
92 resources: HashMap<String, Arc<Resource>>,
93 resource_templates: Vec<Arc<ResourceTemplate>>,
95 prompts: HashMap<String, Arc<Prompt>>,
96 in_flight: Arc<RwLock<HashMap<RequestId, CancellationToken>>>,
98 notification_tx: Option<NotificationSender>,
100 client_requester: Option<ClientRequesterHandle>,
102 task_store: TaskStore,
104 subscriptions: Arc<RwLock<HashSet<String>>>,
106 completion_handler: Option<CompletionHandler>,
108}
109
110impl McpRouter {
111 pub fn new() -> Self {
113 Self {
114 inner: Arc::new(McpRouterInner {
115 server_name: "tower-mcp".to_string(),
116 server_version: env!("CARGO_PKG_VERSION").to_string(),
117 server_title: None,
118 server_description: None,
119 server_icons: None,
120 server_website_url: None,
121 instructions: None,
122 tools: HashMap::new(),
123 resources: HashMap::new(),
124 resource_templates: Vec::new(),
125 prompts: HashMap::new(),
126 in_flight: Arc::new(RwLock::new(HashMap::new())),
127 notification_tx: None,
128 client_requester: None,
129 task_store: TaskStore::new(),
130 subscriptions: Arc::new(RwLock::new(HashSet::new())),
131 completion_handler: None,
132 }),
133 session: SessionState::new(),
134 }
135 }
136
137 pub fn task_store(&self) -> &TaskStore {
139 &self.inner.task_store
140 }
141
142 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
146 Arc::make_mut(&mut self.inner).notification_tx = Some(tx);
147 self
148 }
149
150 pub fn notification_sender(&self) -> Option<&NotificationSender> {
152 self.inner.notification_tx.as_ref()
153 }
154
155 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
160 Arc::make_mut(&mut self.inner).client_requester = Some(requester);
161 self
162 }
163
164 pub fn client_requester(&self) -> Option<&ClientRequesterHandle> {
166 self.inner.client_requester.as_ref()
167 }
168
169 pub fn create_context(
174 &self,
175 request_id: RequestId,
176 progress_token: Option<ProgressToken>,
177 ) -> RequestContext {
178 let ctx = RequestContext::new(request_id.clone());
179
180 let ctx = if let Some(token) = progress_token {
182 ctx.with_progress_token(token)
183 } else {
184 ctx
185 };
186
187 let ctx = if let Some(tx) = &self.inner.notification_tx {
189 ctx.with_notification_sender(tx.clone())
190 } else {
191 ctx
192 };
193
194 let ctx = if let Some(requester) = &self.inner.client_requester {
196 ctx.with_client_requester(requester.clone())
197 } else {
198 ctx
199 };
200
201 let token = ctx.cancellation_token();
203 if let Ok(mut in_flight) = self.inner.in_flight.write() {
204 in_flight.insert(request_id, token);
205 }
206
207 ctx
208 }
209
210 pub fn complete_request(&self, request_id: &RequestId) {
212 if let Ok(mut in_flight) = self.inner.in_flight.write() {
213 in_flight.remove(request_id);
214 }
215 }
216
217 fn cancel_request(&self, request_id: &RequestId) -> bool {
219 let Ok(in_flight) = self.inner.in_flight.read() else {
220 return false;
221 };
222 let Some(token) = in_flight.get(request_id) else {
223 return false;
224 };
225 token.cancel();
226 true
227 }
228
229 pub fn server_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
231 let inner = Arc::make_mut(&mut self.inner);
232 inner.server_name = name.into();
233 inner.server_version = version.into();
234 self
235 }
236
237 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
239 Arc::make_mut(&mut self.inner).instructions = Some(instructions.into());
240 self
241 }
242
243 pub fn server_title(mut self, title: impl Into<String>) -> Self {
245 Arc::make_mut(&mut self.inner).server_title = Some(title.into());
246 self
247 }
248
249 pub fn server_description(mut self, description: impl Into<String>) -> Self {
251 Arc::make_mut(&mut self.inner).server_description = Some(description.into());
252 self
253 }
254
255 pub fn server_icons(mut self, icons: Vec<ToolIcon>) -> Self {
257 Arc::make_mut(&mut self.inner).server_icons = Some(icons);
258 self
259 }
260
261 pub fn server_website_url(mut self, url: impl Into<String>) -> Self {
263 Arc::make_mut(&mut self.inner).server_website_url = Some(url.into());
264 self
265 }
266
267 pub fn tool(mut self, tool: Tool) -> Self {
269 Arc::make_mut(&mut self.inner)
270 .tools
271 .insert(tool.name.clone(), Arc::new(tool));
272 self
273 }
274
275 pub fn resource(mut self, resource: Resource) -> Self {
277 Arc::make_mut(&mut self.inner)
278 .resources
279 .insert(resource.uri.clone(), Arc::new(resource));
280 self
281 }
282
283 pub fn resource_template(mut self, template: ResourceTemplate) -> Self {
314 Arc::make_mut(&mut self.inner)
315 .resource_templates
316 .push(Arc::new(template));
317 self
318 }
319
320 pub fn prompt(mut self, prompt: Prompt) -> Self {
322 Arc::make_mut(&mut self.inner)
323 .prompts
324 .insert(prompt.name.clone(), Arc::new(prompt));
325 self
326 }
327
328 pub fn tools(self, tools: impl IntoIterator<Item = Tool>) -> Self {
354 tools
355 .into_iter()
356 .fold(self, |router, tool| router.tool(tool))
357 }
358
359 pub fn resources(self, resources: impl IntoIterator<Item = Resource>) -> Self {
378 resources
379 .into_iter()
380 .fold(self, |router, resource| router.resource(resource))
381 }
382
383 pub fn prompts(self, prompts: impl IntoIterator<Item = Prompt>) -> Self {
402 prompts
403 .into_iter()
404 .fold(self, |router, prompt| router.prompt(prompt))
405 }
406
407 pub fn completion_handler<F, Fut>(mut self, handler: F) -> Self
434 where
435 F: Fn(CompleteParams) -> Fut + Send + Sync + 'static,
436 Fut: Future<Output = Result<CompleteResult>> + Send + 'static,
437 {
438 Arc::make_mut(&mut self.inner).completion_handler =
439 Some(Arc::new(move |params| Box::pin(handler(params))));
440 self
441 }
442
443 pub fn session(&self) -> &SessionState {
445 &self.session
446 }
447
448 pub fn log(&self, params: LoggingMessageParams) -> bool {
470 let Some(tx) = &self.inner.notification_tx else {
471 return false;
472 };
473 tx.try_send(ServerNotification::LogMessage(params)).is_ok()
474 }
475
476 pub fn log_info(&self, message: &str) -> bool {
480 self.log(
481 LoggingMessageParams::new(LogLevel::Info)
482 .with_data(serde_json::json!({ "message": message })),
483 )
484 }
485
486 pub fn log_warning(&self, message: &str) -> bool {
488 self.log(
489 LoggingMessageParams::new(LogLevel::Warning)
490 .with_data(serde_json::json!({ "message": message })),
491 )
492 }
493
494 pub fn log_error(&self, message: &str) -> bool {
496 self.log(
497 LoggingMessageParams::new(LogLevel::Error)
498 .with_data(serde_json::json!({ "message": message })),
499 )
500 }
501
502 pub fn log_debug(&self, message: &str) -> bool {
504 self.log(
505 LoggingMessageParams::new(LogLevel::Debug)
506 .with_data(serde_json::json!({ "message": message })),
507 )
508 }
509
510 pub fn is_subscribed(&self, uri: &str) -> bool {
512 if let Ok(subs) = self.inner.subscriptions.read() {
513 return subs.contains(uri);
514 }
515 false
516 }
517
518 pub fn subscribed_uris(&self) -> Vec<String> {
520 if let Ok(subs) = self.inner.subscriptions.read() {
521 return subs.iter().cloned().collect();
522 }
523 Vec::new()
524 }
525
526 fn subscribe(&self, uri: &str) -> bool {
528 if let Ok(mut subs) = self.inner.subscriptions.write() {
529 return subs.insert(uri.to_string());
530 }
531 false
532 }
533
534 fn unsubscribe(&self, uri: &str) -> bool {
536 if let Ok(mut subs) = self.inner.subscriptions.write() {
537 return subs.remove(uri);
538 }
539 false
540 }
541
542 pub fn notify_resource_updated(&self, uri: &str) -> bool {
547 if !self.is_subscribed(uri) {
549 return false;
550 }
551
552 let Some(tx) = &self.inner.notification_tx else {
553 return false;
554 };
555 tx.try_send(ServerNotification::ResourceUpdated {
556 uri: uri.to_string(),
557 })
558 .is_ok()
559 }
560
561 pub fn notify_resources_list_changed(&self) -> bool {
565 let Some(tx) = &self.inner.notification_tx else {
566 return false;
567 };
568 tx.try_send(ServerNotification::ResourcesListChanged)
569 .is_ok()
570 }
571
572 fn capabilities(&self) -> ServerCapabilities {
574 let has_resources =
575 !self.inner.resources.is_empty() || !self.inner.resource_templates.is_empty();
576
577 ServerCapabilities {
578 tools: if self.inner.tools.is_empty() {
579 None
580 } else {
581 Some(ToolsCapability::default())
582 },
583 resources: if has_resources {
584 Some(ResourcesCapability {
585 subscribe: true,
586 ..Default::default()
587 })
588 } else {
589 None
590 },
591 prompts: if self.inner.prompts.is_empty() {
592 None
593 } else {
594 Some(PromptsCapability::default())
595 },
596 logging: if self.inner.notification_tx.is_some() {
598 Some(LoggingCapability::default())
599 } else {
600 None
601 },
602 tasks: Some(TasksCapability::default()),
604 completions: if self.inner.completion_handler.is_some() {
606 Some(CompletionsCapability::default())
607 } else {
608 None
609 },
610 }
611 }
612
613 async fn handle(&self, request_id: RequestId, request: McpRequest) -> Result<McpResponse> {
615 let method = request.method_name();
617 if !self.session.is_request_allowed(method) {
618 tracing::warn!(
619 method = %method,
620 phase = ?self.session.phase(),
621 "Request rejected: session not initialized"
622 );
623 return Err(Error::JsonRpc(JsonRpcError::invalid_request(format!(
624 "Session not initialized. Only 'initialize' and 'ping' are allowed before initialization. Got: {}",
625 method
626 ))));
627 }
628
629 match request {
630 McpRequest::Initialize(params) => {
631 tracing::info!(
632 client = %params.client_info.name,
633 version = %params.client_info.version,
634 "Client initializing"
635 );
636
637 let protocol_version = if crate::protocol::SUPPORTED_PROTOCOL_VERSIONS
640 .contains(¶ms.protocol_version.as_str())
641 {
642 params.protocol_version
643 } else {
644 crate::protocol::LATEST_PROTOCOL_VERSION.to_string()
645 };
646
647 self.session.mark_initializing();
649
650 Ok(McpResponse::Initialize(InitializeResult {
651 protocol_version,
652 capabilities: self.capabilities(),
653 server_info: Implementation {
654 name: self.inner.server_name.clone(),
655 version: self.inner.server_version.clone(),
656 title: self.inner.server_title.clone(),
657 description: self.inner.server_description.clone(),
658 icons: self.inner.server_icons.clone(),
659 website_url: self.inner.server_website_url.clone(),
660 },
661 instructions: self.inner.instructions.clone(),
662 }))
663 }
664
665 McpRequest::ListTools(_params) => {
666 let tools: Vec<ToolDefinition> =
667 self.inner.tools.values().map(|t| t.definition()).collect();
668
669 Ok(McpResponse::ListTools(ListToolsResult {
670 tools,
671 next_cursor: None,
672 }))
673 }
674
675 McpRequest::CallTool(params) => {
676 let tool =
677 self.inner.tools.get(¶ms.name).ok_or_else(|| {
678 Error::JsonRpc(JsonRpcError::method_not_found(¶ms.name))
679 })?;
680
681 let progress_token = params.meta.and_then(|m| m.progress_token);
683 let ctx = self.create_context(request_id, progress_token);
684
685 tracing::debug!(tool = %params.name, "Calling tool");
686 let result = tool.call_with_context(ctx, params.arguments).await?;
687
688 Ok(McpResponse::CallTool(result))
689 }
690
691 McpRequest::ListResources(_params) => {
692 let resources: Vec<ResourceDefinition> = self
693 .inner
694 .resources
695 .values()
696 .map(|r| r.definition())
697 .collect();
698
699 Ok(McpResponse::ListResources(ListResourcesResult {
700 resources,
701 next_cursor: None,
702 }))
703 }
704
705 McpRequest::ListResourceTemplates(_params) => {
706 let resource_templates: Vec<ResourceTemplateDefinition> = self
707 .inner
708 .resource_templates
709 .iter()
710 .map(|t| t.definition())
711 .collect();
712
713 Ok(McpResponse::ListResourceTemplates(
714 ListResourceTemplatesResult {
715 resource_templates,
716 next_cursor: None,
717 },
718 ))
719 }
720
721 McpRequest::ReadResource(params) => {
722 if let Some(resource) = self.inner.resources.get(¶ms.uri) {
724 tracing::debug!(uri = %params.uri, "Reading static resource");
725 let result = resource.read().await?;
726 return Ok(McpResponse::ReadResource(result));
727 }
728
729 for template in &self.inner.resource_templates {
731 if let Some(variables) = template.match_uri(¶ms.uri) {
732 tracing::debug!(
733 uri = %params.uri,
734 template = %template.uri_template,
735 "Reading resource via template"
736 );
737 let result = template.read(¶ms.uri, variables).await?;
738 return Ok(McpResponse::ReadResource(result));
739 }
740 }
741
742 Err(Error::JsonRpc(JsonRpcError::resource_not_found(
744 ¶ms.uri,
745 )))
746 }
747
748 McpRequest::SubscribeResource(params) => {
749 if !self.inner.resources.contains_key(¶ms.uri) {
751 return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
752 ¶ms.uri,
753 )));
754 }
755
756 tracing::debug!(uri = %params.uri, "Subscribing to resource");
757 self.subscribe(¶ms.uri);
758
759 Ok(McpResponse::SubscribeResource(EmptyResult {}))
760 }
761
762 McpRequest::UnsubscribeResource(params) => {
763 if !self.inner.resources.contains_key(¶ms.uri) {
765 return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
766 ¶ms.uri,
767 )));
768 }
769
770 tracing::debug!(uri = %params.uri, "Unsubscribing from resource");
771 self.unsubscribe(¶ms.uri);
772
773 Ok(McpResponse::UnsubscribeResource(EmptyResult {}))
774 }
775
776 McpRequest::ListPrompts(_params) => {
777 let prompts: Vec<PromptDefinition> = self
778 .inner
779 .prompts
780 .values()
781 .map(|p| p.definition())
782 .collect();
783
784 Ok(McpResponse::ListPrompts(ListPromptsResult {
785 prompts,
786 next_cursor: None,
787 }))
788 }
789
790 McpRequest::GetPrompt(params) => {
791 let prompt = self.inner.prompts.get(¶ms.name).ok_or_else(|| {
792 Error::JsonRpc(JsonRpcError::method_not_found(&format!(
793 "Prompt not found: {}",
794 params.name
795 )))
796 })?;
797
798 tracing::debug!(name = %params.name, "Getting prompt");
799 let result = prompt.get(params.arguments).await?;
800
801 Ok(McpResponse::GetPrompt(result))
802 }
803
804 McpRequest::Ping => Ok(McpResponse::Pong(EmptyResult {})),
805
806 McpRequest::EnqueueTask(params) => {
807 let tool = self.inner.tools.get(¶ms.tool_name).ok_or_else(|| {
809 Error::JsonRpc(JsonRpcError::method_not_found(&format!(
810 "Tool not found: {}",
811 params.tool_name
812 )))
813 })?;
814
815 let (task_id, cancellation_token) = self.inner.task_store.create_task(
817 ¶ms.tool_name,
818 params.arguments.clone(),
819 params.ttl,
820 );
821
822 tracing::info!(task_id = %task_id, tool = %params.tool_name, "Enqueued async task");
823
824 let ctx = self.create_context(request_id, None);
826
827 let task_store = self.inner.task_store.clone();
829 let tool = tool.clone();
830 let arguments = params.arguments;
831 let task_id_clone = task_id.clone();
832
833 tokio::spawn(async move {
834 if cancellation_token.is_cancelled() {
836 tracing::debug!(task_id = %task_id_clone, "Task cancelled before execution");
837 return;
838 }
839
840 match tool.call_with_context(ctx, arguments).await {
842 Ok(result) => {
843 if cancellation_token.is_cancelled() {
844 tracing::debug!(task_id = %task_id_clone, "Task cancelled during execution");
845 } else {
846 task_store.complete_task(&task_id_clone, result);
847 tracing::debug!(task_id = %task_id_clone, "Task completed successfully");
848 }
849 }
850 Err(e) => {
851 task_store.fail_task(&task_id_clone, &e.to_string());
852 tracing::warn!(task_id = %task_id_clone, error = %e, "Task failed");
853 }
854 }
855 });
856
857 Ok(McpResponse::EnqueueTask(EnqueueTaskResult {
858 task_id,
859 status: TaskStatus::Working,
860 poll_interval: Some(2),
861 }))
862 }
863
864 McpRequest::ListTasks(params) => {
865 let tasks = self.inner.task_store.list_tasks(params.status);
866
867 Ok(McpResponse::ListTasks(ListTasksResult {
868 tasks,
869 next_cursor: None,
870 }))
871 }
872
873 McpRequest::GetTaskInfo(params) => {
874 let task = self
875 .inner
876 .task_store
877 .get_task(¶ms.task_id)
878 .ok_or_else(|| {
879 Error::JsonRpc(JsonRpcError::invalid_params(format!(
880 "Task not found: {}",
881 params.task_id
882 )))
883 })?;
884
885 Ok(McpResponse::GetTaskInfo(task))
886 }
887
888 McpRequest::GetTaskResult(params) => {
889 let (status, result, error) = self
890 .inner
891 .task_store
892 .get_task_full(¶ms.task_id)
893 .ok_or_else(|| {
894 Error::JsonRpc(JsonRpcError::invalid_params(format!(
895 "Task not found: {}",
896 params.task_id
897 )))
898 })?;
899
900 Ok(McpResponse::GetTaskResult(GetTaskResultResult {
901 task_id: params.task_id,
902 status,
903 result,
904 error,
905 }))
906 }
907
908 McpRequest::CancelTask(params) => {
909 let status = self
910 .inner
911 .task_store
912 .cancel_task(¶ms.task_id, params.reason.as_deref())
913 .ok_or_else(|| {
914 Error::JsonRpc(JsonRpcError::invalid_params(format!(
915 "Task not found: {}",
916 params.task_id
917 )))
918 })?;
919
920 let cancelled = status == TaskStatus::Cancelled;
921
922 Ok(McpResponse::CancelTask(CancelTaskResult {
923 cancelled,
924 status,
925 }))
926 }
927
928 McpRequest::SetLoggingLevel(params) => {
929 tracing::debug!(level = ?params.level, "Client set logging level");
933 Ok(McpResponse::SetLoggingLevel(EmptyResult {}))
934 }
935
936 McpRequest::Complete(params) => {
937 tracing::debug!(
938 reference = ?params.reference,
939 argument = %params.argument.name,
940 "Completion request"
941 );
942
943 if let Some(ref handler) = self.inner.completion_handler {
945 let result = handler(params).await?;
946 Ok(McpResponse::Complete(result))
947 } else {
948 Ok(McpResponse::Complete(CompleteResult::new(vec![])))
950 }
951 }
952
953 McpRequest::Unknown { method, .. } => {
954 Err(Error::JsonRpc(JsonRpcError::method_not_found(&method)))
955 }
956 }
957 }
958
959 pub fn handle_notification(&self, notification: McpNotification) {
961 match notification {
962 McpNotification::Initialized => {
963 if self.session.mark_initialized() {
964 tracing::info!("Session initialized, entering operation phase");
965 } else {
966 tracing::warn!(
967 "Received initialized notification in unexpected state: {:?}",
968 self.session.phase()
969 );
970 }
971 }
972 McpNotification::Cancelled(params) => {
973 if self.cancel_request(¶ms.request_id) {
974 tracing::info!(
975 request_id = ?params.request_id,
976 reason = ?params.reason,
977 "Request cancelled"
978 );
979 } else {
980 tracing::debug!(
981 request_id = ?params.request_id,
982 reason = ?params.reason,
983 "Cancellation requested for unknown request"
984 );
985 }
986 }
987 McpNotification::Progress(params) => {
988 tracing::trace!(
989 token = ?params.progress_token,
990 progress = params.progress,
991 total = ?params.total,
992 "Progress notification"
993 );
994 }
996 McpNotification::RootsListChanged => {
997 tracing::info!("Client roots list changed");
998 }
1001 McpNotification::Unknown { method, .. } => {
1002 tracing::debug!(method = %method, "Unknown notification received");
1003 }
1004 }
1005 }
1006}
1007
1008impl Default for McpRouter {
1009 fn default() -> Self {
1010 Self::new()
1011 }
1012}
1013
1014#[derive(Default, Clone)]
1034pub struct Extensions {
1035 map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
1036}
1037
1038impl Extensions {
1039 pub fn new() -> Self {
1041 Self::default()
1042 }
1043
1044 pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
1048 self.map.insert(TypeId::of::<T>(), Arc::new(val));
1049 }
1050
1051 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
1055 self.map
1056 .get(&TypeId::of::<T>())
1057 .and_then(|val| val.downcast_ref::<T>())
1058 }
1059}
1060
1061impl std::fmt::Debug for Extensions {
1062 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1063 f.debug_struct("Extensions")
1064 .field("len", &self.map.len())
1065 .finish()
1066 }
1067}
1068
1069#[derive(Debug)]
1071pub struct RouterRequest {
1072 pub id: RequestId,
1073 pub inner: McpRequest,
1074 pub extensions: Extensions,
1076}
1077
1078#[derive(Debug)]
1080pub struct RouterResponse {
1081 pub id: RequestId,
1082 pub inner: std::result::Result<McpResponse, JsonRpcError>,
1083}
1084
1085impl RouterResponse {
1086 pub fn into_jsonrpc(self) -> JsonRpcResponse {
1088 match self.inner {
1089 Ok(response) => match serde_json::to_value(response) {
1090 Ok(result) => JsonRpcResponse::result(self.id, result),
1091 Err(e) => {
1092 tracing::error!(error = %e, "Failed to serialize response");
1093 JsonRpcResponse::error(
1094 Some(self.id),
1095 JsonRpcError::internal_error(format!("Serialization error: {}", e)),
1096 )
1097 }
1098 },
1099 Err(error) => JsonRpcResponse::error(Some(self.id), error),
1100 }
1101 }
1102}
1103
1104impl Service<RouterRequest> for McpRouter {
1105 type Response = RouterResponse;
1106 type Error = std::convert::Infallible; type Future =
1108 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
1109
1110 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
1111 Poll::Ready(Ok(()))
1112 }
1113
1114 fn call(&mut self, req: RouterRequest) -> Self::Future {
1115 let router = self.clone();
1116 let request_id = req.id.clone();
1117 Box::pin(async move {
1118 let result = router.handle(req.id, req.inner).await;
1119 router.complete_request(&request_id);
1121 Ok(RouterResponse {
1122 id: request_id,
1123 inner: result.map_err(|e| match e {
1128 Error::JsonRpc(err) => err,
1129 Error::Tool(err) => JsonRpcError::internal_error(err.to_string()),
1130 e => JsonRpcError::internal_error(e.to_string()),
1131 }),
1132 })
1133 })
1134 }
1135}
1136
1137#[cfg(test)]
1138mod tests {
1139 use super::*;
1140 use crate::jsonrpc::JsonRpcService;
1141 use crate::tool::ToolBuilder;
1142 use schemars::JsonSchema;
1143 use serde::Deserialize;
1144 use tower::ServiceExt;
1145
1146 #[derive(Debug, Deserialize, JsonSchema)]
1147 struct AddInput {
1148 a: i64,
1149 b: i64,
1150 }
1151
1152 async fn init_router(router: &mut McpRouter) {
1154 let init_req = RouterRequest {
1156 id: RequestId::Number(0),
1157 inner: McpRequest::Initialize(InitializeParams {
1158 protocol_version: "2025-11-25".to_string(),
1159 capabilities: ClientCapabilities {
1160 roots: None,
1161 sampling: None,
1162 elicitation: None,
1163 },
1164 client_info: Implementation {
1165 name: "test".to_string(),
1166 version: "1.0".to_string(),
1167 ..Default::default()
1168 },
1169 }),
1170 extensions: Extensions::new(),
1171 };
1172 let _ = router.ready().await.unwrap().call(init_req).await.unwrap();
1173 router.handle_notification(McpNotification::Initialized);
1175 }
1176
1177 #[tokio::test]
1178 async fn test_router_list_tools() {
1179 let add_tool = ToolBuilder::new("add")
1180 .description("Add two numbers")
1181 .handler(|input: AddInput| async move {
1182 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1183 })
1184 .build()
1185 .expect("valid tool name");
1186
1187 let mut router = McpRouter::new().tool(add_tool);
1188
1189 init_router(&mut router).await;
1191
1192 let req = RouterRequest {
1193 id: RequestId::Number(1),
1194 inner: McpRequest::ListTools(ListToolsParams::default()),
1195 extensions: Extensions::new(),
1196 };
1197
1198 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1199
1200 match resp.inner {
1201 Ok(McpResponse::ListTools(result)) => {
1202 assert_eq!(result.tools.len(), 1);
1203 assert_eq!(result.tools[0].name, "add");
1204 }
1205 _ => panic!("Expected ListTools response"),
1206 }
1207 }
1208
1209 #[tokio::test]
1210 async fn test_router_call_tool() {
1211 let add_tool = ToolBuilder::new("add")
1212 .description("Add two numbers")
1213 .handler(|input: AddInput| async move {
1214 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1215 })
1216 .build()
1217 .expect("valid tool name");
1218
1219 let mut router = McpRouter::new().tool(add_tool);
1220
1221 init_router(&mut router).await;
1223
1224 let req = RouterRequest {
1225 id: RequestId::Number(1),
1226 inner: McpRequest::CallTool(CallToolParams {
1227 name: "add".to_string(),
1228 arguments: serde_json::json!({"a": 2, "b": 3}),
1229 meta: None,
1230 }),
1231 extensions: Extensions::new(),
1232 };
1233
1234 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1235
1236 match resp.inner {
1237 Ok(McpResponse::CallTool(result)) => {
1238 assert!(!result.is_error);
1239 match &result.content[0] {
1241 Content::Text { text, .. } => assert_eq!(text, "5"),
1242 _ => panic!("Expected text content"),
1243 }
1244 }
1245 _ => panic!("Expected CallTool response"),
1246 }
1247 }
1248
1249 async fn init_jsonrpc_service(service: &mut JsonRpcService<McpRouter>, router: &McpRouter) {
1251 let init_req = JsonRpcRequest::new(0, "initialize").with_params(serde_json::json!({
1252 "protocolVersion": "2025-11-25",
1253 "capabilities": {},
1254 "clientInfo": { "name": "test", "version": "1.0" }
1255 }));
1256 let _ = service.call_single(init_req).await.unwrap();
1257 router.handle_notification(McpNotification::Initialized);
1258 }
1259
1260 #[tokio::test]
1261 async fn test_jsonrpc_service() {
1262 let add_tool = ToolBuilder::new("add")
1263 .description("Add two numbers")
1264 .handler(|input: AddInput| async move {
1265 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1266 })
1267 .build()
1268 .expect("valid tool name");
1269
1270 let router = McpRouter::new().tool(add_tool);
1271 let mut service = JsonRpcService::new(router.clone());
1272
1273 init_jsonrpc_service(&mut service, &router).await;
1275
1276 let req = JsonRpcRequest::new(1, "tools/list");
1277
1278 let resp = service.call_single(req).await.unwrap();
1279
1280 match resp {
1281 JsonRpcResponse::Result(r) => {
1282 assert_eq!(r.id, RequestId::Number(1));
1283 let tools = r.result.get("tools").unwrap().as_array().unwrap();
1284 assert_eq!(tools.len(), 1);
1285 }
1286 JsonRpcResponse::Error(_) => panic!("Expected success response"),
1287 }
1288 }
1289
1290 #[tokio::test]
1291 async fn test_batch_request() {
1292 let add_tool = ToolBuilder::new("add")
1293 .description("Add two numbers")
1294 .handler(|input: AddInput| async move {
1295 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1296 })
1297 .build()
1298 .expect("valid tool name");
1299
1300 let router = McpRouter::new().tool(add_tool);
1301 let mut service = JsonRpcService::new(router.clone());
1302
1303 init_jsonrpc_service(&mut service, &router).await;
1305
1306 let requests = vec![
1308 JsonRpcRequest::new(1, "tools/list"),
1309 JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1310 "name": "add",
1311 "arguments": {"a": 10, "b": 20}
1312 })),
1313 JsonRpcRequest::new(3, "ping"),
1314 ];
1315
1316 let responses = service.call_batch(requests).await.unwrap();
1317
1318 assert_eq!(responses.len(), 3);
1319
1320 match &responses[0] {
1322 JsonRpcResponse::Result(r) => {
1323 assert_eq!(r.id, RequestId::Number(1));
1324 let tools = r.result.get("tools").unwrap().as_array().unwrap();
1325 assert_eq!(tools.len(), 1);
1326 }
1327 JsonRpcResponse::Error(_) => panic!("Expected success for tools/list"),
1328 }
1329
1330 match &responses[1] {
1332 JsonRpcResponse::Result(r) => {
1333 assert_eq!(r.id, RequestId::Number(2));
1334 let content = r.result.get("content").unwrap().as_array().unwrap();
1335 let text = content[0].get("text").unwrap().as_str().unwrap();
1336 assert_eq!(text, "30");
1337 }
1338 JsonRpcResponse::Error(_) => panic!("Expected success for tools/call"),
1339 }
1340
1341 match &responses[2] {
1343 JsonRpcResponse::Result(r) => {
1344 assert_eq!(r.id, RequestId::Number(3));
1345 }
1346 JsonRpcResponse::Error(_) => panic!("Expected success for ping"),
1347 }
1348 }
1349
1350 #[tokio::test]
1351 async fn test_empty_batch_error() {
1352 let router = McpRouter::new();
1353 let mut service = JsonRpcService::new(router);
1354
1355 let result = service.call_batch(vec![]).await;
1356 assert!(result.is_err());
1357 }
1358
1359 #[tokio::test]
1364 async fn test_progress_token_extraction() {
1365 use crate::context::{RequestContext, ServerNotification, notification_channel};
1366 use crate::protocol::ProgressToken;
1367 use std::sync::Arc;
1368 use std::sync::atomic::{AtomicBool, Ordering};
1369
1370 let progress_reported = Arc::new(AtomicBool::new(false));
1372 let progress_ref = progress_reported.clone();
1373
1374 let tool = ToolBuilder::new("progress_tool")
1376 .description("Tool that reports progress")
1377 .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1378 let reported = progress_ref.clone();
1379 async move {
1380 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
1382 .await;
1383 reported.store(true, Ordering::SeqCst);
1384 Ok(CallToolResult::text("done"))
1385 }
1386 })
1387 .build()
1388 .expect("valid tool name");
1389
1390 let (tx, mut rx) = notification_channel(10);
1392 let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1393 let mut service = JsonRpcService::new(router.clone());
1394
1395 init_jsonrpc_service(&mut service, &router).await;
1397
1398 let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1400 "name": "progress_tool",
1401 "arguments": {"a": 1, "b": 2},
1402 "_meta": {
1403 "progressToken": "test-token-123"
1404 }
1405 }));
1406
1407 let resp = service.call_single(req).await.unwrap();
1408
1409 match resp {
1411 JsonRpcResponse::Result(_) => {}
1412 JsonRpcResponse::Error(e) => panic!("Expected success, got error: {:?}", e),
1413 }
1414
1415 assert!(progress_reported.load(Ordering::SeqCst));
1417
1418 let notification = rx.try_recv().expect("Expected progress notification");
1420 match notification {
1421 ServerNotification::Progress(params) => {
1422 assert_eq!(
1423 params.progress_token,
1424 ProgressToken::String("test-token-123".to_string())
1425 );
1426 assert_eq!(params.progress, 50.0);
1427 assert_eq!(params.total, Some(100.0));
1428 assert_eq!(params.message.as_deref(), Some("Halfway"));
1429 }
1430 _ => panic!("Expected Progress notification"),
1431 }
1432 }
1433
1434 #[tokio::test]
1435 async fn test_tool_call_without_progress_token() {
1436 use crate::context::{RequestContext, notification_channel};
1437 use std::sync::Arc;
1438 use std::sync::atomic::{AtomicBool, Ordering};
1439
1440 let progress_attempted = Arc::new(AtomicBool::new(false));
1441 let progress_ref = progress_attempted.clone();
1442
1443 let tool = ToolBuilder::new("no_token_tool")
1444 .description("Tool that tries to report progress without token")
1445 .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1446 let attempted = progress_ref.clone();
1447 async move {
1448 ctx.report_progress(50.0, Some(100.0), None).await;
1450 attempted.store(true, Ordering::SeqCst);
1451 Ok(CallToolResult::text("done"))
1452 }
1453 })
1454 .build()
1455 .expect("valid tool name");
1456
1457 let (tx, mut rx) = notification_channel(10);
1458 let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1459 let mut service = JsonRpcService::new(router.clone());
1460
1461 init_jsonrpc_service(&mut service, &router).await;
1462
1463 let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1465 "name": "no_token_tool",
1466 "arguments": {"a": 1, "b": 2}
1467 }));
1468
1469 let resp = service.call_single(req).await.unwrap();
1470 assert!(matches!(resp, JsonRpcResponse::Result(_)));
1471
1472 assert!(progress_attempted.load(Ordering::SeqCst));
1474
1475 assert!(rx.try_recv().is_err());
1477 }
1478
1479 #[tokio::test]
1480 async fn test_batch_errors_returned_not_dropped() {
1481 let add_tool = ToolBuilder::new("add")
1482 .description("Add two numbers")
1483 .handler(|input: AddInput| async move {
1484 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1485 })
1486 .build()
1487 .expect("valid tool name");
1488
1489 let router = McpRouter::new().tool(add_tool);
1490 let mut service = JsonRpcService::new(router.clone());
1491
1492 init_jsonrpc_service(&mut service, &router).await;
1493
1494 let requests = vec![
1496 JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1498 "name": "add",
1499 "arguments": {"a": 10, "b": 20}
1500 })),
1501 JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1503 "name": "nonexistent_tool",
1504 "arguments": {}
1505 })),
1506 JsonRpcRequest::new(3, "ping"),
1508 ];
1509
1510 let responses = service.call_batch(requests).await.unwrap();
1511
1512 assert_eq!(responses.len(), 3);
1514
1515 match &responses[0] {
1517 JsonRpcResponse::Result(r) => {
1518 assert_eq!(r.id, RequestId::Number(1));
1519 }
1520 JsonRpcResponse::Error(_) => panic!("Expected success for first request"),
1521 }
1522
1523 match &responses[1] {
1525 JsonRpcResponse::Error(e) => {
1526 assert_eq!(e.id, Some(RequestId::Number(2)));
1527 assert!(e.error.message.contains("not found") || e.error.code == -32601);
1529 }
1530 JsonRpcResponse::Result(_) => panic!("Expected error for second request"),
1531 }
1532
1533 match &responses[2] {
1535 JsonRpcResponse::Result(r) => {
1536 assert_eq!(r.id, RequestId::Number(3));
1537 }
1538 JsonRpcResponse::Error(_) => panic!("Expected success for third request"),
1539 }
1540 }
1541
1542 #[tokio::test]
1547 async fn test_list_resource_templates() {
1548 use crate::resource::ResourceTemplateBuilder;
1549 use std::collections::HashMap;
1550
1551 let template = ResourceTemplateBuilder::new("file:///{path}")
1552 .name("Project Files")
1553 .description("Access project files")
1554 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1555 Ok(ReadResourceResult {
1556 contents: vec![ResourceContent {
1557 uri,
1558 mime_type: None,
1559 text: None,
1560 blob: None,
1561 }],
1562 })
1563 });
1564
1565 let mut router = McpRouter::new().resource_template(template);
1566
1567 init_router(&mut router).await;
1569
1570 let req = RouterRequest {
1571 id: RequestId::Number(1),
1572 inner: McpRequest::ListResourceTemplates(ListResourceTemplatesParams::default()),
1573 extensions: Extensions::new(),
1574 };
1575
1576 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1577
1578 match resp.inner {
1579 Ok(McpResponse::ListResourceTemplates(result)) => {
1580 assert_eq!(result.resource_templates.len(), 1);
1581 assert_eq!(result.resource_templates[0].uri_template, "file:///{path}");
1582 assert_eq!(result.resource_templates[0].name, "Project Files");
1583 }
1584 _ => panic!("Expected ListResourceTemplates response"),
1585 }
1586 }
1587
1588 #[tokio::test]
1589 async fn test_read_resource_via_template() {
1590 use crate::resource::ResourceTemplateBuilder;
1591 use std::collections::HashMap;
1592
1593 let template = ResourceTemplateBuilder::new("db://users/{id}")
1594 .name("User Records")
1595 .handler(|uri: String, vars: HashMap<String, String>| async move {
1596 let id = vars.get("id").unwrap().clone();
1597 Ok(ReadResourceResult {
1598 contents: vec![ResourceContent {
1599 uri,
1600 mime_type: Some("application/json".to_string()),
1601 text: Some(format!(r#"{{"id": "{}"}}"#, id)),
1602 blob: None,
1603 }],
1604 })
1605 });
1606
1607 let mut router = McpRouter::new().resource_template(template);
1608
1609 init_router(&mut router).await;
1611
1612 let req = RouterRequest {
1614 id: RequestId::Number(1),
1615 inner: McpRequest::ReadResource(ReadResourceParams {
1616 uri: "db://users/123".to_string(),
1617 }),
1618 extensions: Extensions::new(),
1619 };
1620
1621 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1622
1623 match resp.inner {
1624 Ok(McpResponse::ReadResource(result)) => {
1625 assert_eq!(result.contents.len(), 1);
1626 assert_eq!(result.contents[0].uri, "db://users/123");
1627 assert!(result.contents[0].text.as_ref().unwrap().contains("123"));
1628 }
1629 _ => panic!("Expected ReadResource response"),
1630 }
1631 }
1632
1633 #[tokio::test]
1634 async fn test_static_resource_takes_precedence_over_template() {
1635 use crate::resource::{ResourceBuilder, ResourceTemplateBuilder};
1636 use std::collections::HashMap;
1637
1638 let template = ResourceTemplateBuilder::new("file:///{path}")
1640 .name("Files Template")
1641 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1642 Ok(ReadResourceResult {
1643 contents: vec![ResourceContent {
1644 uri,
1645 mime_type: None,
1646 text: Some("from template".to_string()),
1647 blob: None,
1648 }],
1649 })
1650 });
1651
1652 let static_resource = ResourceBuilder::new("file:///README.md")
1654 .name("README")
1655 .text("from static resource");
1656
1657 let mut router = McpRouter::new()
1658 .resource_template(template)
1659 .resource(static_resource);
1660
1661 init_router(&mut router).await;
1663
1664 let req = RouterRequest {
1666 id: RequestId::Number(1),
1667 inner: McpRequest::ReadResource(ReadResourceParams {
1668 uri: "file:///README.md".to_string(),
1669 }),
1670 extensions: Extensions::new(),
1671 };
1672
1673 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1674
1675 match resp.inner {
1676 Ok(McpResponse::ReadResource(result)) => {
1677 assert_eq!(
1679 result.contents[0].text.as_deref(),
1680 Some("from static resource")
1681 );
1682 }
1683 _ => panic!("Expected ReadResource response"),
1684 }
1685 }
1686
1687 #[tokio::test]
1688 async fn test_resource_not_found_when_no_match() {
1689 use crate::resource::ResourceTemplateBuilder;
1690 use std::collections::HashMap;
1691
1692 let template = ResourceTemplateBuilder::new("db://users/{id}")
1693 .name("Users")
1694 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1695 Ok(ReadResourceResult {
1696 contents: vec![ResourceContent {
1697 uri,
1698 mime_type: None,
1699 text: None,
1700 blob: None,
1701 }],
1702 })
1703 });
1704
1705 let mut router = McpRouter::new().resource_template(template);
1706
1707 init_router(&mut router).await;
1709
1710 let req = RouterRequest {
1712 id: RequestId::Number(1),
1713 inner: McpRequest::ReadResource(ReadResourceParams {
1714 uri: "db://posts/123".to_string(),
1715 }),
1716 extensions: Extensions::new(),
1717 };
1718
1719 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1720
1721 match resp.inner {
1722 Err(err) => {
1723 assert!(err.message.contains("not found"));
1724 }
1725 Ok(_) => panic!("Expected error for non-matching URI"),
1726 }
1727 }
1728
1729 #[tokio::test]
1730 async fn test_capabilities_include_resources_with_only_templates() {
1731 use crate::resource::ResourceTemplateBuilder;
1732 use std::collections::HashMap;
1733
1734 let template = ResourceTemplateBuilder::new("file:///{path}")
1735 .name("Files")
1736 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1737 Ok(ReadResourceResult {
1738 contents: vec![ResourceContent {
1739 uri,
1740 mime_type: None,
1741 text: None,
1742 blob: None,
1743 }],
1744 })
1745 });
1746
1747 let mut router = McpRouter::new().resource_template(template);
1748
1749 let init_req = RouterRequest {
1751 id: RequestId::Number(0),
1752 inner: McpRequest::Initialize(InitializeParams {
1753 protocol_version: "2025-11-25".to_string(),
1754 capabilities: ClientCapabilities {
1755 roots: None,
1756 sampling: None,
1757 elicitation: None,
1758 },
1759 client_info: Implementation {
1760 name: "test".to_string(),
1761 version: "1.0".to_string(),
1762 ..Default::default()
1763 },
1764 }),
1765 extensions: Extensions::new(),
1766 };
1767 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1768
1769 match resp.inner {
1770 Ok(McpResponse::Initialize(result)) => {
1771 assert!(result.capabilities.resources.is_some());
1773 }
1774 _ => panic!("Expected Initialize response"),
1775 }
1776 }
1777
1778 #[tokio::test]
1783 async fn test_log_sends_notification() {
1784 use crate::context::notification_channel;
1785
1786 let (tx, mut rx) = notification_channel(10);
1787 let router = McpRouter::new().with_notification_sender(tx);
1788
1789 let sent = router.log_info("Test message");
1791 assert!(sent);
1792
1793 let notification = rx.try_recv().unwrap();
1795 match notification {
1796 ServerNotification::LogMessage(params) => {
1797 assert_eq!(params.level, LogLevel::Info);
1798 let data = params.data.unwrap();
1799 assert_eq!(
1800 data.get("message").unwrap().as_str().unwrap(),
1801 "Test message"
1802 );
1803 }
1804 _ => panic!("Expected LogMessage notification"),
1805 }
1806 }
1807
1808 #[tokio::test]
1809 async fn test_log_with_custom_params() {
1810 use crate::context::notification_channel;
1811
1812 let (tx, mut rx) = notification_channel(10);
1813 let router = McpRouter::new().with_notification_sender(tx);
1814
1815 let params = LoggingMessageParams::new(LogLevel::Error)
1817 .with_logger("database")
1818 .with_data(serde_json::json!({
1819 "error": "Connection failed",
1820 "host": "localhost"
1821 }));
1822
1823 let sent = router.log(params);
1824 assert!(sent);
1825
1826 let notification = rx.try_recv().unwrap();
1827 match notification {
1828 ServerNotification::LogMessage(params) => {
1829 assert_eq!(params.level, LogLevel::Error);
1830 assert_eq!(params.logger.as_deref(), Some("database"));
1831 let data = params.data.unwrap();
1832 assert_eq!(
1833 data.get("error").unwrap().as_str().unwrap(),
1834 "Connection failed"
1835 );
1836 }
1837 _ => panic!("Expected LogMessage notification"),
1838 }
1839 }
1840
1841 #[tokio::test]
1842 async fn test_log_without_channel_returns_false() {
1843 let router = McpRouter::new();
1845
1846 assert!(!router.log_info("Test"));
1848 assert!(!router.log_warning("Test"));
1849 assert!(!router.log_error("Test"));
1850 assert!(!router.log_debug("Test"));
1851 }
1852
1853 #[tokio::test]
1854 async fn test_logging_capability_with_channel() {
1855 use crate::context::notification_channel;
1856
1857 let (tx, _rx) = notification_channel(10);
1858 let mut router = McpRouter::new().with_notification_sender(tx);
1859
1860 let init_req = RouterRequest {
1862 id: RequestId::Number(0),
1863 inner: McpRequest::Initialize(InitializeParams {
1864 protocol_version: "2025-11-25".to_string(),
1865 capabilities: ClientCapabilities {
1866 roots: None,
1867 sampling: None,
1868 elicitation: None,
1869 },
1870 client_info: Implementation {
1871 name: "test".to_string(),
1872 version: "1.0".to_string(),
1873 ..Default::default()
1874 },
1875 }),
1876 extensions: Extensions::new(),
1877 };
1878 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1879
1880 match resp.inner {
1881 Ok(McpResponse::Initialize(result)) => {
1882 assert!(result.capabilities.logging.is_some());
1884 }
1885 _ => panic!("Expected Initialize response"),
1886 }
1887 }
1888
1889 #[tokio::test]
1890 async fn test_no_logging_capability_without_channel() {
1891 let mut router = McpRouter::new();
1892
1893 let init_req = RouterRequest {
1895 id: RequestId::Number(0),
1896 inner: McpRequest::Initialize(InitializeParams {
1897 protocol_version: "2025-11-25".to_string(),
1898 capabilities: ClientCapabilities {
1899 roots: None,
1900 sampling: None,
1901 elicitation: None,
1902 },
1903 client_info: Implementation {
1904 name: "test".to_string(),
1905 version: "1.0".to_string(),
1906 ..Default::default()
1907 },
1908 }),
1909 extensions: Extensions::new(),
1910 };
1911 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1912
1913 match resp.inner {
1914 Ok(McpResponse::Initialize(result)) => {
1915 assert!(result.capabilities.logging.is_none());
1917 }
1918 _ => panic!("Expected Initialize response"),
1919 }
1920 }
1921
1922 #[tokio::test]
1927 async fn test_enqueue_task() {
1928 let add_tool = ToolBuilder::new("add")
1929 .description("Add two numbers")
1930 .handler(|input: AddInput| async move {
1931 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1932 })
1933 .build()
1934 .expect("valid tool name");
1935
1936 let mut router = McpRouter::new().tool(add_tool);
1937 init_router(&mut router).await;
1938
1939 let req = RouterRequest {
1940 id: RequestId::Number(1),
1941 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1942 tool_name: "add".to_string(),
1943 arguments: serde_json::json!({"a": 5, "b": 10}),
1944 ttl: None,
1945 }),
1946 extensions: Extensions::new(),
1947 };
1948
1949 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1950
1951 match resp.inner {
1952 Ok(McpResponse::EnqueueTask(result)) => {
1953 assert!(result.task_id.starts_with("task-"));
1954 assert_eq!(result.status, TaskStatus::Working);
1955 }
1956 _ => panic!("Expected EnqueueTask response"),
1957 }
1958 }
1959
1960 #[tokio::test]
1961 async fn test_list_tasks_empty() {
1962 let mut router = McpRouter::new();
1963 init_router(&mut router).await;
1964
1965 let req = RouterRequest {
1966 id: RequestId::Number(1),
1967 inner: McpRequest::ListTasks(ListTasksParams::default()),
1968 extensions: Extensions::new(),
1969 };
1970
1971 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1972
1973 match resp.inner {
1974 Ok(McpResponse::ListTasks(result)) => {
1975 assert!(result.tasks.is_empty());
1976 }
1977 _ => panic!("Expected ListTasks response"),
1978 }
1979 }
1980
1981 #[tokio::test]
1982 async fn test_task_lifecycle_complete() {
1983 let add_tool = ToolBuilder::new("add")
1984 .description("Add two numbers")
1985 .handler(|input: AddInput| async move {
1986 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1987 })
1988 .build()
1989 .expect("valid tool name");
1990
1991 let mut router = McpRouter::new().tool(add_tool);
1992 init_router(&mut router).await;
1993
1994 let req = RouterRequest {
1996 id: RequestId::Number(1),
1997 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1998 tool_name: "add".to_string(),
1999 arguments: serde_json::json!({"a": 7, "b": 8}),
2000 ttl: None,
2001 }),
2002 extensions: Extensions::new(),
2003 };
2004
2005 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2006 let task_id = match resp.inner {
2007 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2008 _ => panic!("Expected EnqueueTask response"),
2009 };
2010
2011 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
2013
2014 let req = RouterRequest {
2016 id: RequestId::Number(2),
2017 inner: McpRequest::GetTaskResult(GetTaskResultParams {
2018 task_id: task_id.clone(),
2019 }),
2020 extensions: Extensions::new(),
2021 };
2022
2023 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2024
2025 match resp.inner {
2026 Ok(McpResponse::GetTaskResult(result)) => {
2027 assert_eq!(result.task_id, task_id);
2028 assert_eq!(result.status, TaskStatus::Completed);
2029 assert!(result.result.is_some());
2030 assert!(result.error.is_none());
2031
2032 let tool_result = result.result.unwrap();
2034 match &tool_result.content[0] {
2035 Content::Text { text, .. } => assert_eq!(text, "15"),
2036 _ => panic!("Expected text content"),
2037 }
2038 }
2039 _ => panic!("Expected GetTaskResult response"),
2040 }
2041 }
2042
2043 #[tokio::test]
2044 async fn test_task_cancellation() {
2045 let slow_tool = ToolBuilder::new("slow")
2047 .description("Slow tool")
2048 .handler(|_input: serde_json::Value| async move {
2049 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
2050 Ok(CallToolResult::text("done"))
2051 })
2052 .build()
2053 .expect("valid tool name");
2054
2055 let mut router = McpRouter::new().tool(slow_tool);
2056 init_router(&mut router).await;
2057
2058 let req = RouterRequest {
2060 id: RequestId::Number(1),
2061 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2062 tool_name: "slow".to_string(),
2063 arguments: serde_json::json!({}),
2064 ttl: None,
2065 }),
2066 extensions: Extensions::new(),
2067 };
2068
2069 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2070 let task_id = match resp.inner {
2071 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2072 _ => panic!("Expected EnqueueTask response"),
2073 };
2074
2075 let req = RouterRequest {
2077 id: RequestId::Number(2),
2078 inner: McpRequest::CancelTask(CancelTaskParams {
2079 task_id: task_id.clone(),
2080 reason: Some("Test cancellation".to_string()),
2081 }),
2082 extensions: Extensions::new(),
2083 };
2084
2085 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2086
2087 match resp.inner {
2088 Ok(McpResponse::CancelTask(result)) => {
2089 assert!(result.cancelled);
2090 assert_eq!(result.status, TaskStatus::Cancelled);
2091 }
2092 _ => panic!("Expected CancelTask response"),
2093 }
2094 }
2095
2096 #[tokio::test]
2097 async fn test_get_task_info() {
2098 let add_tool = ToolBuilder::new("add")
2099 .description("Add two numbers")
2100 .handler(|input: AddInput| async move {
2101 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2102 })
2103 .build()
2104 .expect("valid tool name");
2105
2106 let mut router = McpRouter::new().tool(add_tool);
2107 init_router(&mut router).await;
2108
2109 let req = RouterRequest {
2111 id: RequestId::Number(1),
2112 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2113 tool_name: "add".to_string(),
2114 arguments: serde_json::json!({"a": 1, "b": 2}),
2115 ttl: Some(600),
2116 }),
2117 extensions: Extensions::new(),
2118 };
2119
2120 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2121 let task_id = match resp.inner {
2122 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2123 _ => panic!("Expected EnqueueTask response"),
2124 };
2125
2126 let req = RouterRequest {
2128 id: RequestId::Number(2),
2129 inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2130 task_id: task_id.clone(),
2131 }),
2132 extensions: Extensions::new(),
2133 };
2134
2135 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2136
2137 match resp.inner {
2138 Ok(McpResponse::GetTaskInfo(info)) => {
2139 assert_eq!(info.task_id, task_id);
2140 assert!(info.created_at.contains('T')); assert_eq!(info.ttl, Some(600));
2142 }
2143 _ => panic!("Expected GetTaskInfo response"),
2144 }
2145 }
2146
2147 #[tokio::test]
2148 async fn test_enqueue_nonexistent_tool() {
2149 let mut router = McpRouter::new();
2150 init_router(&mut router).await;
2151
2152 let req = RouterRequest {
2153 id: RequestId::Number(1),
2154 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2155 tool_name: "nonexistent".to_string(),
2156 arguments: serde_json::json!({}),
2157 ttl: None,
2158 }),
2159 extensions: Extensions::new(),
2160 };
2161
2162 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2163
2164 match resp.inner {
2165 Err(e) => {
2166 assert!(e.message.contains("not found"));
2167 }
2168 _ => panic!("Expected error response"),
2169 }
2170 }
2171
2172 #[tokio::test]
2173 async fn test_get_nonexistent_task() {
2174 let mut router = McpRouter::new();
2175 init_router(&mut router).await;
2176
2177 let req = RouterRequest {
2178 id: RequestId::Number(1),
2179 inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2180 task_id: "task-999".to_string(),
2181 }),
2182 extensions: Extensions::new(),
2183 };
2184
2185 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2186
2187 match resp.inner {
2188 Err(e) => {
2189 assert!(e.message.contains("not found"));
2190 }
2191 _ => panic!("Expected error response"),
2192 }
2193 }
2194
2195 #[tokio::test]
2200 async fn test_subscribe_to_resource() {
2201 use crate::resource::ResourceBuilder;
2202
2203 let resource = ResourceBuilder::new("file:///test.txt")
2204 .name("Test File")
2205 .text("Hello");
2206
2207 let mut router = McpRouter::new().resource(resource);
2208 init_router(&mut router).await;
2209
2210 let req = RouterRequest {
2212 id: RequestId::Number(1),
2213 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2214 uri: "file:///test.txt".to_string(),
2215 }),
2216 extensions: Extensions::new(),
2217 };
2218
2219 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2220
2221 match resp.inner {
2222 Ok(McpResponse::SubscribeResource(_)) => {
2223 assert!(router.is_subscribed("file:///test.txt"));
2225 }
2226 _ => panic!("Expected SubscribeResource response"),
2227 }
2228 }
2229
2230 #[tokio::test]
2231 async fn test_unsubscribe_from_resource() {
2232 use crate::resource::ResourceBuilder;
2233
2234 let resource = ResourceBuilder::new("file:///test.txt")
2235 .name("Test File")
2236 .text("Hello");
2237
2238 let mut router = McpRouter::new().resource(resource);
2239 init_router(&mut router).await;
2240
2241 let req = RouterRequest {
2243 id: RequestId::Number(1),
2244 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2245 uri: "file:///test.txt".to_string(),
2246 }),
2247 extensions: Extensions::new(),
2248 };
2249 let _ = router.ready().await.unwrap().call(req).await.unwrap();
2250 assert!(router.is_subscribed("file:///test.txt"));
2251
2252 let req = RouterRequest {
2254 id: RequestId::Number(2),
2255 inner: McpRequest::UnsubscribeResource(UnsubscribeResourceParams {
2256 uri: "file:///test.txt".to_string(),
2257 }),
2258 extensions: Extensions::new(),
2259 };
2260
2261 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2262
2263 match resp.inner {
2264 Ok(McpResponse::UnsubscribeResource(_)) => {
2265 assert!(!router.is_subscribed("file:///test.txt"));
2267 }
2268 _ => panic!("Expected UnsubscribeResource response"),
2269 }
2270 }
2271
2272 #[tokio::test]
2273 async fn test_subscribe_nonexistent_resource() {
2274 let mut router = McpRouter::new();
2275 init_router(&mut router).await;
2276
2277 let req = RouterRequest {
2278 id: RequestId::Number(1),
2279 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2280 uri: "file:///nonexistent.txt".to_string(),
2281 }),
2282 extensions: Extensions::new(),
2283 };
2284
2285 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2286
2287 match resp.inner {
2288 Err(e) => {
2289 assert!(e.message.contains("not found"));
2290 }
2291 _ => panic!("Expected error response"),
2292 }
2293 }
2294
2295 #[tokio::test]
2296 async fn test_notify_resource_updated() {
2297 use crate::context::notification_channel;
2298 use crate::resource::ResourceBuilder;
2299
2300 let (tx, mut rx) = notification_channel(10);
2301
2302 let resource = ResourceBuilder::new("file:///test.txt")
2303 .name("Test File")
2304 .text("Hello");
2305
2306 let router = McpRouter::new()
2307 .resource(resource)
2308 .with_notification_sender(tx);
2309
2310 router.subscribe("file:///test.txt");
2312
2313 let sent = router.notify_resource_updated("file:///test.txt");
2315 assert!(sent);
2316
2317 let notification = rx.try_recv().unwrap();
2319 match notification {
2320 ServerNotification::ResourceUpdated { uri } => {
2321 assert_eq!(uri, "file:///test.txt");
2322 }
2323 _ => panic!("Expected ResourceUpdated notification"),
2324 }
2325 }
2326
2327 #[tokio::test]
2328 async fn test_notify_resource_updated_not_subscribed() {
2329 use crate::context::notification_channel;
2330 use crate::resource::ResourceBuilder;
2331
2332 let (tx, mut rx) = notification_channel(10);
2333
2334 let resource = ResourceBuilder::new("file:///test.txt")
2335 .name("Test File")
2336 .text("Hello");
2337
2338 let router = McpRouter::new()
2339 .resource(resource)
2340 .with_notification_sender(tx);
2341
2342 let sent = router.notify_resource_updated("file:///test.txt");
2344 assert!(!sent); assert!(rx.try_recv().is_err());
2348 }
2349
2350 #[tokio::test]
2351 async fn test_notify_resources_list_changed() {
2352 use crate::context::notification_channel;
2353
2354 let (tx, mut rx) = notification_channel(10);
2355 let router = McpRouter::new().with_notification_sender(tx);
2356
2357 let sent = router.notify_resources_list_changed();
2358 assert!(sent);
2359
2360 let notification = rx.try_recv().unwrap();
2361 match notification {
2362 ServerNotification::ResourcesListChanged => {}
2363 _ => panic!("Expected ResourcesListChanged notification"),
2364 }
2365 }
2366
2367 #[tokio::test]
2368 async fn test_subscribed_uris() {
2369 use crate::resource::ResourceBuilder;
2370
2371 let resource1 = ResourceBuilder::new("file:///a.txt").name("A").text("A");
2372
2373 let resource2 = ResourceBuilder::new("file:///b.txt").name("B").text("B");
2374
2375 let router = McpRouter::new().resource(resource1).resource(resource2);
2376
2377 router.subscribe("file:///a.txt");
2379 router.subscribe("file:///b.txt");
2380
2381 let uris = router.subscribed_uris();
2382 assert_eq!(uris.len(), 2);
2383 assert!(uris.contains(&"file:///a.txt".to_string()));
2384 assert!(uris.contains(&"file:///b.txt".to_string()));
2385 }
2386
2387 #[tokio::test]
2388 async fn test_subscription_capability_advertised() {
2389 use crate::resource::ResourceBuilder;
2390
2391 let resource = ResourceBuilder::new("file:///test.txt")
2392 .name("Test")
2393 .text("Hello");
2394
2395 let mut router = McpRouter::new().resource(resource);
2396
2397 let init_req = RouterRequest {
2399 id: RequestId::Number(0),
2400 inner: McpRequest::Initialize(InitializeParams {
2401 protocol_version: "2025-11-25".to_string(),
2402 capabilities: ClientCapabilities {
2403 roots: None,
2404 sampling: None,
2405 elicitation: None,
2406 },
2407 client_info: Implementation {
2408 name: "test".to_string(),
2409 version: "1.0".to_string(),
2410 ..Default::default()
2411 },
2412 }),
2413 extensions: Extensions::new(),
2414 };
2415 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2416
2417 match resp.inner {
2418 Ok(McpResponse::Initialize(result)) => {
2419 let resources_cap = result.capabilities.resources.unwrap();
2421 assert!(resources_cap.subscribe);
2422 }
2423 _ => panic!("Expected Initialize response"),
2424 }
2425 }
2426
2427 #[tokio::test]
2428 async fn test_completion_handler() {
2429 let router = McpRouter::new()
2430 .server_info("test", "1.0")
2431 .completion_handler(|params: CompleteParams| async move {
2432 let prefix = ¶ms.argument.value;
2434 let suggestions: Vec<String> = vec!["alpha", "beta", "gamma"]
2435 .into_iter()
2436 .filter(|s| s.starts_with(prefix))
2437 .map(String::from)
2438 .collect();
2439 Ok(CompleteResult::new(suggestions))
2440 });
2441
2442 let init_req = RouterRequest {
2444 id: RequestId::Number(0),
2445 inner: McpRequest::Initialize(InitializeParams {
2446 protocol_version: "2025-11-25".to_string(),
2447 capabilities: ClientCapabilities::default(),
2448 client_info: Implementation {
2449 name: "test".to_string(),
2450 version: "1.0".to_string(),
2451 ..Default::default()
2452 },
2453 }),
2454 extensions: Extensions::new(),
2455 };
2456 let resp = router
2457 .clone()
2458 .ready()
2459 .await
2460 .unwrap()
2461 .call(init_req)
2462 .await
2463 .unwrap();
2464
2465 match resp.inner {
2467 Ok(McpResponse::Initialize(result)) => {
2468 assert!(result.capabilities.completions.is_some());
2469 }
2470 _ => panic!("Expected Initialize response"),
2471 }
2472
2473 router.handle_notification(McpNotification::Initialized);
2475
2476 let complete_req = RouterRequest {
2478 id: RequestId::Number(1),
2479 inner: McpRequest::Complete(CompleteParams {
2480 reference: CompletionReference::prompt("test-prompt"),
2481 argument: CompletionArgument::new("query", "al"),
2482 }),
2483 extensions: Extensions::new(),
2484 };
2485 let resp = router
2486 .clone()
2487 .ready()
2488 .await
2489 .unwrap()
2490 .call(complete_req)
2491 .await
2492 .unwrap();
2493
2494 match resp.inner {
2495 Ok(McpResponse::Complete(result)) => {
2496 assert_eq!(result.completion.values, vec!["alpha"]);
2497 }
2498 _ => panic!("Expected Complete response"),
2499 }
2500 }
2501
2502 #[tokio::test]
2503 async fn test_completion_without_handler_returns_empty() {
2504 let router = McpRouter::new().server_info("test", "1.0");
2505
2506 let init_req = RouterRequest {
2508 id: RequestId::Number(0),
2509 inner: McpRequest::Initialize(InitializeParams {
2510 protocol_version: "2025-11-25".to_string(),
2511 capabilities: ClientCapabilities::default(),
2512 client_info: Implementation {
2513 name: "test".to_string(),
2514 version: "1.0".to_string(),
2515 ..Default::default()
2516 },
2517 }),
2518 extensions: Extensions::new(),
2519 };
2520 let resp = router
2521 .clone()
2522 .ready()
2523 .await
2524 .unwrap()
2525 .call(init_req)
2526 .await
2527 .unwrap();
2528
2529 match resp.inner {
2531 Ok(McpResponse::Initialize(result)) => {
2532 assert!(result.capabilities.completions.is_none());
2533 }
2534 _ => panic!("Expected Initialize response"),
2535 }
2536
2537 router.handle_notification(McpNotification::Initialized);
2539
2540 let complete_req = RouterRequest {
2542 id: RequestId::Number(1),
2543 inner: McpRequest::Complete(CompleteParams {
2544 reference: CompletionReference::prompt("test-prompt"),
2545 argument: CompletionArgument::new("query", "al"),
2546 }),
2547 extensions: Extensions::new(),
2548 };
2549 let resp = router
2550 .clone()
2551 .ready()
2552 .await
2553 .unwrap()
2554 .call(complete_req)
2555 .await
2556 .unwrap();
2557
2558 match resp.inner {
2559 Ok(McpResponse::Complete(result)) => {
2560 assert!(result.completion.values.is_empty());
2561 }
2562 _ => panic!("Expected Complete response"),
2563 }
2564 }
2565}