1use std::any::{Any, TypeId};
7use std::collections::{HashMap, HashSet};
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::{Arc, RwLock};
11use std::task::{Context, Poll};
12
13use tower_service::Service;
14
15use crate::async_task::TaskStore;
16use crate::context::{
17 CancellationToken, ClientRequesterHandle, NotificationSender, RequestContext,
18 ServerNotification,
19};
20use crate::error::{Error, JsonRpcError, Result};
21use crate::filter::{PromptFilter, ResourceFilter, ToolFilter};
22use crate::prompt::Prompt;
23use crate::protocol::*;
24use crate::resource::{Resource, ResourceTemplate};
25use crate::session::SessionState;
26use crate::tool::Tool;
27
28pub type CompletionHandler = Arc<
30 dyn Fn(CompleteParams) -> Pin<Box<dyn Future<Output = Result<CompleteResult>> + Send>>
31 + Send
32 + Sync,
33>;
34
35#[derive(Clone)]
60pub struct McpRouter {
61 inner: Arc<McpRouterInner>,
62 session: SessionState,
63}
64
65impl std::fmt::Debug for McpRouter {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("McpRouter")
68 .field("server_name", &self.inner.server_name)
69 .field("server_version", &self.inner.server_version)
70 .field("tools_count", &self.inner.tools.len())
71 .field("resources_count", &self.inner.resources.len())
72 .field("prompts_count", &self.inner.prompts.len())
73 .field("session_phase", &self.session.phase())
74 .finish()
75 }
76}
77
78#[derive(Clone)]
80struct McpRouterInner {
81 server_name: String,
82 server_version: String,
83 server_title: Option<String>,
85 server_description: Option<String>,
87 server_icons: Option<Vec<ToolIcon>>,
89 server_website_url: Option<String>,
91 instructions: Option<String>,
92 tools: HashMap<String, Arc<Tool>>,
93 resources: HashMap<String, Arc<Resource>>,
94 resource_templates: Vec<Arc<ResourceTemplate>>,
96 prompts: HashMap<String, Arc<Prompt>>,
97 in_flight: Arc<RwLock<HashMap<RequestId, CancellationToken>>>,
99 notification_tx: Option<NotificationSender>,
101 client_requester: Option<ClientRequesterHandle>,
103 task_store: TaskStore,
105 subscriptions: Arc<RwLock<HashSet<String>>>,
107 completion_handler: Option<CompletionHandler>,
109 tool_filter: Option<ToolFilter>,
111 resource_filter: Option<ResourceFilter>,
113 prompt_filter: Option<PromptFilter>,
115}
116
117impl McpRouter {
118 pub fn new() -> Self {
120 Self {
121 inner: Arc::new(McpRouterInner {
122 server_name: "tower-mcp".to_string(),
123 server_version: env!("CARGO_PKG_VERSION").to_string(),
124 server_title: None,
125 server_description: None,
126 server_icons: None,
127 server_website_url: None,
128 instructions: None,
129 tools: HashMap::new(),
130 resources: HashMap::new(),
131 resource_templates: Vec::new(),
132 prompts: HashMap::new(),
133 in_flight: Arc::new(RwLock::new(HashMap::new())),
134 notification_tx: None,
135 client_requester: None,
136 task_store: TaskStore::new(),
137 subscriptions: Arc::new(RwLock::new(HashSet::new())),
138 completion_handler: None,
139 tool_filter: None,
140 resource_filter: None,
141 prompt_filter: None,
142 }),
143 session: SessionState::new(),
144 }
145 }
146
147 pub fn task_store(&self) -> &TaskStore {
149 &self.inner.task_store
150 }
151
152 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
156 Arc::make_mut(&mut self.inner).notification_tx = Some(tx);
157 self
158 }
159
160 pub fn notification_sender(&self) -> Option<&NotificationSender> {
162 self.inner.notification_tx.as_ref()
163 }
164
165 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
170 Arc::make_mut(&mut self.inner).client_requester = Some(requester);
171 self
172 }
173
174 pub fn client_requester(&self) -> Option<&ClientRequesterHandle> {
176 self.inner.client_requester.as_ref()
177 }
178
179 pub fn create_context(
184 &self,
185 request_id: RequestId,
186 progress_token: Option<ProgressToken>,
187 ) -> RequestContext {
188 let ctx = RequestContext::new(request_id.clone());
189
190 let ctx = if let Some(token) = progress_token {
192 ctx.with_progress_token(token)
193 } else {
194 ctx
195 };
196
197 let ctx = if let Some(tx) = &self.inner.notification_tx {
199 ctx.with_notification_sender(tx.clone())
200 } else {
201 ctx
202 };
203
204 let ctx = if let Some(requester) = &self.inner.client_requester {
206 ctx.with_client_requester(requester.clone())
207 } else {
208 ctx
209 };
210
211 let token = ctx.cancellation_token();
213 if let Ok(mut in_flight) = self.inner.in_flight.write() {
214 in_flight.insert(request_id, token);
215 }
216
217 ctx
218 }
219
220 pub fn complete_request(&self, request_id: &RequestId) {
222 if let Ok(mut in_flight) = self.inner.in_flight.write() {
223 in_flight.remove(request_id);
224 }
225 }
226
227 fn cancel_request(&self, request_id: &RequestId) -> bool {
229 let Ok(in_flight) = self.inner.in_flight.read() else {
230 return false;
231 };
232 let Some(token) = in_flight.get(request_id) else {
233 return false;
234 };
235 token.cancel();
236 true
237 }
238
239 pub fn server_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
241 let inner = Arc::make_mut(&mut self.inner);
242 inner.server_name = name.into();
243 inner.server_version = version.into();
244 self
245 }
246
247 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
249 Arc::make_mut(&mut self.inner).instructions = Some(instructions.into());
250 self
251 }
252
253 pub fn server_title(mut self, title: impl Into<String>) -> Self {
255 Arc::make_mut(&mut self.inner).server_title = Some(title.into());
256 self
257 }
258
259 pub fn server_description(mut self, description: impl Into<String>) -> Self {
261 Arc::make_mut(&mut self.inner).server_description = Some(description.into());
262 self
263 }
264
265 pub fn server_icons(mut self, icons: Vec<ToolIcon>) -> Self {
267 Arc::make_mut(&mut self.inner).server_icons = Some(icons);
268 self
269 }
270
271 pub fn server_website_url(mut self, url: impl Into<String>) -> Self {
273 Arc::make_mut(&mut self.inner).server_website_url = Some(url.into());
274 self
275 }
276
277 pub fn tool(mut self, tool: Tool) -> Self {
279 Arc::make_mut(&mut self.inner)
280 .tools
281 .insert(tool.name.clone(), Arc::new(tool));
282 self
283 }
284
285 pub fn resource(mut self, resource: Resource) -> Self {
287 Arc::make_mut(&mut self.inner)
288 .resources
289 .insert(resource.uri.clone(), Arc::new(resource));
290 self
291 }
292
293 pub fn resource_template(mut self, template: ResourceTemplate) -> Self {
324 Arc::make_mut(&mut self.inner)
325 .resource_templates
326 .push(Arc::new(template));
327 self
328 }
329
330 pub fn prompt(mut self, prompt: Prompt) -> Self {
332 Arc::make_mut(&mut self.inner)
333 .prompts
334 .insert(prompt.name.clone(), Arc::new(prompt));
335 self
336 }
337
338 pub fn tools(self, tools: impl IntoIterator<Item = Tool>) -> Self {
364 tools
365 .into_iter()
366 .fold(self, |router, tool| router.tool(tool))
367 }
368
369 pub fn resources(self, resources: impl IntoIterator<Item = Resource>) -> Self {
388 resources
389 .into_iter()
390 .fold(self, |router, resource| router.resource(resource))
391 }
392
393 pub fn prompts(self, prompts: impl IntoIterator<Item = Prompt>) -> Self {
412 prompts
413 .into_iter()
414 .fold(self, |router, prompt| router.prompt(prompt))
415 }
416
417 pub fn completion_handler<F, Fut>(mut self, handler: F) -> Self
444 where
445 F: Fn(CompleteParams) -> Fut + Send + Sync + 'static,
446 Fut: Future<Output = Result<CompleteResult>> + Send + 'static,
447 {
448 Arc::make_mut(&mut self.inner).completion_handler =
449 Some(Arc::new(move |params| Box::pin(handler(params))));
450 self
451 }
452
453 pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
490 Arc::make_mut(&mut self.inner).tool_filter = Some(filter);
491 self
492 }
493
494 pub fn resource_filter(mut self, filter: ResourceFilter) -> Self {
525 Arc::make_mut(&mut self.inner).resource_filter = Some(filter);
526 self
527 }
528
529 pub fn prompt_filter(mut self, filter: PromptFilter) -> Self {
558 Arc::make_mut(&mut self.inner).prompt_filter = Some(filter);
559 self
560 }
561
562 pub fn session(&self) -> &SessionState {
564 &self.session
565 }
566
567 pub fn log(&self, params: LoggingMessageParams) -> bool {
589 let Some(tx) = &self.inner.notification_tx else {
590 return false;
591 };
592 tx.try_send(ServerNotification::LogMessage(params)).is_ok()
593 }
594
595 pub fn log_info(&self, message: &str) -> bool {
599 self.log(
600 LoggingMessageParams::new(LogLevel::Info)
601 .with_data(serde_json::json!({ "message": message })),
602 )
603 }
604
605 pub fn log_warning(&self, message: &str) -> bool {
607 self.log(
608 LoggingMessageParams::new(LogLevel::Warning)
609 .with_data(serde_json::json!({ "message": message })),
610 )
611 }
612
613 pub fn log_error(&self, message: &str) -> bool {
615 self.log(
616 LoggingMessageParams::new(LogLevel::Error)
617 .with_data(serde_json::json!({ "message": message })),
618 )
619 }
620
621 pub fn log_debug(&self, message: &str) -> bool {
623 self.log(
624 LoggingMessageParams::new(LogLevel::Debug)
625 .with_data(serde_json::json!({ "message": message })),
626 )
627 }
628
629 pub fn is_subscribed(&self, uri: &str) -> bool {
631 if let Ok(subs) = self.inner.subscriptions.read() {
632 return subs.contains(uri);
633 }
634 false
635 }
636
637 pub fn subscribed_uris(&self) -> Vec<String> {
639 if let Ok(subs) = self.inner.subscriptions.read() {
640 return subs.iter().cloned().collect();
641 }
642 Vec::new()
643 }
644
645 fn subscribe(&self, uri: &str) -> bool {
647 if let Ok(mut subs) = self.inner.subscriptions.write() {
648 return subs.insert(uri.to_string());
649 }
650 false
651 }
652
653 fn unsubscribe(&self, uri: &str) -> bool {
655 if let Ok(mut subs) = self.inner.subscriptions.write() {
656 return subs.remove(uri);
657 }
658 false
659 }
660
661 pub fn notify_resource_updated(&self, uri: &str) -> bool {
666 if !self.is_subscribed(uri) {
668 return false;
669 }
670
671 let Some(tx) = &self.inner.notification_tx else {
672 return false;
673 };
674 tx.try_send(ServerNotification::ResourceUpdated {
675 uri: uri.to_string(),
676 })
677 .is_ok()
678 }
679
680 pub fn notify_resources_list_changed(&self) -> bool {
684 let Some(tx) = &self.inner.notification_tx else {
685 return false;
686 };
687 tx.try_send(ServerNotification::ResourcesListChanged)
688 .is_ok()
689 }
690
691 fn capabilities(&self) -> ServerCapabilities {
693 let has_resources =
694 !self.inner.resources.is_empty() || !self.inner.resource_templates.is_empty();
695
696 ServerCapabilities {
697 tools: if self.inner.tools.is_empty() {
698 None
699 } else {
700 Some(ToolsCapability::default())
701 },
702 resources: if has_resources {
703 Some(ResourcesCapability {
704 subscribe: true,
705 ..Default::default()
706 })
707 } else {
708 None
709 },
710 prompts: if self.inner.prompts.is_empty() {
711 None
712 } else {
713 Some(PromptsCapability::default())
714 },
715 logging: if self.inner.notification_tx.is_some() {
717 Some(LoggingCapability::default())
718 } else {
719 None
720 },
721 tasks: Some(TasksCapability::default()),
723 completions: if self.inner.completion_handler.is_some() {
725 Some(CompletionsCapability::default())
726 } else {
727 None
728 },
729 }
730 }
731
732 async fn handle(&self, request_id: RequestId, request: McpRequest) -> Result<McpResponse> {
734 let method = request.method_name();
736 if !self.session.is_request_allowed(method) {
737 tracing::warn!(
738 method = %method,
739 phase = ?self.session.phase(),
740 "Request rejected: session not initialized"
741 );
742 return Err(Error::JsonRpc(JsonRpcError::invalid_request(format!(
743 "Session not initialized. Only 'initialize' and 'ping' are allowed before initialization. Got: {}",
744 method
745 ))));
746 }
747
748 match request {
749 McpRequest::Initialize(params) => {
750 tracing::info!(
751 client = %params.client_info.name,
752 version = %params.client_info.version,
753 "Client initializing"
754 );
755
756 let protocol_version = if crate::protocol::SUPPORTED_PROTOCOL_VERSIONS
759 .contains(¶ms.protocol_version.as_str())
760 {
761 params.protocol_version
762 } else {
763 crate::protocol::LATEST_PROTOCOL_VERSION.to_string()
764 };
765
766 self.session.mark_initializing();
768
769 Ok(McpResponse::Initialize(InitializeResult {
770 protocol_version,
771 capabilities: self.capabilities(),
772 server_info: Implementation {
773 name: self.inner.server_name.clone(),
774 version: self.inner.server_version.clone(),
775 title: self.inner.server_title.clone(),
776 description: self.inner.server_description.clone(),
777 icons: self.inner.server_icons.clone(),
778 website_url: self.inner.server_website_url.clone(),
779 },
780 instructions: self.inner.instructions.clone(),
781 }))
782 }
783
784 McpRequest::ListTools(_params) => {
785 let tools: Vec<ToolDefinition> = self
786 .inner
787 .tools
788 .values()
789 .filter(|t| {
790 self.inner
792 .tool_filter
793 .as_ref()
794 .map(|f| f.is_visible(&self.session, t))
795 .unwrap_or(true)
796 })
797 .map(|t| t.definition())
798 .collect();
799
800 Ok(McpResponse::ListTools(ListToolsResult {
801 tools,
802 next_cursor: None,
803 }))
804 }
805
806 McpRequest::CallTool(params) => {
807 let tool =
808 self.inner.tools.get(¶ms.name).ok_or_else(|| {
809 Error::JsonRpc(JsonRpcError::method_not_found(¶ms.name))
810 })?;
811
812 if let Some(filter) = &self.inner.tool_filter {
814 if !filter.is_visible(&self.session, tool) {
815 return Err(filter.denial_error(¶ms.name));
816 }
817 }
818
819 let progress_token = params.meta.and_then(|m| m.progress_token);
821 let ctx = self.create_context(request_id, progress_token);
822
823 tracing::debug!(tool = %params.name, "Calling tool");
824 let result = tool.call_with_context(ctx, params.arguments).await?;
825
826 Ok(McpResponse::CallTool(result))
827 }
828
829 McpRequest::ListResources(_params) => {
830 let resources: Vec<ResourceDefinition> = self
831 .inner
832 .resources
833 .values()
834 .filter(|r| {
835 self.inner
837 .resource_filter
838 .as_ref()
839 .map(|f| f.is_visible(&self.session, r))
840 .unwrap_or(true)
841 })
842 .map(|r| r.definition())
843 .collect();
844
845 Ok(McpResponse::ListResources(ListResourcesResult {
846 resources,
847 next_cursor: None,
848 }))
849 }
850
851 McpRequest::ListResourceTemplates(_params) => {
852 let resource_templates: Vec<ResourceTemplateDefinition> = self
853 .inner
854 .resource_templates
855 .iter()
856 .map(|t| t.definition())
857 .collect();
858
859 Ok(McpResponse::ListResourceTemplates(
860 ListResourceTemplatesResult {
861 resource_templates,
862 next_cursor: None,
863 },
864 ))
865 }
866
867 McpRequest::ReadResource(params) => {
868 if let Some(resource) = self.inner.resources.get(¶ms.uri) {
870 if let Some(filter) = &self.inner.resource_filter {
872 if !filter.is_visible(&self.session, resource) {
873 return Err(filter.denial_error(¶ms.uri));
874 }
875 }
876
877 tracing::debug!(uri = %params.uri, "Reading static resource");
878 let result = resource.read().await?;
879 return Ok(McpResponse::ReadResource(result));
880 }
881
882 for template in &self.inner.resource_templates {
884 if let Some(variables) = template.match_uri(¶ms.uri) {
885 tracing::debug!(
886 uri = %params.uri,
887 template = %template.uri_template,
888 "Reading resource via template"
889 );
890 let result = template.read(¶ms.uri, variables).await?;
891 return Ok(McpResponse::ReadResource(result));
892 }
893 }
894
895 Err(Error::JsonRpc(JsonRpcError::resource_not_found(
897 ¶ms.uri,
898 )))
899 }
900
901 McpRequest::SubscribeResource(params) => {
902 if !self.inner.resources.contains_key(¶ms.uri) {
904 return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
905 ¶ms.uri,
906 )));
907 }
908
909 tracing::debug!(uri = %params.uri, "Subscribing to resource");
910 self.subscribe(¶ms.uri);
911
912 Ok(McpResponse::SubscribeResource(EmptyResult {}))
913 }
914
915 McpRequest::UnsubscribeResource(params) => {
916 if !self.inner.resources.contains_key(¶ms.uri) {
918 return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
919 ¶ms.uri,
920 )));
921 }
922
923 tracing::debug!(uri = %params.uri, "Unsubscribing from resource");
924 self.unsubscribe(¶ms.uri);
925
926 Ok(McpResponse::UnsubscribeResource(EmptyResult {}))
927 }
928
929 McpRequest::ListPrompts(_params) => {
930 let prompts: Vec<PromptDefinition> = self
931 .inner
932 .prompts
933 .values()
934 .filter(|p| {
935 self.inner
937 .prompt_filter
938 .as_ref()
939 .map(|f| f.is_visible(&self.session, p))
940 .unwrap_or(true)
941 })
942 .map(|p| p.definition())
943 .collect();
944
945 Ok(McpResponse::ListPrompts(ListPromptsResult {
946 prompts,
947 next_cursor: None,
948 }))
949 }
950
951 McpRequest::GetPrompt(params) => {
952 let prompt = self.inner.prompts.get(¶ms.name).ok_or_else(|| {
953 Error::JsonRpc(JsonRpcError::method_not_found(&format!(
954 "Prompt not found: {}",
955 params.name
956 )))
957 })?;
958
959 if let Some(filter) = &self.inner.prompt_filter {
961 if !filter.is_visible(&self.session, prompt) {
962 return Err(filter.denial_error(¶ms.name));
963 }
964 }
965
966 tracing::debug!(name = %params.name, "Getting prompt");
967 let result = prompt.get(params.arguments).await?;
968
969 Ok(McpResponse::GetPrompt(result))
970 }
971
972 McpRequest::Ping => Ok(McpResponse::Pong(EmptyResult {})),
973
974 McpRequest::EnqueueTask(params) => {
975 let tool = self.inner.tools.get(¶ms.tool_name).ok_or_else(|| {
977 Error::JsonRpc(JsonRpcError::method_not_found(&format!(
978 "Tool not found: {}",
979 params.tool_name
980 )))
981 })?;
982
983 let (task_id, cancellation_token) = self.inner.task_store.create_task(
985 ¶ms.tool_name,
986 params.arguments.clone(),
987 params.ttl,
988 );
989
990 tracing::info!(task_id = %task_id, tool = %params.tool_name, "Enqueued async task");
991
992 let ctx = self.create_context(request_id, None);
994
995 let task_store = self.inner.task_store.clone();
997 let tool = tool.clone();
998 let arguments = params.arguments;
999 let task_id_clone = task_id.clone();
1000
1001 tokio::spawn(async move {
1002 if cancellation_token.is_cancelled() {
1004 tracing::debug!(task_id = %task_id_clone, "Task cancelled before execution");
1005 return;
1006 }
1007
1008 match tool.call_with_context(ctx, arguments).await {
1010 Ok(result) => {
1011 if cancellation_token.is_cancelled() {
1012 tracing::debug!(task_id = %task_id_clone, "Task cancelled during execution");
1013 } else {
1014 task_store.complete_task(&task_id_clone, result);
1015 tracing::debug!(task_id = %task_id_clone, "Task completed successfully");
1016 }
1017 }
1018 Err(e) => {
1019 task_store.fail_task(&task_id_clone, &e.to_string());
1020 tracing::warn!(task_id = %task_id_clone, error = %e, "Task failed");
1021 }
1022 }
1023 });
1024
1025 Ok(McpResponse::EnqueueTask(EnqueueTaskResult {
1026 task_id,
1027 status: TaskStatus::Working,
1028 poll_interval: Some(2),
1029 }))
1030 }
1031
1032 McpRequest::ListTasks(params) => {
1033 let tasks = self.inner.task_store.list_tasks(params.status);
1034
1035 Ok(McpResponse::ListTasks(ListTasksResult {
1036 tasks,
1037 next_cursor: None,
1038 }))
1039 }
1040
1041 McpRequest::GetTaskInfo(params) => {
1042 let task = self
1043 .inner
1044 .task_store
1045 .get_task(¶ms.task_id)
1046 .ok_or_else(|| {
1047 Error::JsonRpc(JsonRpcError::invalid_params(format!(
1048 "Task not found: {}",
1049 params.task_id
1050 )))
1051 })?;
1052
1053 Ok(McpResponse::GetTaskInfo(task))
1054 }
1055
1056 McpRequest::GetTaskResult(params) => {
1057 let (status, result, error) = self
1058 .inner
1059 .task_store
1060 .get_task_full(¶ms.task_id)
1061 .ok_or_else(|| {
1062 Error::JsonRpc(JsonRpcError::invalid_params(format!(
1063 "Task not found: {}",
1064 params.task_id
1065 )))
1066 })?;
1067
1068 Ok(McpResponse::GetTaskResult(GetTaskResultResult {
1069 task_id: params.task_id,
1070 status,
1071 result,
1072 error,
1073 }))
1074 }
1075
1076 McpRequest::CancelTask(params) => {
1077 let status = self
1078 .inner
1079 .task_store
1080 .cancel_task(¶ms.task_id, params.reason.as_deref())
1081 .ok_or_else(|| {
1082 Error::JsonRpc(JsonRpcError::invalid_params(format!(
1083 "Task not found: {}",
1084 params.task_id
1085 )))
1086 })?;
1087
1088 let cancelled = status == TaskStatus::Cancelled;
1089
1090 Ok(McpResponse::CancelTask(CancelTaskResult {
1091 cancelled,
1092 status,
1093 }))
1094 }
1095
1096 McpRequest::SetLoggingLevel(params) => {
1097 tracing::debug!(level = ?params.level, "Client set logging level");
1101 Ok(McpResponse::SetLoggingLevel(EmptyResult {}))
1102 }
1103
1104 McpRequest::Complete(params) => {
1105 tracing::debug!(
1106 reference = ?params.reference,
1107 argument = %params.argument.name,
1108 "Completion request"
1109 );
1110
1111 if let Some(ref handler) = self.inner.completion_handler {
1113 let result = handler(params).await?;
1114 Ok(McpResponse::Complete(result))
1115 } else {
1116 Ok(McpResponse::Complete(CompleteResult::new(vec![])))
1118 }
1119 }
1120
1121 McpRequest::Unknown { method, .. } => {
1122 Err(Error::JsonRpc(JsonRpcError::method_not_found(&method)))
1123 }
1124 }
1125 }
1126
1127 pub fn handle_notification(&self, notification: McpNotification) {
1129 match notification {
1130 McpNotification::Initialized => {
1131 if self.session.mark_initialized() {
1132 tracing::info!("Session initialized, entering operation phase");
1133 } else {
1134 tracing::warn!(
1135 "Received initialized notification in unexpected state: {:?}",
1136 self.session.phase()
1137 );
1138 }
1139 }
1140 McpNotification::Cancelled(params) => {
1141 if self.cancel_request(¶ms.request_id) {
1142 tracing::info!(
1143 request_id = ?params.request_id,
1144 reason = ?params.reason,
1145 "Request cancelled"
1146 );
1147 } else {
1148 tracing::debug!(
1149 request_id = ?params.request_id,
1150 reason = ?params.reason,
1151 "Cancellation requested for unknown request"
1152 );
1153 }
1154 }
1155 McpNotification::Progress(params) => {
1156 tracing::trace!(
1157 token = ?params.progress_token,
1158 progress = params.progress,
1159 total = ?params.total,
1160 "Progress notification"
1161 );
1162 }
1164 McpNotification::RootsListChanged => {
1165 tracing::info!("Client roots list changed");
1166 }
1169 McpNotification::Unknown { method, .. } => {
1170 tracing::debug!(method = %method, "Unknown notification received");
1171 }
1172 }
1173 }
1174}
1175
1176impl Default for McpRouter {
1177 fn default() -> Self {
1178 Self::new()
1179 }
1180}
1181
1182#[derive(Default, Clone)]
1202pub struct Extensions {
1203 map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
1204}
1205
1206impl Extensions {
1207 pub fn new() -> Self {
1209 Self::default()
1210 }
1211
1212 pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
1216 self.map.insert(TypeId::of::<T>(), Arc::new(val));
1217 }
1218
1219 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
1223 self.map
1224 .get(&TypeId::of::<T>())
1225 .and_then(|val| val.downcast_ref::<T>())
1226 }
1227}
1228
1229impl std::fmt::Debug for Extensions {
1230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1231 f.debug_struct("Extensions")
1232 .field("len", &self.map.len())
1233 .finish()
1234 }
1235}
1236
1237#[derive(Debug)]
1239pub struct RouterRequest {
1240 pub id: RequestId,
1241 pub inner: McpRequest,
1242 pub extensions: Extensions,
1244}
1245
1246#[derive(Debug)]
1248pub struct RouterResponse {
1249 pub id: RequestId,
1250 pub inner: std::result::Result<McpResponse, JsonRpcError>,
1251}
1252
1253impl RouterResponse {
1254 pub fn into_jsonrpc(self) -> JsonRpcResponse {
1256 match self.inner {
1257 Ok(response) => match serde_json::to_value(response) {
1258 Ok(result) => JsonRpcResponse::result(self.id, result),
1259 Err(e) => {
1260 tracing::error!(error = %e, "Failed to serialize response");
1261 JsonRpcResponse::error(
1262 Some(self.id),
1263 JsonRpcError::internal_error(format!("Serialization error: {}", e)),
1264 )
1265 }
1266 },
1267 Err(error) => JsonRpcResponse::error(Some(self.id), error),
1268 }
1269 }
1270}
1271
1272impl Service<RouterRequest> for McpRouter {
1273 type Response = RouterResponse;
1274 type Error = std::convert::Infallible; type Future =
1276 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
1277
1278 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
1279 Poll::Ready(Ok(()))
1280 }
1281
1282 fn call(&mut self, req: RouterRequest) -> Self::Future {
1283 let router = self.clone();
1284 let request_id = req.id.clone();
1285 Box::pin(async move {
1286 let result = router.handle(req.id, req.inner).await;
1287 router.complete_request(&request_id);
1289 Ok(RouterResponse {
1290 id: request_id,
1291 inner: result.map_err(|e| match e {
1296 Error::JsonRpc(err) => err,
1297 Error::Tool(err) => JsonRpcError::internal_error(err.to_string()),
1298 e => JsonRpcError::internal_error(e.to_string()),
1299 }),
1300 })
1301 })
1302 }
1303}
1304
1305#[cfg(test)]
1306mod tests {
1307 use super::*;
1308 use crate::jsonrpc::JsonRpcService;
1309 use crate::tool::ToolBuilder;
1310 use schemars::JsonSchema;
1311 use serde::Deserialize;
1312 use tower::ServiceExt;
1313
1314 #[derive(Debug, Deserialize, JsonSchema)]
1315 struct AddInput {
1316 a: i64,
1317 b: i64,
1318 }
1319
1320 async fn init_router(router: &mut McpRouter) {
1322 let init_req = RouterRequest {
1324 id: RequestId::Number(0),
1325 inner: McpRequest::Initialize(InitializeParams {
1326 protocol_version: "2025-11-25".to_string(),
1327 capabilities: ClientCapabilities {
1328 roots: None,
1329 sampling: None,
1330 elicitation: None,
1331 },
1332 client_info: Implementation {
1333 name: "test".to_string(),
1334 version: "1.0".to_string(),
1335 ..Default::default()
1336 },
1337 }),
1338 extensions: Extensions::new(),
1339 };
1340 let _ = router.ready().await.unwrap().call(init_req).await.unwrap();
1341 router.handle_notification(McpNotification::Initialized);
1343 }
1344
1345 #[tokio::test]
1346 async fn test_router_list_tools() {
1347 let add_tool = ToolBuilder::new("add")
1348 .description("Add two numbers")
1349 .handler(|input: AddInput| async move {
1350 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1351 })
1352 .build()
1353 .expect("valid tool name");
1354
1355 let mut router = McpRouter::new().tool(add_tool);
1356
1357 init_router(&mut router).await;
1359
1360 let req = RouterRequest {
1361 id: RequestId::Number(1),
1362 inner: McpRequest::ListTools(ListToolsParams::default()),
1363 extensions: Extensions::new(),
1364 };
1365
1366 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1367
1368 match resp.inner {
1369 Ok(McpResponse::ListTools(result)) => {
1370 assert_eq!(result.tools.len(), 1);
1371 assert_eq!(result.tools[0].name, "add");
1372 }
1373 _ => panic!("Expected ListTools response"),
1374 }
1375 }
1376
1377 #[tokio::test]
1378 async fn test_router_call_tool() {
1379 let add_tool = ToolBuilder::new("add")
1380 .description("Add two numbers")
1381 .handler(|input: AddInput| async move {
1382 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1383 })
1384 .build()
1385 .expect("valid tool name");
1386
1387 let mut router = McpRouter::new().tool(add_tool);
1388
1389 init_router(&mut router).await;
1391
1392 let req = RouterRequest {
1393 id: RequestId::Number(1),
1394 inner: McpRequest::CallTool(CallToolParams {
1395 name: "add".to_string(),
1396 arguments: serde_json::json!({"a": 2, "b": 3}),
1397 meta: None,
1398 }),
1399 extensions: Extensions::new(),
1400 };
1401
1402 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1403
1404 match resp.inner {
1405 Ok(McpResponse::CallTool(result)) => {
1406 assert!(!result.is_error);
1407 match &result.content[0] {
1409 Content::Text { text, .. } => assert_eq!(text, "5"),
1410 _ => panic!("Expected text content"),
1411 }
1412 }
1413 _ => panic!("Expected CallTool response"),
1414 }
1415 }
1416
1417 async fn init_jsonrpc_service(service: &mut JsonRpcService<McpRouter>, router: &McpRouter) {
1419 let init_req = JsonRpcRequest::new(0, "initialize").with_params(serde_json::json!({
1420 "protocolVersion": "2025-11-25",
1421 "capabilities": {},
1422 "clientInfo": { "name": "test", "version": "1.0" }
1423 }));
1424 let _ = service.call_single(init_req).await.unwrap();
1425 router.handle_notification(McpNotification::Initialized);
1426 }
1427
1428 #[tokio::test]
1429 async fn test_jsonrpc_service() {
1430 let add_tool = ToolBuilder::new("add")
1431 .description("Add two numbers")
1432 .handler(|input: AddInput| async move {
1433 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1434 })
1435 .build()
1436 .expect("valid tool name");
1437
1438 let router = McpRouter::new().tool(add_tool);
1439 let mut service = JsonRpcService::new(router.clone());
1440
1441 init_jsonrpc_service(&mut service, &router).await;
1443
1444 let req = JsonRpcRequest::new(1, "tools/list");
1445
1446 let resp = service.call_single(req).await.unwrap();
1447
1448 match resp {
1449 JsonRpcResponse::Result(r) => {
1450 assert_eq!(r.id, RequestId::Number(1));
1451 let tools = r.result.get("tools").unwrap().as_array().unwrap();
1452 assert_eq!(tools.len(), 1);
1453 }
1454 JsonRpcResponse::Error(_) => panic!("Expected success response"),
1455 }
1456 }
1457
1458 #[tokio::test]
1459 async fn test_batch_request() {
1460 let add_tool = ToolBuilder::new("add")
1461 .description("Add two numbers")
1462 .handler(|input: AddInput| async move {
1463 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1464 })
1465 .build()
1466 .expect("valid tool name");
1467
1468 let router = McpRouter::new().tool(add_tool);
1469 let mut service = JsonRpcService::new(router.clone());
1470
1471 init_jsonrpc_service(&mut service, &router).await;
1473
1474 let requests = vec![
1476 JsonRpcRequest::new(1, "tools/list"),
1477 JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1478 "name": "add",
1479 "arguments": {"a": 10, "b": 20}
1480 })),
1481 JsonRpcRequest::new(3, "ping"),
1482 ];
1483
1484 let responses = service.call_batch(requests).await.unwrap();
1485
1486 assert_eq!(responses.len(), 3);
1487
1488 match &responses[0] {
1490 JsonRpcResponse::Result(r) => {
1491 assert_eq!(r.id, RequestId::Number(1));
1492 let tools = r.result.get("tools").unwrap().as_array().unwrap();
1493 assert_eq!(tools.len(), 1);
1494 }
1495 JsonRpcResponse::Error(_) => panic!("Expected success for tools/list"),
1496 }
1497
1498 match &responses[1] {
1500 JsonRpcResponse::Result(r) => {
1501 assert_eq!(r.id, RequestId::Number(2));
1502 let content = r.result.get("content").unwrap().as_array().unwrap();
1503 let text = content[0].get("text").unwrap().as_str().unwrap();
1504 assert_eq!(text, "30");
1505 }
1506 JsonRpcResponse::Error(_) => panic!("Expected success for tools/call"),
1507 }
1508
1509 match &responses[2] {
1511 JsonRpcResponse::Result(r) => {
1512 assert_eq!(r.id, RequestId::Number(3));
1513 }
1514 JsonRpcResponse::Error(_) => panic!("Expected success for ping"),
1515 }
1516 }
1517
1518 #[tokio::test]
1519 async fn test_empty_batch_error() {
1520 let router = McpRouter::new();
1521 let mut service = JsonRpcService::new(router);
1522
1523 let result = service.call_batch(vec![]).await;
1524 assert!(result.is_err());
1525 }
1526
1527 #[tokio::test]
1532 async fn test_progress_token_extraction() {
1533 use crate::context::{RequestContext, ServerNotification, notification_channel};
1534 use crate::protocol::ProgressToken;
1535 use std::sync::Arc;
1536 use std::sync::atomic::{AtomicBool, Ordering};
1537
1538 let progress_reported = Arc::new(AtomicBool::new(false));
1540 let progress_ref = progress_reported.clone();
1541
1542 let tool = ToolBuilder::new("progress_tool")
1544 .description("Tool that reports progress")
1545 .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1546 let reported = progress_ref.clone();
1547 async move {
1548 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
1550 .await;
1551 reported.store(true, Ordering::SeqCst);
1552 Ok(CallToolResult::text("done"))
1553 }
1554 })
1555 .build()
1556 .expect("valid tool name");
1557
1558 let (tx, mut rx) = notification_channel(10);
1560 let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1561 let mut service = JsonRpcService::new(router.clone());
1562
1563 init_jsonrpc_service(&mut service, &router).await;
1565
1566 let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1568 "name": "progress_tool",
1569 "arguments": {"a": 1, "b": 2},
1570 "_meta": {
1571 "progressToken": "test-token-123"
1572 }
1573 }));
1574
1575 let resp = service.call_single(req).await.unwrap();
1576
1577 match resp {
1579 JsonRpcResponse::Result(_) => {}
1580 JsonRpcResponse::Error(e) => panic!("Expected success, got error: {:?}", e),
1581 }
1582
1583 assert!(progress_reported.load(Ordering::SeqCst));
1585
1586 let notification = rx.try_recv().expect("Expected progress notification");
1588 match notification {
1589 ServerNotification::Progress(params) => {
1590 assert_eq!(
1591 params.progress_token,
1592 ProgressToken::String("test-token-123".to_string())
1593 );
1594 assert_eq!(params.progress, 50.0);
1595 assert_eq!(params.total, Some(100.0));
1596 assert_eq!(params.message.as_deref(), Some("Halfway"));
1597 }
1598 _ => panic!("Expected Progress notification"),
1599 }
1600 }
1601
1602 #[tokio::test]
1603 async fn test_tool_call_without_progress_token() {
1604 use crate::context::{RequestContext, notification_channel};
1605 use std::sync::Arc;
1606 use std::sync::atomic::{AtomicBool, Ordering};
1607
1608 let progress_attempted = Arc::new(AtomicBool::new(false));
1609 let progress_ref = progress_attempted.clone();
1610
1611 let tool = ToolBuilder::new("no_token_tool")
1612 .description("Tool that tries to report progress without token")
1613 .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1614 let attempted = progress_ref.clone();
1615 async move {
1616 ctx.report_progress(50.0, Some(100.0), None).await;
1618 attempted.store(true, Ordering::SeqCst);
1619 Ok(CallToolResult::text("done"))
1620 }
1621 })
1622 .build()
1623 .expect("valid tool name");
1624
1625 let (tx, mut rx) = notification_channel(10);
1626 let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1627 let mut service = JsonRpcService::new(router.clone());
1628
1629 init_jsonrpc_service(&mut service, &router).await;
1630
1631 let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1633 "name": "no_token_tool",
1634 "arguments": {"a": 1, "b": 2}
1635 }));
1636
1637 let resp = service.call_single(req).await.unwrap();
1638 assert!(matches!(resp, JsonRpcResponse::Result(_)));
1639
1640 assert!(progress_attempted.load(Ordering::SeqCst));
1642
1643 assert!(rx.try_recv().is_err());
1645 }
1646
1647 #[tokio::test]
1648 async fn test_batch_errors_returned_not_dropped() {
1649 let add_tool = ToolBuilder::new("add")
1650 .description("Add two numbers")
1651 .handler(|input: AddInput| async move {
1652 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1653 })
1654 .build()
1655 .expect("valid tool name");
1656
1657 let router = McpRouter::new().tool(add_tool);
1658 let mut service = JsonRpcService::new(router.clone());
1659
1660 init_jsonrpc_service(&mut service, &router).await;
1661
1662 let requests = vec![
1664 JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1666 "name": "add",
1667 "arguments": {"a": 10, "b": 20}
1668 })),
1669 JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1671 "name": "nonexistent_tool",
1672 "arguments": {}
1673 })),
1674 JsonRpcRequest::new(3, "ping"),
1676 ];
1677
1678 let responses = service.call_batch(requests).await.unwrap();
1679
1680 assert_eq!(responses.len(), 3);
1682
1683 match &responses[0] {
1685 JsonRpcResponse::Result(r) => {
1686 assert_eq!(r.id, RequestId::Number(1));
1687 }
1688 JsonRpcResponse::Error(_) => panic!("Expected success for first request"),
1689 }
1690
1691 match &responses[1] {
1693 JsonRpcResponse::Error(e) => {
1694 assert_eq!(e.id, Some(RequestId::Number(2)));
1695 assert!(e.error.message.contains("not found") || e.error.code == -32601);
1697 }
1698 JsonRpcResponse::Result(_) => panic!("Expected error for second request"),
1699 }
1700
1701 match &responses[2] {
1703 JsonRpcResponse::Result(r) => {
1704 assert_eq!(r.id, RequestId::Number(3));
1705 }
1706 JsonRpcResponse::Error(_) => panic!("Expected success for third request"),
1707 }
1708 }
1709
1710 #[tokio::test]
1715 async fn test_list_resource_templates() {
1716 use crate::resource::ResourceTemplateBuilder;
1717 use std::collections::HashMap;
1718
1719 let template = ResourceTemplateBuilder::new("file:///{path}")
1720 .name("Project Files")
1721 .description("Access project files")
1722 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1723 Ok(ReadResourceResult {
1724 contents: vec![ResourceContent {
1725 uri,
1726 mime_type: None,
1727 text: None,
1728 blob: None,
1729 }],
1730 })
1731 });
1732
1733 let mut router = McpRouter::new().resource_template(template);
1734
1735 init_router(&mut router).await;
1737
1738 let req = RouterRequest {
1739 id: RequestId::Number(1),
1740 inner: McpRequest::ListResourceTemplates(ListResourceTemplatesParams::default()),
1741 extensions: Extensions::new(),
1742 };
1743
1744 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1745
1746 match resp.inner {
1747 Ok(McpResponse::ListResourceTemplates(result)) => {
1748 assert_eq!(result.resource_templates.len(), 1);
1749 assert_eq!(result.resource_templates[0].uri_template, "file:///{path}");
1750 assert_eq!(result.resource_templates[0].name, "Project Files");
1751 }
1752 _ => panic!("Expected ListResourceTemplates response"),
1753 }
1754 }
1755
1756 #[tokio::test]
1757 async fn test_read_resource_via_template() {
1758 use crate::resource::ResourceTemplateBuilder;
1759 use std::collections::HashMap;
1760
1761 let template = ResourceTemplateBuilder::new("db://users/{id}")
1762 .name("User Records")
1763 .handler(|uri: String, vars: HashMap<String, String>| async move {
1764 let id = vars.get("id").unwrap().clone();
1765 Ok(ReadResourceResult {
1766 contents: vec![ResourceContent {
1767 uri,
1768 mime_type: Some("application/json".to_string()),
1769 text: Some(format!(r#"{{"id": "{}"}}"#, id)),
1770 blob: None,
1771 }],
1772 })
1773 });
1774
1775 let mut router = McpRouter::new().resource_template(template);
1776
1777 init_router(&mut router).await;
1779
1780 let req = RouterRequest {
1782 id: RequestId::Number(1),
1783 inner: McpRequest::ReadResource(ReadResourceParams {
1784 uri: "db://users/123".to_string(),
1785 }),
1786 extensions: Extensions::new(),
1787 };
1788
1789 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1790
1791 match resp.inner {
1792 Ok(McpResponse::ReadResource(result)) => {
1793 assert_eq!(result.contents.len(), 1);
1794 assert_eq!(result.contents[0].uri, "db://users/123");
1795 assert!(result.contents[0].text.as_ref().unwrap().contains("123"));
1796 }
1797 _ => panic!("Expected ReadResource response"),
1798 }
1799 }
1800
1801 #[tokio::test]
1802 async fn test_static_resource_takes_precedence_over_template() {
1803 use crate::resource::{ResourceBuilder, ResourceTemplateBuilder};
1804 use std::collections::HashMap;
1805
1806 let template = ResourceTemplateBuilder::new("file:///{path}")
1808 .name("Files Template")
1809 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1810 Ok(ReadResourceResult {
1811 contents: vec![ResourceContent {
1812 uri,
1813 mime_type: None,
1814 text: Some("from template".to_string()),
1815 blob: None,
1816 }],
1817 })
1818 });
1819
1820 let static_resource = ResourceBuilder::new("file:///README.md")
1822 .name("README")
1823 .text("from static resource");
1824
1825 let mut router = McpRouter::new()
1826 .resource_template(template)
1827 .resource(static_resource);
1828
1829 init_router(&mut router).await;
1831
1832 let req = RouterRequest {
1834 id: RequestId::Number(1),
1835 inner: McpRequest::ReadResource(ReadResourceParams {
1836 uri: "file:///README.md".to_string(),
1837 }),
1838 extensions: Extensions::new(),
1839 };
1840
1841 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1842
1843 match resp.inner {
1844 Ok(McpResponse::ReadResource(result)) => {
1845 assert_eq!(
1847 result.contents[0].text.as_deref(),
1848 Some("from static resource")
1849 );
1850 }
1851 _ => panic!("Expected ReadResource response"),
1852 }
1853 }
1854
1855 #[tokio::test]
1856 async fn test_resource_not_found_when_no_match() {
1857 use crate::resource::ResourceTemplateBuilder;
1858 use std::collections::HashMap;
1859
1860 let template = ResourceTemplateBuilder::new("db://users/{id}")
1861 .name("Users")
1862 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1863 Ok(ReadResourceResult {
1864 contents: vec![ResourceContent {
1865 uri,
1866 mime_type: None,
1867 text: None,
1868 blob: None,
1869 }],
1870 })
1871 });
1872
1873 let mut router = McpRouter::new().resource_template(template);
1874
1875 init_router(&mut router).await;
1877
1878 let req = RouterRequest {
1880 id: RequestId::Number(1),
1881 inner: McpRequest::ReadResource(ReadResourceParams {
1882 uri: "db://posts/123".to_string(),
1883 }),
1884 extensions: Extensions::new(),
1885 };
1886
1887 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1888
1889 match resp.inner {
1890 Err(err) => {
1891 assert!(err.message.contains("not found"));
1892 }
1893 Ok(_) => panic!("Expected error for non-matching URI"),
1894 }
1895 }
1896
1897 #[tokio::test]
1898 async fn test_capabilities_include_resources_with_only_templates() {
1899 use crate::resource::ResourceTemplateBuilder;
1900 use std::collections::HashMap;
1901
1902 let template = ResourceTemplateBuilder::new("file:///{path}")
1903 .name("Files")
1904 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1905 Ok(ReadResourceResult {
1906 contents: vec![ResourceContent {
1907 uri,
1908 mime_type: None,
1909 text: None,
1910 blob: None,
1911 }],
1912 })
1913 });
1914
1915 let mut router = McpRouter::new().resource_template(template);
1916
1917 let init_req = RouterRequest {
1919 id: RequestId::Number(0),
1920 inner: McpRequest::Initialize(InitializeParams {
1921 protocol_version: "2025-11-25".to_string(),
1922 capabilities: ClientCapabilities {
1923 roots: None,
1924 sampling: None,
1925 elicitation: None,
1926 },
1927 client_info: Implementation {
1928 name: "test".to_string(),
1929 version: "1.0".to_string(),
1930 ..Default::default()
1931 },
1932 }),
1933 extensions: Extensions::new(),
1934 };
1935 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1936
1937 match resp.inner {
1938 Ok(McpResponse::Initialize(result)) => {
1939 assert!(result.capabilities.resources.is_some());
1941 }
1942 _ => panic!("Expected Initialize response"),
1943 }
1944 }
1945
1946 #[tokio::test]
1951 async fn test_log_sends_notification() {
1952 use crate::context::notification_channel;
1953
1954 let (tx, mut rx) = notification_channel(10);
1955 let router = McpRouter::new().with_notification_sender(tx);
1956
1957 let sent = router.log_info("Test message");
1959 assert!(sent);
1960
1961 let notification = rx.try_recv().unwrap();
1963 match notification {
1964 ServerNotification::LogMessage(params) => {
1965 assert_eq!(params.level, LogLevel::Info);
1966 let data = params.data.unwrap();
1967 assert_eq!(
1968 data.get("message").unwrap().as_str().unwrap(),
1969 "Test message"
1970 );
1971 }
1972 _ => panic!("Expected LogMessage notification"),
1973 }
1974 }
1975
1976 #[tokio::test]
1977 async fn test_log_with_custom_params() {
1978 use crate::context::notification_channel;
1979
1980 let (tx, mut rx) = notification_channel(10);
1981 let router = McpRouter::new().with_notification_sender(tx);
1982
1983 let params = LoggingMessageParams::new(LogLevel::Error)
1985 .with_logger("database")
1986 .with_data(serde_json::json!({
1987 "error": "Connection failed",
1988 "host": "localhost"
1989 }));
1990
1991 let sent = router.log(params);
1992 assert!(sent);
1993
1994 let notification = rx.try_recv().unwrap();
1995 match notification {
1996 ServerNotification::LogMessage(params) => {
1997 assert_eq!(params.level, LogLevel::Error);
1998 assert_eq!(params.logger.as_deref(), Some("database"));
1999 let data = params.data.unwrap();
2000 assert_eq!(
2001 data.get("error").unwrap().as_str().unwrap(),
2002 "Connection failed"
2003 );
2004 }
2005 _ => panic!("Expected LogMessage notification"),
2006 }
2007 }
2008
2009 #[tokio::test]
2010 async fn test_log_without_channel_returns_false() {
2011 let router = McpRouter::new();
2013
2014 assert!(!router.log_info("Test"));
2016 assert!(!router.log_warning("Test"));
2017 assert!(!router.log_error("Test"));
2018 assert!(!router.log_debug("Test"));
2019 }
2020
2021 #[tokio::test]
2022 async fn test_logging_capability_with_channel() {
2023 use crate::context::notification_channel;
2024
2025 let (tx, _rx) = notification_channel(10);
2026 let mut router = McpRouter::new().with_notification_sender(tx);
2027
2028 let init_req = RouterRequest {
2030 id: RequestId::Number(0),
2031 inner: McpRequest::Initialize(InitializeParams {
2032 protocol_version: "2025-11-25".to_string(),
2033 capabilities: ClientCapabilities {
2034 roots: None,
2035 sampling: None,
2036 elicitation: None,
2037 },
2038 client_info: Implementation {
2039 name: "test".to_string(),
2040 version: "1.0".to_string(),
2041 ..Default::default()
2042 },
2043 }),
2044 extensions: Extensions::new(),
2045 };
2046 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2047
2048 match resp.inner {
2049 Ok(McpResponse::Initialize(result)) => {
2050 assert!(result.capabilities.logging.is_some());
2052 }
2053 _ => panic!("Expected Initialize response"),
2054 }
2055 }
2056
2057 #[tokio::test]
2058 async fn test_no_logging_capability_without_channel() {
2059 let mut router = McpRouter::new();
2060
2061 let init_req = RouterRequest {
2063 id: RequestId::Number(0),
2064 inner: McpRequest::Initialize(InitializeParams {
2065 protocol_version: "2025-11-25".to_string(),
2066 capabilities: ClientCapabilities {
2067 roots: None,
2068 sampling: None,
2069 elicitation: None,
2070 },
2071 client_info: Implementation {
2072 name: "test".to_string(),
2073 version: "1.0".to_string(),
2074 ..Default::default()
2075 },
2076 }),
2077 extensions: Extensions::new(),
2078 };
2079 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2080
2081 match resp.inner {
2082 Ok(McpResponse::Initialize(result)) => {
2083 assert!(result.capabilities.logging.is_none());
2085 }
2086 _ => panic!("Expected Initialize response"),
2087 }
2088 }
2089
2090 #[tokio::test]
2095 async fn test_enqueue_task() {
2096 let add_tool = ToolBuilder::new("add")
2097 .description("Add two numbers")
2098 .handler(|input: AddInput| async move {
2099 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2100 })
2101 .build()
2102 .expect("valid tool name");
2103
2104 let mut router = McpRouter::new().tool(add_tool);
2105 init_router(&mut router).await;
2106
2107 let req = RouterRequest {
2108 id: RequestId::Number(1),
2109 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2110 tool_name: "add".to_string(),
2111 arguments: serde_json::json!({"a": 5, "b": 10}),
2112 ttl: None,
2113 }),
2114 extensions: Extensions::new(),
2115 };
2116
2117 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2118
2119 match resp.inner {
2120 Ok(McpResponse::EnqueueTask(result)) => {
2121 assert!(result.task_id.starts_with("task-"));
2122 assert_eq!(result.status, TaskStatus::Working);
2123 }
2124 _ => panic!("Expected EnqueueTask response"),
2125 }
2126 }
2127
2128 #[tokio::test]
2129 async fn test_list_tasks_empty() {
2130 let mut router = McpRouter::new();
2131 init_router(&mut router).await;
2132
2133 let req = RouterRequest {
2134 id: RequestId::Number(1),
2135 inner: McpRequest::ListTasks(ListTasksParams::default()),
2136 extensions: Extensions::new(),
2137 };
2138
2139 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2140
2141 match resp.inner {
2142 Ok(McpResponse::ListTasks(result)) => {
2143 assert!(result.tasks.is_empty());
2144 }
2145 _ => panic!("Expected ListTasks response"),
2146 }
2147 }
2148
2149 #[tokio::test]
2150 async fn test_task_lifecycle_complete() {
2151 let add_tool = ToolBuilder::new("add")
2152 .description("Add two numbers")
2153 .handler(|input: AddInput| async move {
2154 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2155 })
2156 .build()
2157 .expect("valid tool name");
2158
2159 let mut router = McpRouter::new().tool(add_tool);
2160 init_router(&mut router).await;
2161
2162 let req = RouterRequest {
2164 id: RequestId::Number(1),
2165 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2166 tool_name: "add".to_string(),
2167 arguments: serde_json::json!({"a": 7, "b": 8}),
2168 ttl: None,
2169 }),
2170 extensions: Extensions::new(),
2171 };
2172
2173 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2174 let task_id = match resp.inner {
2175 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2176 _ => panic!("Expected EnqueueTask response"),
2177 };
2178
2179 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
2181
2182 let req = RouterRequest {
2184 id: RequestId::Number(2),
2185 inner: McpRequest::GetTaskResult(GetTaskResultParams {
2186 task_id: task_id.clone(),
2187 }),
2188 extensions: Extensions::new(),
2189 };
2190
2191 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2192
2193 match resp.inner {
2194 Ok(McpResponse::GetTaskResult(result)) => {
2195 assert_eq!(result.task_id, task_id);
2196 assert_eq!(result.status, TaskStatus::Completed);
2197 assert!(result.result.is_some());
2198 assert!(result.error.is_none());
2199
2200 let tool_result = result.result.unwrap();
2202 match &tool_result.content[0] {
2203 Content::Text { text, .. } => assert_eq!(text, "15"),
2204 _ => panic!("Expected text content"),
2205 }
2206 }
2207 _ => panic!("Expected GetTaskResult response"),
2208 }
2209 }
2210
2211 #[tokio::test]
2212 async fn test_task_cancellation() {
2213 let slow_tool = ToolBuilder::new("slow")
2215 .description("Slow tool")
2216 .handler(|_input: serde_json::Value| async move {
2217 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
2218 Ok(CallToolResult::text("done"))
2219 })
2220 .build()
2221 .expect("valid tool name");
2222
2223 let mut router = McpRouter::new().tool(slow_tool);
2224 init_router(&mut router).await;
2225
2226 let req = RouterRequest {
2228 id: RequestId::Number(1),
2229 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2230 tool_name: "slow".to_string(),
2231 arguments: serde_json::json!({}),
2232 ttl: None,
2233 }),
2234 extensions: Extensions::new(),
2235 };
2236
2237 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2238 let task_id = match resp.inner {
2239 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2240 _ => panic!("Expected EnqueueTask response"),
2241 };
2242
2243 let req = RouterRequest {
2245 id: RequestId::Number(2),
2246 inner: McpRequest::CancelTask(CancelTaskParams {
2247 task_id: task_id.clone(),
2248 reason: Some("Test cancellation".to_string()),
2249 }),
2250 extensions: Extensions::new(),
2251 };
2252
2253 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2254
2255 match resp.inner {
2256 Ok(McpResponse::CancelTask(result)) => {
2257 assert!(result.cancelled);
2258 assert_eq!(result.status, TaskStatus::Cancelled);
2259 }
2260 _ => panic!("Expected CancelTask response"),
2261 }
2262 }
2263
2264 #[tokio::test]
2265 async fn test_get_task_info() {
2266 let add_tool = ToolBuilder::new("add")
2267 .description("Add two numbers")
2268 .handler(|input: AddInput| async move {
2269 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2270 })
2271 .build()
2272 .expect("valid tool name");
2273
2274 let mut router = McpRouter::new().tool(add_tool);
2275 init_router(&mut router).await;
2276
2277 let req = RouterRequest {
2279 id: RequestId::Number(1),
2280 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2281 tool_name: "add".to_string(),
2282 arguments: serde_json::json!({"a": 1, "b": 2}),
2283 ttl: Some(600),
2284 }),
2285 extensions: Extensions::new(),
2286 };
2287
2288 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2289 let task_id = match resp.inner {
2290 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2291 _ => panic!("Expected EnqueueTask response"),
2292 };
2293
2294 let req = RouterRequest {
2296 id: RequestId::Number(2),
2297 inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2298 task_id: task_id.clone(),
2299 }),
2300 extensions: Extensions::new(),
2301 };
2302
2303 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2304
2305 match resp.inner {
2306 Ok(McpResponse::GetTaskInfo(info)) => {
2307 assert_eq!(info.task_id, task_id);
2308 assert!(info.created_at.contains('T')); assert_eq!(info.ttl, Some(600));
2310 }
2311 _ => panic!("Expected GetTaskInfo response"),
2312 }
2313 }
2314
2315 #[tokio::test]
2316 async fn test_enqueue_nonexistent_tool() {
2317 let mut router = McpRouter::new();
2318 init_router(&mut router).await;
2319
2320 let req = RouterRequest {
2321 id: RequestId::Number(1),
2322 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2323 tool_name: "nonexistent".to_string(),
2324 arguments: serde_json::json!({}),
2325 ttl: None,
2326 }),
2327 extensions: Extensions::new(),
2328 };
2329
2330 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2331
2332 match resp.inner {
2333 Err(e) => {
2334 assert!(e.message.contains("not found"));
2335 }
2336 _ => panic!("Expected error response"),
2337 }
2338 }
2339
2340 #[tokio::test]
2341 async fn test_get_nonexistent_task() {
2342 let mut router = McpRouter::new();
2343 init_router(&mut router).await;
2344
2345 let req = RouterRequest {
2346 id: RequestId::Number(1),
2347 inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2348 task_id: "task-999".to_string(),
2349 }),
2350 extensions: Extensions::new(),
2351 };
2352
2353 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2354
2355 match resp.inner {
2356 Err(e) => {
2357 assert!(e.message.contains("not found"));
2358 }
2359 _ => panic!("Expected error response"),
2360 }
2361 }
2362
2363 #[tokio::test]
2368 async fn test_subscribe_to_resource() {
2369 use crate::resource::ResourceBuilder;
2370
2371 let resource = ResourceBuilder::new("file:///test.txt")
2372 .name("Test File")
2373 .text("Hello");
2374
2375 let mut router = McpRouter::new().resource(resource);
2376 init_router(&mut router).await;
2377
2378 let req = RouterRequest {
2380 id: RequestId::Number(1),
2381 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2382 uri: "file:///test.txt".to_string(),
2383 }),
2384 extensions: Extensions::new(),
2385 };
2386
2387 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2388
2389 match resp.inner {
2390 Ok(McpResponse::SubscribeResource(_)) => {
2391 assert!(router.is_subscribed("file:///test.txt"));
2393 }
2394 _ => panic!("Expected SubscribeResource response"),
2395 }
2396 }
2397
2398 #[tokio::test]
2399 async fn test_unsubscribe_from_resource() {
2400 use crate::resource::ResourceBuilder;
2401
2402 let resource = ResourceBuilder::new("file:///test.txt")
2403 .name("Test File")
2404 .text("Hello");
2405
2406 let mut router = McpRouter::new().resource(resource);
2407 init_router(&mut router).await;
2408
2409 let req = RouterRequest {
2411 id: RequestId::Number(1),
2412 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2413 uri: "file:///test.txt".to_string(),
2414 }),
2415 extensions: Extensions::new(),
2416 };
2417 let _ = router.ready().await.unwrap().call(req).await.unwrap();
2418 assert!(router.is_subscribed("file:///test.txt"));
2419
2420 let req = RouterRequest {
2422 id: RequestId::Number(2),
2423 inner: McpRequest::UnsubscribeResource(UnsubscribeResourceParams {
2424 uri: "file:///test.txt".to_string(),
2425 }),
2426 extensions: Extensions::new(),
2427 };
2428
2429 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2430
2431 match resp.inner {
2432 Ok(McpResponse::UnsubscribeResource(_)) => {
2433 assert!(!router.is_subscribed("file:///test.txt"));
2435 }
2436 _ => panic!("Expected UnsubscribeResource response"),
2437 }
2438 }
2439
2440 #[tokio::test]
2441 async fn test_subscribe_nonexistent_resource() {
2442 let mut router = McpRouter::new();
2443 init_router(&mut router).await;
2444
2445 let req = RouterRequest {
2446 id: RequestId::Number(1),
2447 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2448 uri: "file:///nonexistent.txt".to_string(),
2449 }),
2450 extensions: Extensions::new(),
2451 };
2452
2453 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2454
2455 match resp.inner {
2456 Err(e) => {
2457 assert!(e.message.contains("not found"));
2458 }
2459 _ => panic!("Expected error response"),
2460 }
2461 }
2462
2463 #[tokio::test]
2464 async fn test_notify_resource_updated() {
2465 use crate::context::notification_channel;
2466 use crate::resource::ResourceBuilder;
2467
2468 let (tx, mut rx) = notification_channel(10);
2469
2470 let resource = ResourceBuilder::new("file:///test.txt")
2471 .name("Test File")
2472 .text("Hello");
2473
2474 let router = McpRouter::new()
2475 .resource(resource)
2476 .with_notification_sender(tx);
2477
2478 router.subscribe("file:///test.txt");
2480
2481 let sent = router.notify_resource_updated("file:///test.txt");
2483 assert!(sent);
2484
2485 let notification = rx.try_recv().unwrap();
2487 match notification {
2488 ServerNotification::ResourceUpdated { uri } => {
2489 assert_eq!(uri, "file:///test.txt");
2490 }
2491 _ => panic!("Expected ResourceUpdated notification"),
2492 }
2493 }
2494
2495 #[tokio::test]
2496 async fn test_notify_resource_updated_not_subscribed() {
2497 use crate::context::notification_channel;
2498 use crate::resource::ResourceBuilder;
2499
2500 let (tx, mut rx) = notification_channel(10);
2501
2502 let resource = ResourceBuilder::new("file:///test.txt")
2503 .name("Test File")
2504 .text("Hello");
2505
2506 let router = McpRouter::new()
2507 .resource(resource)
2508 .with_notification_sender(tx);
2509
2510 let sent = router.notify_resource_updated("file:///test.txt");
2512 assert!(!sent); assert!(rx.try_recv().is_err());
2516 }
2517
2518 #[tokio::test]
2519 async fn test_notify_resources_list_changed() {
2520 use crate::context::notification_channel;
2521
2522 let (tx, mut rx) = notification_channel(10);
2523 let router = McpRouter::new().with_notification_sender(tx);
2524
2525 let sent = router.notify_resources_list_changed();
2526 assert!(sent);
2527
2528 let notification = rx.try_recv().unwrap();
2529 match notification {
2530 ServerNotification::ResourcesListChanged => {}
2531 _ => panic!("Expected ResourcesListChanged notification"),
2532 }
2533 }
2534
2535 #[tokio::test]
2536 async fn test_subscribed_uris() {
2537 use crate::resource::ResourceBuilder;
2538
2539 let resource1 = ResourceBuilder::new("file:///a.txt").name("A").text("A");
2540
2541 let resource2 = ResourceBuilder::new("file:///b.txt").name("B").text("B");
2542
2543 let router = McpRouter::new().resource(resource1).resource(resource2);
2544
2545 router.subscribe("file:///a.txt");
2547 router.subscribe("file:///b.txt");
2548
2549 let uris = router.subscribed_uris();
2550 assert_eq!(uris.len(), 2);
2551 assert!(uris.contains(&"file:///a.txt".to_string()));
2552 assert!(uris.contains(&"file:///b.txt".to_string()));
2553 }
2554
2555 #[tokio::test]
2556 async fn test_subscription_capability_advertised() {
2557 use crate::resource::ResourceBuilder;
2558
2559 let resource = ResourceBuilder::new("file:///test.txt")
2560 .name("Test")
2561 .text("Hello");
2562
2563 let mut router = McpRouter::new().resource(resource);
2564
2565 let init_req = RouterRequest {
2567 id: RequestId::Number(0),
2568 inner: McpRequest::Initialize(InitializeParams {
2569 protocol_version: "2025-11-25".to_string(),
2570 capabilities: ClientCapabilities {
2571 roots: None,
2572 sampling: None,
2573 elicitation: None,
2574 },
2575 client_info: Implementation {
2576 name: "test".to_string(),
2577 version: "1.0".to_string(),
2578 ..Default::default()
2579 },
2580 }),
2581 extensions: Extensions::new(),
2582 };
2583 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2584
2585 match resp.inner {
2586 Ok(McpResponse::Initialize(result)) => {
2587 let resources_cap = result.capabilities.resources.unwrap();
2589 assert!(resources_cap.subscribe);
2590 }
2591 _ => panic!("Expected Initialize response"),
2592 }
2593 }
2594
2595 #[tokio::test]
2596 async fn test_completion_handler() {
2597 let router = McpRouter::new()
2598 .server_info("test", "1.0")
2599 .completion_handler(|params: CompleteParams| async move {
2600 let prefix = ¶ms.argument.value;
2602 let suggestions: Vec<String> = vec!["alpha", "beta", "gamma"]
2603 .into_iter()
2604 .filter(|s| s.starts_with(prefix))
2605 .map(String::from)
2606 .collect();
2607 Ok(CompleteResult::new(suggestions))
2608 });
2609
2610 let init_req = RouterRequest {
2612 id: RequestId::Number(0),
2613 inner: McpRequest::Initialize(InitializeParams {
2614 protocol_version: "2025-11-25".to_string(),
2615 capabilities: ClientCapabilities::default(),
2616 client_info: Implementation {
2617 name: "test".to_string(),
2618 version: "1.0".to_string(),
2619 ..Default::default()
2620 },
2621 }),
2622 extensions: Extensions::new(),
2623 };
2624 let resp = router
2625 .clone()
2626 .ready()
2627 .await
2628 .unwrap()
2629 .call(init_req)
2630 .await
2631 .unwrap();
2632
2633 match resp.inner {
2635 Ok(McpResponse::Initialize(result)) => {
2636 assert!(result.capabilities.completions.is_some());
2637 }
2638 _ => panic!("Expected Initialize response"),
2639 }
2640
2641 router.handle_notification(McpNotification::Initialized);
2643
2644 let complete_req = RouterRequest {
2646 id: RequestId::Number(1),
2647 inner: McpRequest::Complete(CompleteParams {
2648 reference: CompletionReference::prompt("test-prompt"),
2649 argument: CompletionArgument::new("query", "al"),
2650 }),
2651 extensions: Extensions::new(),
2652 };
2653 let resp = router
2654 .clone()
2655 .ready()
2656 .await
2657 .unwrap()
2658 .call(complete_req)
2659 .await
2660 .unwrap();
2661
2662 match resp.inner {
2663 Ok(McpResponse::Complete(result)) => {
2664 assert_eq!(result.completion.values, vec!["alpha"]);
2665 }
2666 _ => panic!("Expected Complete response"),
2667 }
2668 }
2669
2670 #[tokio::test]
2671 async fn test_completion_without_handler_returns_empty() {
2672 let router = McpRouter::new().server_info("test", "1.0");
2673
2674 let init_req = RouterRequest {
2676 id: RequestId::Number(0),
2677 inner: McpRequest::Initialize(InitializeParams {
2678 protocol_version: "2025-11-25".to_string(),
2679 capabilities: ClientCapabilities::default(),
2680 client_info: Implementation {
2681 name: "test".to_string(),
2682 version: "1.0".to_string(),
2683 ..Default::default()
2684 },
2685 }),
2686 extensions: Extensions::new(),
2687 };
2688 let resp = router
2689 .clone()
2690 .ready()
2691 .await
2692 .unwrap()
2693 .call(init_req)
2694 .await
2695 .unwrap();
2696
2697 match resp.inner {
2699 Ok(McpResponse::Initialize(result)) => {
2700 assert!(result.capabilities.completions.is_none());
2701 }
2702 _ => panic!("Expected Initialize response"),
2703 }
2704
2705 router.handle_notification(McpNotification::Initialized);
2707
2708 let complete_req = RouterRequest {
2710 id: RequestId::Number(1),
2711 inner: McpRequest::Complete(CompleteParams {
2712 reference: CompletionReference::prompt("test-prompt"),
2713 argument: CompletionArgument::new("query", "al"),
2714 }),
2715 extensions: Extensions::new(),
2716 };
2717 let resp = router
2718 .clone()
2719 .ready()
2720 .await
2721 .unwrap()
2722 .call(complete_req)
2723 .await
2724 .unwrap();
2725
2726 match resp.inner {
2727 Ok(McpResponse::Complete(result)) => {
2728 assert!(result.completion.values.is_empty());
2729 }
2730 _ => panic!("Expected Complete response"),
2731 }
2732 }
2733
2734 #[tokio::test]
2735 async fn test_tool_filter_list() {
2736 use crate::filter::CapabilityFilter;
2737 use crate::tool::Tool;
2738
2739 let public_tool = ToolBuilder::new("public")
2740 .description("Public tool")
2741 .handler(|_: AddInput| async move { Ok(CallToolResult::text("public")) })
2742 .build()
2743 .expect("valid tool name");
2744
2745 let admin_tool = ToolBuilder::new("admin")
2746 .description("Admin tool")
2747 .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
2748 .build()
2749 .expect("valid tool name");
2750
2751 let mut router = McpRouter::new()
2752 .tool(public_tool)
2753 .tool(admin_tool)
2754 .tool_filter(CapabilityFilter::new(|_, tool: &Tool| tool.name != "admin"));
2755
2756 init_router(&mut router).await;
2758
2759 let req = RouterRequest {
2760 id: RequestId::Number(1),
2761 inner: McpRequest::ListTools(ListToolsParams::default()),
2762 extensions: Extensions::new(),
2763 };
2764
2765 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2766
2767 match resp.inner {
2768 Ok(McpResponse::ListTools(result)) => {
2769 assert_eq!(result.tools.len(), 1);
2771 assert_eq!(result.tools[0].name, "public");
2772 }
2773 _ => panic!("Expected ListTools response"),
2774 }
2775 }
2776
2777 #[tokio::test]
2778 async fn test_tool_filter_call_denied() {
2779 use crate::filter::CapabilityFilter;
2780 use crate::tool::Tool;
2781
2782 let admin_tool = ToolBuilder::new("admin")
2783 .description("Admin tool")
2784 .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
2785 .build()
2786 .expect("valid tool name");
2787
2788 let mut router = McpRouter::new()
2789 .tool(admin_tool)
2790 .tool_filter(CapabilityFilter::new(|_, _: &Tool| false)); init_router(&mut router).await;
2794
2795 let req = RouterRequest {
2796 id: RequestId::Number(1),
2797 inner: McpRequest::CallTool(CallToolParams {
2798 name: "admin".to_string(),
2799 arguments: serde_json::json!({"a": 1, "b": 2}),
2800 meta: None,
2801 }),
2802 extensions: Extensions::new(),
2803 };
2804
2805 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2806
2807 match resp.inner {
2809 Err(e) => {
2810 assert_eq!(e.code, -32601); }
2812 _ => panic!("Expected JsonRpc error"),
2813 }
2814 }
2815
2816 #[tokio::test]
2817 async fn test_tool_filter_call_allowed() {
2818 use crate::filter::CapabilityFilter;
2819 use crate::tool::Tool;
2820
2821 let public_tool = ToolBuilder::new("public")
2822 .description("Public tool")
2823 .handler(|input: AddInput| async move {
2824 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2825 })
2826 .build()
2827 .expect("valid tool name");
2828
2829 let mut router = McpRouter::new()
2830 .tool(public_tool)
2831 .tool_filter(CapabilityFilter::new(|_, _: &Tool| true)); init_router(&mut router).await;
2835
2836 let req = RouterRequest {
2837 id: RequestId::Number(1),
2838 inner: McpRequest::CallTool(CallToolParams {
2839 name: "public".to_string(),
2840 arguments: serde_json::json!({"a": 1, "b": 2}),
2841 meta: None,
2842 }),
2843 extensions: Extensions::new(),
2844 };
2845
2846 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2847
2848 match resp.inner {
2849 Ok(McpResponse::CallTool(result)) => {
2850 assert!(!result.is_error);
2851 }
2852 _ => panic!("Expected CallTool response"),
2853 }
2854 }
2855
2856 #[tokio::test]
2857 async fn test_tool_filter_custom_denial() {
2858 use crate::filter::{CapabilityFilter, DenialBehavior};
2859 use crate::tool::Tool;
2860
2861 let admin_tool = ToolBuilder::new("admin")
2862 .description("Admin tool")
2863 .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
2864 .build()
2865 .expect("valid tool name");
2866
2867 let mut router = McpRouter::new().tool(admin_tool).tool_filter(
2868 CapabilityFilter::new(|_, _: &Tool| false)
2869 .denial_behavior(DenialBehavior::Unauthorized),
2870 );
2871
2872 init_router(&mut router).await;
2874
2875 let req = RouterRequest {
2876 id: RequestId::Number(1),
2877 inner: McpRequest::CallTool(CallToolParams {
2878 name: "admin".to_string(),
2879 arguments: serde_json::json!({"a": 1, "b": 2}),
2880 meta: None,
2881 }),
2882 extensions: Extensions::new(),
2883 };
2884
2885 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2886
2887 match resp.inner {
2889 Err(e) => {
2890 assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
2892 }
2893 _ => panic!("Expected JsonRpc error"),
2894 }
2895 }
2896
2897 #[tokio::test]
2898 async fn test_resource_filter_list() {
2899 use crate::filter::CapabilityFilter;
2900 use crate::resource::{Resource, ResourceBuilder};
2901
2902 let public_resource = ResourceBuilder::new("file:///public.txt")
2903 .name("Public File")
2904 .text("public content");
2905
2906 let secret_resource = ResourceBuilder::new("file:///secret.txt")
2907 .name("Secret File")
2908 .text("secret content");
2909
2910 let mut router = McpRouter::new()
2911 .resource(public_resource)
2912 .resource(secret_resource)
2913 .resource_filter(CapabilityFilter::new(|_, r: &Resource| {
2914 !r.name.contains("Secret")
2915 }));
2916
2917 init_router(&mut router).await;
2919
2920 let req = RouterRequest {
2921 id: RequestId::Number(1),
2922 inner: McpRequest::ListResources(ListResourcesParams::default()),
2923 extensions: Extensions::new(),
2924 };
2925
2926 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2927
2928 match resp.inner {
2929 Ok(McpResponse::ListResources(result)) => {
2930 assert_eq!(result.resources.len(), 1);
2932 assert_eq!(result.resources[0].name, "Public File");
2933 }
2934 _ => panic!("Expected ListResources response"),
2935 }
2936 }
2937
2938 #[tokio::test]
2939 async fn test_resource_filter_read_denied() {
2940 use crate::filter::CapabilityFilter;
2941 use crate::resource::{Resource, ResourceBuilder};
2942
2943 let secret_resource = ResourceBuilder::new("file:///secret.txt")
2944 .name("Secret File")
2945 .text("secret content");
2946
2947 let mut router = McpRouter::new()
2948 .resource(secret_resource)
2949 .resource_filter(CapabilityFilter::new(|_, _: &Resource| false)); init_router(&mut router).await;
2953
2954 let req = RouterRequest {
2955 id: RequestId::Number(1),
2956 inner: McpRequest::ReadResource(ReadResourceParams {
2957 uri: "file:///secret.txt".to_string(),
2958 }),
2959 extensions: Extensions::new(),
2960 };
2961
2962 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2963
2964 match resp.inner {
2966 Err(e) => {
2967 assert_eq!(e.code, -32601); }
2969 _ => panic!("Expected JsonRpc error"),
2970 }
2971 }
2972
2973 #[tokio::test]
2974 async fn test_resource_filter_read_allowed() {
2975 use crate::filter::CapabilityFilter;
2976 use crate::resource::{Resource, ResourceBuilder};
2977
2978 let public_resource = ResourceBuilder::new("file:///public.txt")
2979 .name("Public File")
2980 .text("public content");
2981
2982 let mut router = McpRouter::new()
2983 .resource(public_resource)
2984 .resource_filter(CapabilityFilter::new(|_, _: &Resource| true)); init_router(&mut router).await;
2988
2989 let req = RouterRequest {
2990 id: RequestId::Number(1),
2991 inner: McpRequest::ReadResource(ReadResourceParams {
2992 uri: "file:///public.txt".to_string(),
2993 }),
2994 extensions: Extensions::new(),
2995 };
2996
2997 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2998
2999 match resp.inner {
3000 Ok(McpResponse::ReadResource(result)) => {
3001 assert_eq!(result.contents.len(), 1);
3002 assert_eq!(result.contents[0].text.as_deref(), Some("public content"));
3003 }
3004 _ => panic!("Expected ReadResource response"),
3005 }
3006 }
3007
3008 #[tokio::test]
3009 async fn test_resource_filter_custom_denial() {
3010 use crate::filter::{CapabilityFilter, DenialBehavior};
3011 use crate::resource::{Resource, ResourceBuilder};
3012
3013 let secret_resource = ResourceBuilder::new("file:///secret.txt")
3014 .name("Secret File")
3015 .text("secret content");
3016
3017 let mut router = McpRouter::new().resource(secret_resource).resource_filter(
3018 CapabilityFilter::new(|_, _: &Resource| false)
3019 .denial_behavior(DenialBehavior::Unauthorized),
3020 );
3021
3022 init_router(&mut router).await;
3024
3025 let req = RouterRequest {
3026 id: RequestId::Number(1),
3027 inner: McpRequest::ReadResource(ReadResourceParams {
3028 uri: "file:///secret.txt".to_string(),
3029 }),
3030 extensions: Extensions::new(),
3031 };
3032
3033 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3034
3035 match resp.inner {
3037 Err(e) => {
3038 assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
3040 }
3041 _ => panic!("Expected JsonRpc error"),
3042 }
3043 }
3044
3045 #[tokio::test]
3046 async fn test_prompt_filter_list() {
3047 use crate::filter::CapabilityFilter;
3048 use crate::prompt::{Prompt, PromptBuilder};
3049
3050 let public_prompt = PromptBuilder::new("greeting")
3051 .description("A greeting")
3052 .user_message("Hello!");
3053
3054 let admin_prompt = PromptBuilder::new("system_debug")
3055 .description("Admin prompt")
3056 .user_message("Debug");
3057
3058 let mut router = McpRouter::new()
3059 .prompt(public_prompt)
3060 .prompt(admin_prompt)
3061 .prompt_filter(CapabilityFilter::new(|_, p: &Prompt| {
3062 !p.name.contains("system")
3063 }));
3064
3065 init_router(&mut router).await;
3067
3068 let req = RouterRequest {
3069 id: RequestId::Number(1),
3070 inner: McpRequest::ListPrompts(ListPromptsParams::default()),
3071 extensions: Extensions::new(),
3072 };
3073
3074 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3075
3076 match resp.inner {
3077 Ok(McpResponse::ListPrompts(result)) => {
3078 assert_eq!(result.prompts.len(), 1);
3080 assert_eq!(result.prompts[0].name, "greeting");
3081 }
3082 _ => panic!("Expected ListPrompts response"),
3083 }
3084 }
3085
3086 #[tokio::test]
3087 async fn test_prompt_filter_get_denied() {
3088 use crate::filter::CapabilityFilter;
3089 use crate::prompt::{Prompt, PromptBuilder};
3090 use std::collections::HashMap;
3091
3092 let admin_prompt = PromptBuilder::new("system_debug")
3093 .description("Admin prompt")
3094 .user_message("Debug");
3095
3096 let mut router = McpRouter::new()
3097 .prompt(admin_prompt)
3098 .prompt_filter(CapabilityFilter::new(|_, _: &Prompt| false)); init_router(&mut router).await;
3102
3103 let req = RouterRequest {
3104 id: RequestId::Number(1),
3105 inner: McpRequest::GetPrompt(GetPromptParams {
3106 name: "system_debug".to_string(),
3107 arguments: HashMap::new(),
3108 }),
3109 extensions: Extensions::new(),
3110 };
3111
3112 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3113
3114 match resp.inner {
3116 Err(e) => {
3117 assert_eq!(e.code, -32601); }
3119 _ => panic!("Expected JsonRpc error"),
3120 }
3121 }
3122
3123 #[tokio::test]
3124 async fn test_prompt_filter_get_allowed() {
3125 use crate::filter::CapabilityFilter;
3126 use crate::prompt::{Prompt, PromptBuilder};
3127 use std::collections::HashMap;
3128
3129 let public_prompt = PromptBuilder::new("greeting")
3130 .description("A greeting")
3131 .user_message("Hello!");
3132
3133 let mut router = McpRouter::new()
3134 .prompt(public_prompt)
3135 .prompt_filter(CapabilityFilter::new(|_, _: &Prompt| true)); init_router(&mut router).await;
3139
3140 let req = RouterRequest {
3141 id: RequestId::Number(1),
3142 inner: McpRequest::GetPrompt(GetPromptParams {
3143 name: "greeting".to_string(),
3144 arguments: HashMap::new(),
3145 }),
3146 extensions: Extensions::new(),
3147 };
3148
3149 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3150
3151 match resp.inner {
3152 Ok(McpResponse::GetPrompt(result)) => {
3153 assert_eq!(result.messages.len(), 1);
3154 }
3155 _ => panic!("Expected GetPrompt response"),
3156 }
3157 }
3158
3159 #[tokio::test]
3160 async fn test_prompt_filter_custom_denial() {
3161 use crate::filter::{CapabilityFilter, DenialBehavior};
3162 use crate::prompt::{Prompt, PromptBuilder};
3163 use std::collections::HashMap;
3164
3165 let admin_prompt = PromptBuilder::new("system_debug")
3166 .description("Admin prompt")
3167 .user_message("Debug");
3168
3169 let mut router = McpRouter::new().prompt(admin_prompt).prompt_filter(
3170 CapabilityFilter::new(|_, _: &Prompt| false)
3171 .denial_behavior(DenialBehavior::Unauthorized),
3172 );
3173
3174 init_router(&mut router).await;
3176
3177 let req = RouterRequest {
3178 id: RequestId::Number(1),
3179 inner: McpRequest::GetPrompt(GetPromptParams {
3180 name: "system_debug".to_string(),
3181 arguments: HashMap::new(),
3182 }),
3183 extensions: Extensions::new(),
3184 };
3185
3186 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3187
3188 match resp.inner {
3190 Err(e) => {
3191 assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
3193 }
3194 _ => panic!("Expected JsonRpc error"),
3195 }
3196 }
3197}