1use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, RwLock};
10use std::task::{Context, Poll};
11
12use tower_service::Service;
13
14use crate::async_task::TaskStore;
15use crate::context::{
16 CancellationToken, ClientRequesterHandle, NotificationSender, RequestContext,
17 ServerNotification,
18};
19use crate::error::{Error, JsonRpcError, Result};
20use crate::filter::{PromptFilter, ResourceFilter, ToolFilter};
21use crate::prompt::Prompt;
22use crate::protocol::*;
23use crate::resource::{Resource, ResourceTemplate};
24use crate::session::SessionState;
25use crate::tool::Tool;
26
27pub type CompletionHandler = Arc<
29 dyn Fn(CompleteParams) -> Pin<Box<dyn Future<Output = Result<CompleteResult>> + Send>>
30 + Send
31 + Sync,
32>;
33
34#[derive(Clone)]
59pub struct McpRouter {
60 inner: Arc<McpRouterInner>,
61 session: SessionState,
62}
63
64impl std::fmt::Debug for McpRouter {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("McpRouter")
67 .field("server_name", &self.inner.server_name)
68 .field("server_version", &self.inner.server_version)
69 .field("tools_count", &self.inner.tools.len())
70 .field("resources_count", &self.inner.resources.len())
71 .field("prompts_count", &self.inner.prompts.len())
72 .field("session_phase", &self.session.phase())
73 .finish()
74 }
75}
76
77#[derive(Clone)]
79struct McpRouterInner {
80 server_name: String,
81 server_version: String,
82 server_title: Option<String>,
84 server_description: Option<String>,
86 server_icons: Option<Vec<ToolIcon>>,
88 server_website_url: Option<String>,
90 instructions: Option<String>,
91 tools: HashMap<String, Arc<Tool>>,
92 resources: HashMap<String, Arc<Resource>>,
93 resource_templates: Vec<Arc<ResourceTemplate>>,
95 prompts: HashMap<String, Arc<Prompt>>,
96 in_flight: Arc<RwLock<HashMap<RequestId, CancellationToken>>>,
98 notification_tx: Option<NotificationSender>,
100 client_requester: Option<ClientRequesterHandle>,
102 task_store: TaskStore,
104 subscriptions: Arc<RwLock<HashSet<String>>>,
106 completion_handler: Option<CompletionHandler>,
108 tool_filter: Option<ToolFilter>,
110 resource_filter: Option<ResourceFilter>,
112 prompt_filter: Option<PromptFilter>,
114 extensions: Arc<crate::context::Extensions>,
116}
117
118impl McpRouter {
119 pub fn new() -> Self {
121 Self {
122 inner: Arc::new(McpRouterInner {
123 server_name: "tower-mcp".to_string(),
124 server_version: env!("CARGO_PKG_VERSION").to_string(),
125 server_title: None,
126 server_description: None,
127 server_icons: None,
128 server_website_url: None,
129 instructions: None,
130 tools: HashMap::new(),
131 resources: HashMap::new(),
132 resource_templates: Vec::new(),
133 prompts: HashMap::new(),
134 in_flight: Arc::new(RwLock::new(HashMap::new())),
135 notification_tx: None,
136 client_requester: None,
137 task_store: TaskStore::new(),
138 subscriptions: Arc::new(RwLock::new(HashSet::new())),
139 extensions: Arc::new(crate::context::Extensions::new()),
140 completion_handler: None,
141 tool_filter: None,
142 resource_filter: None,
143 prompt_filter: None,
144 }),
145 session: SessionState::new(),
146 }
147 }
148
149 pub fn task_store(&self) -> &TaskStore {
151 &self.inner.task_store
152 }
153
154 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
158 Arc::make_mut(&mut self.inner).notification_tx = Some(tx);
159 self
160 }
161
162 pub fn notification_sender(&self) -> Option<&NotificationSender> {
164 self.inner.notification_tx.as_ref()
165 }
166
167 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
172 Arc::make_mut(&mut self.inner).client_requester = Some(requester);
173 self
174 }
175
176 pub fn client_requester(&self) -> Option<&ClientRequesterHandle> {
178 self.inner.client_requester.as_ref()
179 }
180
181 pub fn with_state<T: Clone + Send + Sync + 'static>(mut self, state: T) -> Self {
225 let inner = Arc::make_mut(&mut self.inner);
226 Arc::make_mut(&mut inner.extensions).insert(state);
227 self
228 }
229
230 pub fn with_extension<T: Clone + Send + Sync + 'static>(self, value: T) -> Self {
235 self.with_state(value)
236 }
237
238 pub fn extensions(&self) -> &crate::context::Extensions {
240 &self.inner.extensions
241 }
242
243 pub fn create_context(
248 &self,
249 request_id: RequestId,
250 progress_token: Option<ProgressToken>,
251 ) -> RequestContext {
252 let ctx = RequestContext::new(request_id.clone());
253
254 let ctx = if let Some(token) = progress_token {
256 ctx.with_progress_token(token)
257 } else {
258 ctx
259 };
260
261 let ctx = if let Some(tx) = &self.inner.notification_tx {
263 ctx.with_notification_sender(tx.clone())
264 } else {
265 ctx
266 };
267
268 let ctx = if let Some(requester) = &self.inner.client_requester {
270 ctx.with_client_requester(requester.clone())
271 } else {
272 ctx
273 };
274
275 let ctx = ctx.with_extensions(self.inner.extensions.clone());
277
278 let token = ctx.cancellation_token();
280 if let Ok(mut in_flight) = self.inner.in_flight.write() {
281 in_flight.insert(request_id, token);
282 }
283
284 ctx
285 }
286
287 pub fn complete_request(&self, request_id: &RequestId) {
289 if let Ok(mut in_flight) = self.inner.in_flight.write() {
290 in_flight.remove(request_id);
291 }
292 }
293
294 fn cancel_request(&self, request_id: &RequestId) -> bool {
296 let Ok(in_flight) = self.inner.in_flight.read() else {
297 return false;
298 };
299 let Some(token) = in_flight.get(request_id) else {
300 return false;
301 };
302 token.cancel();
303 true
304 }
305
306 pub fn server_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
308 let inner = Arc::make_mut(&mut self.inner);
309 inner.server_name = name.into();
310 inner.server_version = version.into();
311 self
312 }
313
314 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
316 Arc::make_mut(&mut self.inner).instructions = Some(instructions.into());
317 self
318 }
319
320 pub fn server_title(mut self, title: impl Into<String>) -> Self {
322 Arc::make_mut(&mut self.inner).server_title = Some(title.into());
323 self
324 }
325
326 pub fn server_description(mut self, description: impl Into<String>) -> Self {
328 Arc::make_mut(&mut self.inner).server_description = Some(description.into());
329 self
330 }
331
332 pub fn server_icons(mut self, icons: Vec<ToolIcon>) -> Self {
334 Arc::make_mut(&mut self.inner).server_icons = Some(icons);
335 self
336 }
337
338 pub fn server_website_url(mut self, url: impl Into<String>) -> Self {
340 Arc::make_mut(&mut self.inner).server_website_url = Some(url.into());
341 self
342 }
343
344 pub fn tool(mut self, tool: Tool) -> Self {
346 Arc::make_mut(&mut self.inner)
347 .tools
348 .insert(tool.name.clone(), Arc::new(tool));
349 self
350 }
351
352 pub fn resource(mut self, resource: Resource) -> Self {
354 Arc::make_mut(&mut self.inner)
355 .resources
356 .insert(resource.uri.clone(), Arc::new(resource));
357 self
358 }
359
360 pub fn resource_template(mut self, template: ResourceTemplate) -> Self {
391 Arc::make_mut(&mut self.inner)
392 .resource_templates
393 .push(Arc::new(template));
394 self
395 }
396
397 pub fn prompt(mut self, prompt: Prompt) -> Self {
399 Arc::make_mut(&mut self.inner)
400 .prompts
401 .insert(prompt.name.clone(), Arc::new(prompt));
402 self
403 }
404
405 pub fn tools(self, tools: impl IntoIterator<Item = Tool>) -> Self {
431 tools
432 .into_iter()
433 .fold(self, |router, tool| router.tool(tool))
434 }
435
436 pub fn resources(self, resources: impl IntoIterator<Item = Resource>) -> Self {
455 resources
456 .into_iter()
457 .fold(self, |router, resource| router.resource(resource))
458 }
459
460 pub fn prompts(self, prompts: impl IntoIterator<Item = Prompt>) -> Self {
479 prompts
480 .into_iter()
481 .fold(self, |router, prompt| router.prompt(prompt))
482 }
483
484 pub fn merge(mut self, other: McpRouter) -> Self {
531 let inner = Arc::make_mut(&mut self.inner);
532 let other_inner = other.inner;
533
534 for (name, tool) in &other_inner.tools {
536 inner.tools.insert(name.clone(), tool.clone());
537 }
538
539 for (uri, resource) in &other_inner.resources {
541 inner.resources.insert(uri.clone(), resource.clone());
542 }
543
544 for template in &other_inner.resource_templates {
547 inner.resource_templates.push(template.clone());
548 }
549
550 for (name, prompt) in &other_inner.prompts {
552 inner.prompts.insert(name.clone(), prompt.clone());
553 }
554
555 self
556 }
557
558 pub fn nest(mut self, prefix: impl Into<String>, other: McpRouter) -> Self {
600 let prefix = prefix.into();
601 let inner = Arc::make_mut(&mut self.inner);
602 let other_inner = other.inner;
603
604 for tool in other_inner.tools.values() {
606 let prefixed_tool = tool.with_name_prefix(&prefix);
607 inner
608 .tools
609 .insert(prefixed_tool.name.clone(), Arc::new(prefixed_tool));
610 }
611
612 for (uri, resource) in &other_inner.resources {
614 inner.resources.insert(uri.clone(), resource.clone());
615 }
616
617 for template in &other_inner.resource_templates {
619 inner.resource_templates.push(template.clone());
620 }
621
622 for (name, prompt) in &other_inner.prompts {
624 inner.prompts.insert(name.clone(), prompt.clone());
625 }
626
627 self
628 }
629
630 pub fn completion_handler<F, Fut>(mut self, handler: F) -> Self
657 where
658 F: Fn(CompleteParams) -> Fut + Send + Sync + 'static,
659 Fut: Future<Output = Result<CompleteResult>> + Send + 'static,
660 {
661 Arc::make_mut(&mut self.inner).completion_handler =
662 Some(Arc::new(move |params| Box::pin(handler(params))));
663 self
664 }
665
666 pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
703 Arc::make_mut(&mut self.inner).tool_filter = Some(filter);
704 self
705 }
706
707 pub fn resource_filter(mut self, filter: ResourceFilter) -> Self {
738 Arc::make_mut(&mut self.inner).resource_filter = Some(filter);
739 self
740 }
741
742 pub fn prompt_filter(mut self, filter: PromptFilter) -> Self {
771 Arc::make_mut(&mut self.inner).prompt_filter = Some(filter);
772 self
773 }
774
775 pub fn session(&self) -> &SessionState {
777 &self.session
778 }
779
780 pub fn log(&self, params: LoggingMessageParams) -> bool {
802 let Some(tx) = &self.inner.notification_tx else {
803 return false;
804 };
805 tx.try_send(ServerNotification::LogMessage(params)).is_ok()
806 }
807
808 pub fn log_info(&self, message: &str) -> bool {
812 self.log(
813 LoggingMessageParams::new(LogLevel::Info)
814 .with_data(serde_json::json!({ "message": message })),
815 )
816 }
817
818 pub fn log_warning(&self, message: &str) -> bool {
820 self.log(
821 LoggingMessageParams::new(LogLevel::Warning)
822 .with_data(serde_json::json!({ "message": message })),
823 )
824 }
825
826 pub fn log_error(&self, message: &str) -> bool {
828 self.log(
829 LoggingMessageParams::new(LogLevel::Error)
830 .with_data(serde_json::json!({ "message": message })),
831 )
832 }
833
834 pub fn log_debug(&self, message: &str) -> bool {
836 self.log(
837 LoggingMessageParams::new(LogLevel::Debug)
838 .with_data(serde_json::json!({ "message": message })),
839 )
840 }
841
842 pub fn is_subscribed(&self, uri: &str) -> bool {
844 if let Ok(subs) = self.inner.subscriptions.read() {
845 return subs.contains(uri);
846 }
847 false
848 }
849
850 pub fn subscribed_uris(&self) -> Vec<String> {
852 if let Ok(subs) = self.inner.subscriptions.read() {
853 return subs.iter().cloned().collect();
854 }
855 Vec::new()
856 }
857
858 fn subscribe(&self, uri: &str) -> bool {
860 if let Ok(mut subs) = self.inner.subscriptions.write() {
861 return subs.insert(uri.to_string());
862 }
863 false
864 }
865
866 fn unsubscribe(&self, uri: &str) -> bool {
868 if let Ok(mut subs) = self.inner.subscriptions.write() {
869 return subs.remove(uri);
870 }
871 false
872 }
873
874 pub fn notify_resource_updated(&self, uri: &str) -> bool {
879 if !self.is_subscribed(uri) {
881 return false;
882 }
883
884 let Some(tx) = &self.inner.notification_tx else {
885 return false;
886 };
887 tx.try_send(ServerNotification::ResourceUpdated {
888 uri: uri.to_string(),
889 })
890 .is_ok()
891 }
892
893 pub fn notify_resources_list_changed(&self) -> bool {
897 let Some(tx) = &self.inner.notification_tx else {
898 return false;
899 };
900 tx.try_send(ServerNotification::ResourcesListChanged)
901 .is_ok()
902 }
903
904 fn capabilities(&self) -> ServerCapabilities {
906 let has_resources =
907 !self.inner.resources.is_empty() || !self.inner.resource_templates.is_empty();
908
909 ServerCapabilities {
910 tools: if self.inner.tools.is_empty() {
911 None
912 } else {
913 Some(ToolsCapability::default())
914 },
915 resources: if has_resources {
916 Some(ResourcesCapability {
917 subscribe: true,
918 ..Default::default()
919 })
920 } else {
921 None
922 },
923 prompts: if self.inner.prompts.is_empty() {
924 None
925 } else {
926 Some(PromptsCapability::default())
927 },
928 logging: if self.inner.notification_tx.is_some() {
930 Some(LoggingCapability::default())
931 } else {
932 None
933 },
934 tasks: Some(TasksCapability::default()),
936 completions: if self.inner.completion_handler.is_some() {
938 Some(CompletionsCapability::default())
939 } else {
940 None
941 },
942 }
943 }
944
945 async fn handle(&self, request_id: RequestId, request: McpRequest) -> Result<McpResponse> {
947 let method = request.method_name();
949 if !self.session.is_request_allowed(method) {
950 tracing::warn!(
951 method = %method,
952 phase = ?self.session.phase(),
953 "Request rejected: session not initialized"
954 );
955 return Err(Error::JsonRpc(JsonRpcError::invalid_request(format!(
956 "Session not initialized. Only 'initialize' and 'ping' are allowed before initialization. Got: {}",
957 method
958 ))));
959 }
960
961 match request {
962 McpRequest::Initialize(params) => {
963 tracing::info!(
964 client = %params.client_info.name,
965 version = %params.client_info.version,
966 "Client initializing"
967 );
968
969 let protocol_version = if crate::protocol::SUPPORTED_PROTOCOL_VERSIONS
972 .contains(¶ms.protocol_version.as_str())
973 {
974 params.protocol_version
975 } else {
976 crate::protocol::LATEST_PROTOCOL_VERSION.to_string()
977 };
978
979 self.session.mark_initializing();
981
982 Ok(McpResponse::Initialize(InitializeResult {
983 protocol_version,
984 capabilities: self.capabilities(),
985 server_info: Implementation {
986 name: self.inner.server_name.clone(),
987 version: self.inner.server_version.clone(),
988 title: self.inner.server_title.clone(),
989 description: self.inner.server_description.clone(),
990 icons: self.inner.server_icons.clone(),
991 website_url: self.inner.server_website_url.clone(),
992 },
993 instructions: self.inner.instructions.clone(),
994 }))
995 }
996
997 McpRequest::ListTools(_params) => {
998 let tools: Vec<ToolDefinition> = self
999 .inner
1000 .tools
1001 .values()
1002 .filter(|t| {
1003 self.inner
1005 .tool_filter
1006 .as_ref()
1007 .map(|f| f.is_visible(&self.session, t))
1008 .unwrap_or(true)
1009 })
1010 .map(|t| t.definition())
1011 .collect();
1012
1013 Ok(McpResponse::ListTools(ListToolsResult {
1014 tools,
1015 next_cursor: None,
1016 }))
1017 }
1018
1019 McpRequest::CallTool(params) => {
1020 let tool =
1021 self.inner.tools.get(¶ms.name).ok_or_else(|| {
1022 Error::JsonRpc(JsonRpcError::method_not_found(¶ms.name))
1023 })?;
1024
1025 if let Some(filter) = &self.inner.tool_filter {
1027 if !filter.is_visible(&self.session, tool) {
1028 return Err(filter.denial_error(¶ms.name));
1029 }
1030 }
1031
1032 let progress_token = params.meta.and_then(|m| m.progress_token);
1034 let ctx = self.create_context(request_id, progress_token);
1035
1036 tracing::debug!(tool = %params.name, "Calling tool");
1037 let result = tool.call_with_context(ctx, params.arguments).await;
1038
1039 Ok(McpResponse::CallTool(result))
1040 }
1041
1042 McpRequest::ListResources(_params) => {
1043 let resources: Vec<ResourceDefinition> = self
1044 .inner
1045 .resources
1046 .values()
1047 .filter(|r| {
1048 self.inner
1050 .resource_filter
1051 .as_ref()
1052 .map(|f| f.is_visible(&self.session, r))
1053 .unwrap_or(true)
1054 })
1055 .map(|r| r.definition())
1056 .collect();
1057
1058 Ok(McpResponse::ListResources(ListResourcesResult {
1059 resources,
1060 next_cursor: None,
1061 }))
1062 }
1063
1064 McpRequest::ListResourceTemplates(_params) => {
1065 let resource_templates: Vec<ResourceTemplateDefinition> = self
1066 .inner
1067 .resource_templates
1068 .iter()
1069 .map(|t| t.definition())
1070 .collect();
1071
1072 Ok(McpResponse::ListResourceTemplates(
1073 ListResourceTemplatesResult {
1074 resource_templates,
1075 next_cursor: None,
1076 },
1077 ))
1078 }
1079
1080 McpRequest::ReadResource(params) => {
1081 if let Some(resource) = self.inner.resources.get(¶ms.uri) {
1083 if let Some(filter) = &self.inner.resource_filter {
1085 if !filter.is_visible(&self.session, resource) {
1086 return Err(filter.denial_error(¶ms.uri));
1087 }
1088 }
1089
1090 tracing::debug!(uri = %params.uri, "Reading static resource");
1091 let result = resource.read().await;
1092 return Ok(McpResponse::ReadResource(result));
1093 }
1094
1095 for template in &self.inner.resource_templates {
1097 if let Some(variables) = template.match_uri(¶ms.uri) {
1098 tracing::debug!(
1099 uri = %params.uri,
1100 template = %template.uri_template,
1101 "Reading resource via template"
1102 );
1103 let result = template.read(¶ms.uri, variables).await?;
1104 return Ok(McpResponse::ReadResource(result));
1105 }
1106 }
1107
1108 Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1110 ¶ms.uri,
1111 )))
1112 }
1113
1114 McpRequest::SubscribeResource(params) => {
1115 if !self.inner.resources.contains_key(¶ms.uri) {
1117 return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1118 ¶ms.uri,
1119 )));
1120 }
1121
1122 tracing::debug!(uri = %params.uri, "Subscribing to resource");
1123 self.subscribe(¶ms.uri);
1124
1125 Ok(McpResponse::SubscribeResource(EmptyResult {}))
1126 }
1127
1128 McpRequest::UnsubscribeResource(params) => {
1129 if !self.inner.resources.contains_key(¶ms.uri) {
1131 return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1132 ¶ms.uri,
1133 )));
1134 }
1135
1136 tracing::debug!(uri = %params.uri, "Unsubscribing from resource");
1137 self.unsubscribe(¶ms.uri);
1138
1139 Ok(McpResponse::UnsubscribeResource(EmptyResult {}))
1140 }
1141
1142 McpRequest::ListPrompts(_params) => {
1143 let prompts: Vec<PromptDefinition> = self
1144 .inner
1145 .prompts
1146 .values()
1147 .filter(|p| {
1148 self.inner
1150 .prompt_filter
1151 .as_ref()
1152 .map(|f| f.is_visible(&self.session, p))
1153 .unwrap_or(true)
1154 })
1155 .map(|p| p.definition())
1156 .collect();
1157
1158 Ok(McpResponse::ListPrompts(ListPromptsResult {
1159 prompts,
1160 next_cursor: None,
1161 }))
1162 }
1163
1164 McpRequest::GetPrompt(params) => {
1165 let prompt = self.inner.prompts.get(¶ms.name).ok_or_else(|| {
1166 Error::JsonRpc(JsonRpcError::method_not_found(&format!(
1167 "Prompt not found: {}",
1168 params.name
1169 )))
1170 })?;
1171
1172 if let Some(filter) = &self.inner.prompt_filter {
1174 if !filter.is_visible(&self.session, prompt) {
1175 return Err(filter.denial_error(¶ms.name));
1176 }
1177 }
1178
1179 tracing::debug!(name = %params.name, "Getting prompt");
1180 let result = prompt.get(params.arguments).await?;
1181
1182 Ok(McpResponse::GetPrompt(result))
1183 }
1184
1185 McpRequest::Ping => Ok(McpResponse::Pong(EmptyResult {})),
1186
1187 McpRequest::EnqueueTask(params) => {
1188 let tool = self.inner.tools.get(¶ms.tool_name).ok_or_else(|| {
1190 Error::JsonRpc(JsonRpcError::method_not_found(&format!(
1191 "Tool not found: {}",
1192 params.tool_name
1193 )))
1194 })?;
1195
1196 let (task_id, cancellation_token) = self.inner.task_store.create_task(
1198 ¶ms.tool_name,
1199 params.arguments.clone(),
1200 params.ttl,
1201 );
1202
1203 tracing::info!(task_id = %task_id, tool = %params.tool_name, "Enqueued async task");
1204
1205 let ctx = self.create_context(request_id, None);
1207
1208 let task_store = self.inner.task_store.clone();
1210 let tool = tool.clone();
1211 let arguments = params.arguments;
1212 let task_id_clone = task_id.clone();
1213
1214 tokio::spawn(async move {
1215 if cancellation_token.is_cancelled() {
1217 tracing::debug!(task_id = %task_id_clone, "Task cancelled before execution");
1218 return;
1219 }
1220
1221 let result = tool.call_with_context(ctx, arguments).await;
1223
1224 if cancellation_token.is_cancelled() {
1225 tracing::debug!(task_id = %task_id_clone, "Task cancelled during execution");
1226 } else if result.is_error {
1227 let error_msg = result.first_text().unwrap_or("Tool execution failed");
1229 task_store.fail_task(&task_id_clone, error_msg);
1230 tracing::warn!(task_id = %task_id_clone, error = %error_msg, "Task failed");
1231 } else {
1232 task_store.complete_task(&task_id_clone, result);
1233 tracing::debug!(task_id = %task_id_clone, "Task completed successfully");
1234 }
1235 });
1236
1237 Ok(McpResponse::EnqueueTask(EnqueueTaskResult {
1238 task_id,
1239 status: TaskStatus::Working,
1240 poll_interval: Some(2),
1241 }))
1242 }
1243
1244 McpRequest::ListTasks(params) => {
1245 let tasks = self.inner.task_store.list_tasks(params.status);
1246
1247 Ok(McpResponse::ListTasks(ListTasksResult {
1248 tasks,
1249 next_cursor: None,
1250 }))
1251 }
1252
1253 McpRequest::GetTaskInfo(params) => {
1254 let task = self
1255 .inner
1256 .task_store
1257 .get_task(¶ms.task_id)
1258 .ok_or_else(|| {
1259 Error::JsonRpc(JsonRpcError::invalid_params(format!(
1260 "Task not found: {}",
1261 params.task_id
1262 )))
1263 })?;
1264
1265 Ok(McpResponse::GetTaskInfo(task))
1266 }
1267
1268 McpRequest::GetTaskResult(params) => {
1269 let (status, result, error) = self
1270 .inner
1271 .task_store
1272 .get_task_full(¶ms.task_id)
1273 .ok_or_else(|| {
1274 Error::JsonRpc(JsonRpcError::invalid_params(format!(
1275 "Task not found: {}",
1276 params.task_id
1277 )))
1278 })?;
1279
1280 Ok(McpResponse::GetTaskResult(GetTaskResultResult {
1281 task_id: params.task_id,
1282 status,
1283 result,
1284 error,
1285 }))
1286 }
1287
1288 McpRequest::CancelTask(params) => {
1289 let status = self
1290 .inner
1291 .task_store
1292 .cancel_task(¶ms.task_id, params.reason.as_deref())
1293 .ok_or_else(|| {
1294 Error::JsonRpc(JsonRpcError::invalid_params(format!(
1295 "Task not found: {}",
1296 params.task_id
1297 )))
1298 })?;
1299
1300 let cancelled = status == TaskStatus::Cancelled;
1301
1302 Ok(McpResponse::CancelTask(CancelTaskResult {
1303 cancelled,
1304 status,
1305 }))
1306 }
1307
1308 McpRequest::SetLoggingLevel(params) => {
1309 tracing::debug!(level = ?params.level, "Client set logging level");
1313 Ok(McpResponse::SetLoggingLevel(EmptyResult {}))
1314 }
1315
1316 McpRequest::Complete(params) => {
1317 tracing::debug!(
1318 reference = ?params.reference,
1319 argument = %params.argument.name,
1320 "Completion request"
1321 );
1322
1323 if let Some(ref handler) = self.inner.completion_handler {
1325 let result = handler(params).await?;
1326 Ok(McpResponse::Complete(result))
1327 } else {
1328 Ok(McpResponse::Complete(CompleteResult::new(vec![])))
1330 }
1331 }
1332
1333 McpRequest::Unknown { method, .. } => {
1334 Err(Error::JsonRpc(JsonRpcError::method_not_found(&method)))
1335 }
1336 }
1337 }
1338
1339 pub fn handle_notification(&self, notification: McpNotification) {
1341 match notification {
1342 McpNotification::Initialized => {
1343 if self.session.mark_initialized() {
1344 tracing::info!("Session initialized, entering operation phase");
1345 } else {
1346 tracing::warn!(
1347 "Received initialized notification in unexpected state: {:?}",
1348 self.session.phase()
1349 );
1350 }
1351 }
1352 McpNotification::Cancelled(params) => {
1353 if self.cancel_request(¶ms.request_id) {
1354 tracing::info!(
1355 request_id = ?params.request_id,
1356 reason = ?params.reason,
1357 "Request cancelled"
1358 );
1359 } else {
1360 tracing::debug!(
1361 request_id = ?params.request_id,
1362 reason = ?params.reason,
1363 "Cancellation requested for unknown request"
1364 );
1365 }
1366 }
1367 McpNotification::Progress(params) => {
1368 tracing::trace!(
1369 token = ?params.progress_token,
1370 progress = params.progress,
1371 total = ?params.total,
1372 "Progress notification"
1373 );
1374 }
1376 McpNotification::RootsListChanged => {
1377 tracing::info!("Client roots list changed");
1378 }
1381 McpNotification::Unknown { method, .. } => {
1382 tracing::debug!(method = %method, "Unknown notification received");
1383 }
1384 }
1385 }
1386}
1387
1388impl Default for McpRouter {
1389 fn default() -> Self {
1390 Self::new()
1391 }
1392}
1393
1394pub use crate::context::Extensions;
1400
1401#[derive(Debug)]
1403pub struct RouterRequest {
1404 pub id: RequestId,
1405 pub inner: McpRequest,
1406 pub extensions: Extensions,
1408}
1409
1410#[derive(Debug)]
1412pub struct RouterResponse {
1413 pub id: RequestId,
1414 pub inner: std::result::Result<McpResponse, JsonRpcError>,
1415}
1416
1417impl RouterResponse {
1418 pub fn into_jsonrpc(self) -> JsonRpcResponse {
1420 match self.inner {
1421 Ok(response) => match serde_json::to_value(response) {
1422 Ok(result) => JsonRpcResponse::result(self.id, result),
1423 Err(e) => {
1424 tracing::error!(error = %e, "Failed to serialize response");
1425 JsonRpcResponse::error(
1426 Some(self.id),
1427 JsonRpcError::internal_error(format!("Serialization error: {}", e)),
1428 )
1429 }
1430 },
1431 Err(error) => JsonRpcResponse::error(Some(self.id), error),
1432 }
1433 }
1434}
1435
1436impl Service<RouterRequest> for McpRouter {
1437 type Response = RouterResponse;
1438 type Error = std::convert::Infallible; type Future =
1440 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
1441
1442 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
1443 Poll::Ready(Ok(()))
1444 }
1445
1446 fn call(&mut self, req: RouterRequest) -> Self::Future {
1447 let router = self.clone();
1448 let request_id = req.id.clone();
1449 Box::pin(async move {
1450 let result = router.handle(req.id, req.inner).await;
1451 router.complete_request(&request_id);
1453 Ok(RouterResponse {
1454 id: request_id,
1455 inner: result.map_err(|e| match e {
1460 Error::JsonRpc(err) => err,
1461 Error::Tool(err) => JsonRpcError::internal_error(err.to_string()),
1462 e => JsonRpcError::internal_error(e.to_string()),
1463 }),
1464 })
1465 })
1466 }
1467}
1468
1469#[cfg(test)]
1470mod tests {
1471 use super::*;
1472 use crate::extract::{Context, Json};
1473 use crate::jsonrpc::JsonRpcService;
1474 use crate::tool::ToolBuilder;
1475 use schemars::JsonSchema;
1476 use serde::Deserialize;
1477 use tower::ServiceExt;
1478
1479 #[derive(Debug, Deserialize, JsonSchema)]
1480 struct AddInput {
1481 a: i64,
1482 b: i64,
1483 }
1484
1485 async fn init_router(router: &mut McpRouter) {
1487 let init_req = RouterRequest {
1489 id: RequestId::Number(0),
1490 inner: McpRequest::Initialize(InitializeParams {
1491 protocol_version: "2025-11-25".to_string(),
1492 capabilities: ClientCapabilities {
1493 roots: None,
1494 sampling: None,
1495 elicitation: None,
1496 },
1497 client_info: Implementation {
1498 name: "test".to_string(),
1499 version: "1.0".to_string(),
1500 ..Default::default()
1501 },
1502 }),
1503 extensions: Extensions::new(),
1504 };
1505 let _ = router.ready().await.unwrap().call(init_req).await.unwrap();
1506 router.handle_notification(McpNotification::Initialized);
1508 }
1509
1510 #[tokio::test]
1511 async fn test_router_list_tools() {
1512 let add_tool = ToolBuilder::new("add")
1513 .description("Add two numbers")
1514 .handler(|input: AddInput| async move {
1515 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1516 })
1517 .build()
1518 .expect("valid tool name");
1519
1520 let mut router = McpRouter::new().tool(add_tool);
1521
1522 init_router(&mut router).await;
1524
1525 let req = RouterRequest {
1526 id: RequestId::Number(1),
1527 inner: McpRequest::ListTools(ListToolsParams::default()),
1528 extensions: Extensions::new(),
1529 };
1530
1531 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1532
1533 match resp.inner {
1534 Ok(McpResponse::ListTools(result)) => {
1535 assert_eq!(result.tools.len(), 1);
1536 assert_eq!(result.tools[0].name, "add");
1537 }
1538 _ => panic!("Expected ListTools response"),
1539 }
1540 }
1541
1542 #[tokio::test]
1543 async fn test_router_call_tool() {
1544 let add_tool = ToolBuilder::new("add")
1545 .description("Add two numbers")
1546 .handler(|input: AddInput| async move {
1547 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1548 })
1549 .build()
1550 .expect("valid tool name");
1551
1552 let mut router = McpRouter::new().tool(add_tool);
1553
1554 init_router(&mut router).await;
1556
1557 let req = RouterRequest {
1558 id: RequestId::Number(1),
1559 inner: McpRequest::CallTool(CallToolParams {
1560 name: "add".to_string(),
1561 arguments: serde_json::json!({"a": 2, "b": 3}),
1562 meta: None,
1563 }),
1564 extensions: Extensions::new(),
1565 };
1566
1567 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1568
1569 match resp.inner {
1570 Ok(McpResponse::CallTool(result)) => {
1571 assert!(!result.is_error);
1572 match &result.content[0] {
1574 Content::Text { text, .. } => assert_eq!(text, "5"),
1575 _ => panic!("Expected text content"),
1576 }
1577 }
1578 _ => panic!("Expected CallTool response"),
1579 }
1580 }
1581
1582 async fn init_jsonrpc_service(service: &mut JsonRpcService<McpRouter>, router: &McpRouter) {
1584 let init_req = JsonRpcRequest::new(0, "initialize").with_params(serde_json::json!({
1585 "protocolVersion": "2025-11-25",
1586 "capabilities": {},
1587 "clientInfo": { "name": "test", "version": "1.0" }
1588 }));
1589 let _ = service.call_single(init_req).await.unwrap();
1590 router.handle_notification(McpNotification::Initialized);
1591 }
1592
1593 #[tokio::test]
1594 async fn test_jsonrpc_service() {
1595 let add_tool = ToolBuilder::new("add")
1596 .description("Add two numbers")
1597 .handler(|input: AddInput| async move {
1598 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1599 })
1600 .build()
1601 .expect("valid tool name");
1602
1603 let router = McpRouter::new().tool(add_tool);
1604 let mut service = JsonRpcService::new(router.clone());
1605
1606 init_jsonrpc_service(&mut service, &router).await;
1608
1609 let req = JsonRpcRequest::new(1, "tools/list");
1610
1611 let resp = service.call_single(req).await.unwrap();
1612
1613 match resp {
1614 JsonRpcResponse::Result(r) => {
1615 assert_eq!(r.id, RequestId::Number(1));
1616 let tools = r.result.get("tools").unwrap().as_array().unwrap();
1617 assert_eq!(tools.len(), 1);
1618 }
1619 JsonRpcResponse::Error(_) => panic!("Expected success response"),
1620 }
1621 }
1622
1623 #[tokio::test]
1624 async fn test_batch_request() {
1625 let add_tool = ToolBuilder::new("add")
1626 .description("Add two numbers")
1627 .handler(|input: AddInput| async move {
1628 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1629 })
1630 .build()
1631 .expect("valid tool name");
1632
1633 let router = McpRouter::new().tool(add_tool);
1634 let mut service = JsonRpcService::new(router.clone());
1635
1636 init_jsonrpc_service(&mut service, &router).await;
1638
1639 let requests = vec![
1641 JsonRpcRequest::new(1, "tools/list"),
1642 JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1643 "name": "add",
1644 "arguments": {"a": 10, "b": 20}
1645 })),
1646 JsonRpcRequest::new(3, "ping"),
1647 ];
1648
1649 let responses = service.call_batch(requests).await.unwrap();
1650
1651 assert_eq!(responses.len(), 3);
1652
1653 match &responses[0] {
1655 JsonRpcResponse::Result(r) => {
1656 assert_eq!(r.id, RequestId::Number(1));
1657 let tools = r.result.get("tools").unwrap().as_array().unwrap();
1658 assert_eq!(tools.len(), 1);
1659 }
1660 JsonRpcResponse::Error(_) => panic!("Expected success for tools/list"),
1661 }
1662
1663 match &responses[1] {
1665 JsonRpcResponse::Result(r) => {
1666 assert_eq!(r.id, RequestId::Number(2));
1667 let content = r.result.get("content").unwrap().as_array().unwrap();
1668 let text = content[0].get("text").unwrap().as_str().unwrap();
1669 assert_eq!(text, "30");
1670 }
1671 JsonRpcResponse::Error(_) => panic!("Expected success for tools/call"),
1672 }
1673
1674 match &responses[2] {
1676 JsonRpcResponse::Result(r) => {
1677 assert_eq!(r.id, RequestId::Number(3));
1678 }
1679 JsonRpcResponse::Error(_) => panic!("Expected success for ping"),
1680 }
1681 }
1682
1683 #[tokio::test]
1684 async fn test_empty_batch_error() {
1685 let router = McpRouter::new();
1686 let mut service = JsonRpcService::new(router);
1687
1688 let result = service.call_batch(vec![]).await;
1689 assert!(result.is_err());
1690 }
1691
1692 #[tokio::test]
1697 async fn test_progress_token_extraction() {
1698 use crate::context::{ServerNotification, notification_channel};
1699 use crate::protocol::ProgressToken;
1700 use std::sync::Arc;
1701 use std::sync::atomic::{AtomicBool, Ordering};
1702
1703 let progress_reported = Arc::new(AtomicBool::new(false));
1705 let progress_ref = progress_reported.clone();
1706
1707 let tool = ToolBuilder::new("progress_tool")
1709 .description("Tool that reports progress")
1710 .extractor_handler_typed::<_, _, _, AddInput>(
1711 (),
1712 move |ctx: Context, Json(_input): Json<AddInput>| {
1713 let reported = progress_ref.clone();
1714 async move {
1715 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
1717 .await;
1718 reported.store(true, Ordering::SeqCst);
1719 Ok(CallToolResult::text("done"))
1720 }
1721 },
1722 )
1723 .build()
1724 .expect("valid tool name");
1725
1726 let (tx, mut rx) = notification_channel(10);
1728 let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1729 let mut service = JsonRpcService::new(router.clone());
1730
1731 init_jsonrpc_service(&mut service, &router).await;
1733
1734 let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1736 "name": "progress_tool",
1737 "arguments": {"a": 1, "b": 2},
1738 "_meta": {
1739 "progressToken": "test-token-123"
1740 }
1741 }));
1742
1743 let resp = service.call_single(req).await.unwrap();
1744
1745 match resp {
1747 JsonRpcResponse::Result(_) => {}
1748 JsonRpcResponse::Error(e) => panic!("Expected success, got error: {:?}", e),
1749 }
1750
1751 assert!(progress_reported.load(Ordering::SeqCst));
1753
1754 let notification = rx.try_recv().expect("Expected progress notification");
1756 match notification {
1757 ServerNotification::Progress(params) => {
1758 assert_eq!(
1759 params.progress_token,
1760 ProgressToken::String("test-token-123".to_string())
1761 );
1762 assert_eq!(params.progress, 50.0);
1763 assert_eq!(params.total, Some(100.0));
1764 assert_eq!(params.message.as_deref(), Some("Halfway"));
1765 }
1766 _ => panic!("Expected Progress notification"),
1767 }
1768 }
1769
1770 #[tokio::test]
1771 async fn test_tool_call_without_progress_token() {
1772 use crate::context::notification_channel;
1773 use std::sync::Arc;
1774 use std::sync::atomic::{AtomicBool, Ordering};
1775
1776 let progress_attempted = Arc::new(AtomicBool::new(false));
1777 let progress_ref = progress_attempted.clone();
1778
1779 let tool = ToolBuilder::new("no_token_tool")
1780 .description("Tool that tries to report progress without token")
1781 .extractor_handler_typed::<_, _, _, AddInput>(
1782 (),
1783 move |ctx: Context, Json(_input): Json<AddInput>| {
1784 let attempted = progress_ref.clone();
1785 async move {
1786 ctx.report_progress(50.0, Some(100.0), None).await;
1788 attempted.store(true, Ordering::SeqCst);
1789 Ok(CallToolResult::text("done"))
1790 }
1791 },
1792 )
1793 .build()
1794 .expect("valid tool name");
1795
1796 let (tx, mut rx) = notification_channel(10);
1797 let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1798 let mut service = JsonRpcService::new(router.clone());
1799
1800 init_jsonrpc_service(&mut service, &router).await;
1801
1802 let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1804 "name": "no_token_tool",
1805 "arguments": {"a": 1, "b": 2}
1806 }));
1807
1808 let resp = service.call_single(req).await.unwrap();
1809 assert!(matches!(resp, JsonRpcResponse::Result(_)));
1810
1811 assert!(progress_attempted.load(Ordering::SeqCst));
1813
1814 assert!(rx.try_recv().is_err());
1816 }
1817
1818 #[tokio::test]
1819 async fn test_batch_errors_returned_not_dropped() {
1820 let add_tool = ToolBuilder::new("add")
1821 .description("Add two numbers")
1822 .handler(|input: AddInput| async move {
1823 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1824 })
1825 .build()
1826 .expect("valid tool name");
1827
1828 let router = McpRouter::new().tool(add_tool);
1829 let mut service = JsonRpcService::new(router.clone());
1830
1831 init_jsonrpc_service(&mut service, &router).await;
1832
1833 let requests = vec![
1835 JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1837 "name": "add",
1838 "arguments": {"a": 10, "b": 20}
1839 })),
1840 JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1842 "name": "nonexistent_tool",
1843 "arguments": {}
1844 })),
1845 JsonRpcRequest::new(3, "ping"),
1847 ];
1848
1849 let responses = service.call_batch(requests).await.unwrap();
1850
1851 assert_eq!(responses.len(), 3);
1853
1854 match &responses[0] {
1856 JsonRpcResponse::Result(r) => {
1857 assert_eq!(r.id, RequestId::Number(1));
1858 }
1859 JsonRpcResponse::Error(_) => panic!("Expected success for first request"),
1860 }
1861
1862 match &responses[1] {
1864 JsonRpcResponse::Error(e) => {
1865 assert_eq!(e.id, Some(RequestId::Number(2)));
1866 assert!(e.error.message.contains("not found") || e.error.code == -32601);
1868 }
1869 JsonRpcResponse::Result(_) => panic!("Expected error for second request"),
1870 }
1871
1872 match &responses[2] {
1874 JsonRpcResponse::Result(r) => {
1875 assert_eq!(r.id, RequestId::Number(3));
1876 }
1877 JsonRpcResponse::Error(_) => panic!("Expected success for third request"),
1878 }
1879 }
1880
1881 #[tokio::test]
1886 async fn test_list_resource_templates() {
1887 use crate::resource::ResourceTemplateBuilder;
1888 use std::collections::HashMap;
1889
1890 let template = ResourceTemplateBuilder::new("file:///{path}")
1891 .name("Project Files")
1892 .description("Access project files")
1893 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1894 Ok(ReadResourceResult {
1895 contents: vec![ResourceContent {
1896 uri,
1897 mime_type: None,
1898 text: None,
1899 blob: None,
1900 }],
1901 })
1902 });
1903
1904 let mut router = McpRouter::new().resource_template(template);
1905
1906 init_router(&mut router).await;
1908
1909 let req = RouterRequest {
1910 id: RequestId::Number(1),
1911 inner: McpRequest::ListResourceTemplates(ListResourceTemplatesParams::default()),
1912 extensions: Extensions::new(),
1913 };
1914
1915 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1916
1917 match resp.inner {
1918 Ok(McpResponse::ListResourceTemplates(result)) => {
1919 assert_eq!(result.resource_templates.len(), 1);
1920 assert_eq!(result.resource_templates[0].uri_template, "file:///{path}");
1921 assert_eq!(result.resource_templates[0].name, "Project Files");
1922 }
1923 _ => panic!("Expected ListResourceTemplates response"),
1924 }
1925 }
1926
1927 #[tokio::test]
1928 async fn test_read_resource_via_template() {
1929 use crate::resource::ResourceTemplateBuilder;
1930 use std::collections::HashMap;
1931
1932 let template = ResourceTemplateBuilder::new("db://users/{id}")
1933 .name("User Records")
1934 .handler(|uri: String, vars: HashMap<String, String>| async move {
1935 let id = vars.get("id").unwrap().clone();
1936 Ok(ReadResourceResult {
1937 contents: vec![ResourceContent {
1938 uri,
1939 mime_type: Some("application/json".to_string()),
1940 text: Some(format!(r#"{{"id": "{}"}}"#, id)),
1941 blob: None,
1942 }],
1943 })
1944 });
1945
1946 let mut router = McpRouter::new().resource_template(template);
1947
1948 init_router(&mut router).await;
1950
1951 let req = RouterRequest {
1953 id: RequestId::Number(1),
1954 inner: McpRequest::ReadResource(ReadResourceParams {
1955 uri: "db://users/123".to_string(),
1956 }),
1957 extensions: Extensions::new(),
1958 };
1959
1960 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1961
1962 match resp.inner {
1963 Ok(McpResponse::ReadResource(result)) => {
1964 assert_eq!(result.contents.len(), 1);
1965 assert_eq!(result.contents[0].uri, "db://users/123");
1966 assert!(result.contents[0].text.as_ref().unwrap().contains("123"));
1967 }
1968 _ => panic!("Expected ReadResource response"),
1969 }
1970 }
1971
1972 #[tokio::test]
1973 async fn test_static_resource_takes_precedence_over_template() {
1974 use crate::resource::{ResourceBuilder, ResourceTemplateBuilder};
1975 use std::collections::HashMap;
1976
1977 let template = ResourceTemplateBuilder::new("file:///{path}")
1979 .name("Files Template")
1980 .handler(|uri: String, _vars: HashMap<String, String>| async move {
1981 Ok(ReadResourceResult {
1982 contents: vec![ResourceContent {
1983 uri,
1984 mime_type: None,
1985 text: Some("from template".to_string()),
1986 blob: None,
1987 }],
1988 })
1989 });
1990
1991 let static_resource = ResourceBuilder::new("file:///README.md")
1993 .name("README")
1994 .text("from static resource");
1995
1996 let mut router = McpRouter::new()
1997 .resource_template(template)
1998 .resource(static_resource);
1999
2000 init_router(&mut router).await;
2002
2003 let req = RouterRequest {
2005 id: RequestId::Number(1),
2006 inner: McpRequest::ReadResource(ReadResourceParams {
2007 uri: "file:///README.md".to_string(),
2008 }),
2009 extensions: Extensions::new(),
2010 };
2011
2012 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2013
2014 match resp.inner {
2015 Ok(McpResponse::ReadResource(result)) => {
2016 assert_eq!(
2018 result.contents[0].text.as_deref(),
2019 Some("from static resource")
2020 );
2021 }
2022 _ => panic!("Expected ReadResource response"),
2023 }
2024 }
2025
2026 #[tokio::test]
2027 async fn test_resource_not_found_when_no_match() {
2028 use crate::resource::ResourceTemplateBuilder;
2029 use std::collections::HashMap;
2030
2031 let template = ResourceTemplateBuilder::new("db://users/{id}")
2032 .name("Users")
2033 .handler(|uri: String, _vars: HashMap<String, String>| async move {
2034 Ok(ReadResourceResult {
2035 contents: vec![ResourceContent {
2036 uri,
2037 mime_type: None,
2038 text: None,
2039 blob: None,
2040 }],
2041 })
2042 });
2043
2044 let mut router = McpRouter::new().resource_template(template);
2045
2046 init_router(&mut router).await;
2048
2049 let req = RouterRequest {
2051 id: RequestId::Number(1),
2052 inner: McpRequest::ReadResource(ReadResourceParams {
2053 uri: "db://posts/123".to_string(),
2054 }),
2055 extensions: Extensions::new(),
2056 };
2057
2058 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2059
2060 match resp.inner {
2061 Err(err) => {
2062 assert!(err.message.contains("not found"));
2063 }
2064 Ok(_) => panic!("Expected error for non-matching URI"),
2065 }
2066 }
2067
2068 #[tokio::test]
2069 async fn test_capabilities_include_resources_with_only_templates() {
2070 use crate::resource::ResourceTemplateBuilder;
2071 use std::collections::HashMap;
2072
2073 let template = ResourceTemplateBuilder::new("file:///{path}")
2074 .name("Files")
2075 .handler(|uri: String, _vars: HashMap<String, String>| async move {
2076 Ok(ReadResourceResult {
2077 contents: vec![ResourceContent {
2078 uri,
2079 mime_type: None,
2080 text: None,
2081 blob: None,
2082 }],
2083 })
2084 });
2085
2086 let mut router = McpRouter::new().resource_template(template);
2087
2088 let init_req = RouterRequest {
2090 id: RequestId::Number(0),
2091 inner: McpRequest::Initialize(InitializeParams {
2092 protocol_version: "2025-11-25".to_string(),
2093 capabilities: ClientCapabilities {
2094 roots: None,
2095 sampling: None,
2096 elicitation: None,
2097 },
2098 client_info: Implementation {
2099 name: "test".to_string(),
2100 version: "1.0".to_string(),
2101 ..Default::default()
2102 },
2103 }),
2104 extensions: Extensions::new(),
2105 };
2106 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2107
2108 match resp.inner {
2109 Ok(McpResponse::Initialize(result)) => {
2110 assert!(result.capabilities.resources.is_some());
2112 }
2113 _ => panic!("Expected Initialize response"),
2114 }
2115 }
2116
2117 #[tokio::test]
2122 async fn test_log_sends_notification() {
2123 use crate::context::notification_channel;
2124
2125 let (tx, mut rx) = notification_channel(10);
2126 let router = McpRouter::new().with_notification_sender(tx);
2127
2128 let sent = router.log_info("Test message");
2130 assert!(sent);
2131
2132 let notification = rx.try_recv().unwrap();
2134 match notification {
2135 ServerNotification::LogMessage(params) => {
2136 assert_eq!(params.level, LogLevel::Info);
2137 let data = params.data.unwrap();
2138 assert_eq!(
2139 data.get("message").unwrap().as_str().unwrap(),
2140 "Test message"
2141 );
2142 }
2143 _ => panic!("Expected LogMessage notification"),
2144 }
2145 }
2146
2147 #[tokio::test]
2148 async fn test_log_with_custom_params() {
2149 use crate::context::notification_channel;
2150
2151 let (tx, mut rx) = notification_channel(10);
2152 let router = McpRouter::new().with_notification_sender(tx);
2153
2154 let params = LoggingMessageParams::new(LogLevel::Error)
2156 .with_logger("database")
2157 .with_data(serde_json::json!({
2158 "error": "Connection failed",
2159 "host": "localhost"
2160 }));
2161
2162 let sent = router.log(params);
2163 assert!(sent);
2164
2165 let notification = rx.try_recv().unwrap();
2166 match notification {
2167 ServerNotification::LogMessage(params) => {
2168 assert_eq!(params.level, LogLevel::Error);
2169 assert_eq!(params.logger.as_deref(), Some("database"));
2170 let data = params.data.unwrap();
2171 assert_eq!(
2172 data.get("error").unwrap().as_str().unwrap(),
2173 "Connection failed"
2174 );
2175 }
2176 _ => panic!("Expected LogMessage notification"),
2177 }
2178 }
2179
2180 #[tokio::test]
2181 async fn test_log_without_channel_returns_false() {
2182 let router = McpRouter::new();
2184
2185 assert!(!router.log_info("Test"));
2187 assert!(!router.log_warning("Test"));
2188 assert!(!router.log_error("Test"));
2189 assert!(!router.log_debug("Test"));
2190 }
2191
2192 #[tokio::test]
2193 async fn test_logging_capability_with_channel() {
2194 use crate::context::notification_channel;
2195
2196 let (tx, _rx) = notification_channel(10);
2197 let mut router = McpRouter::new().with_notification_sender(tx);
2198
2199 let init_req = RouterRequest {
2201 id: RequestId::Number(0),
2202 inner: McpRequest::Initialize(InitializeParams {
2203 protocol_version: "2025-11-25".to_string(),
2204 capabilities: ClientCapabilities {
2205 roots: None,
2206 sampling: None,
2207 elicitation: None,
2208 },
2209 client_info: Implementation {
2210 name: "test".to_string(),
2211 version: "1.0".to_string(),
2212 ..Default::default()
2213 },
2214 }),
2215 extensions: Extensions::new(),
2216 };
2217 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2218
2219 match resp.inner {
2220 Ok(McpResponse::Initialize(result)) => {
2221 assert!(result.capabilities.logging.is_some());
2223 }
2224 _ => panic!("Expected Initialize response"),
2225 }
2226 }
2227
2228 #[tokio::test]
2229 async fn test_no_logging_capability_without_channel() {
2230 let mut router = McpRouter::new();
2231
2232 let init_req = RouterRequest {
2234 id: RequestId::Number(0),
2235 inner: McpRequest::Initialize(InitializeParams {
2236 protocol_version: "2025-11-25".to_string(),
2237 capabilities: ClientCapabilities {
2238 roots: None,
2239 sampling: None,
2240 elicitation: None,
2241 },
2242 client_info: Implementation {
2243 name: "test".to_string(),
2244 version: "1.0".to_string(),
2245 ..Default::default()
2246 },
2247 }),
2248 extensions: Extensions::new(),
2249 };
2250 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2251
2252 match resp.inner {
2253 Ok(McpResponse::Initialize(result)) => {
2254 assert!(result.capabilities.logging.is_none());
2256 }
2257 _ => panic!("Expected Initialize response"),
2258 }
2259 }
2260
2261 #[tokio::test]
2266 async fn test_enqueue_task() {
2267 let add_tool = ToolBuilder::new("add")
2268 .description("Add two numbers")
2269 .handler(|input: AddInput| async move {
2270 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2271 })
2272 .build()
2273 .expect("valid tool name");
2274
2275 let mut router = McpRouter::new().tool(add_tool);
2276 init_router(&mut router).await;
2277
2278 let req = RouterRequest {
2279 id: RequestId::Number(1),
2280 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2281 tool_name: "add".to_string(),
2282 arguments: serde_json::json!({"a": 5, "b": 10}),
2283 ttl: None,
2284 }),
2285 extensions: Extensions::new(),
2286 };
2287
2288 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2289
2290 match resp.inner {
2291 Ok(McpResponse::EnqueueTask(result)) => {
2292 assert!(result.task_id.starts_with("task-"));
2293 assert_eq!(result.status, TaskStatus::Working);
2294 }
2295 _ => panic!("Expected EnqueueTask response"),
2296 }
2297 }
2298
2299 #[tokio::test]
2300 async fn test_list_tasks_empty() {
2301 let mut router = McpRouter::new();
2302 init_router(&mut router).await;
2303
2304 let req = RouterRequest {
2305 id: RequestId::Number(1),
2306 inner: McpRequest::ListTasks(ListTasksParams::default()),
2307 extensions: Extensions::new(),
2308 };
2309
2310 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2311
2312 match resp.inner {
2313 Ok(McpResponse::ListTasks(result)) => {
2314 assert!(result.tasks.is_empty());
2315 }
2316 _ => panic!("Expected ListTasks response"),
2317 }
2318 }
2319
2320 #[tokio::test]
2321 async fn test_task_lifecycle_complete() {
2322 let add_tool = ToolBuilder::new("add")
2323 .description("Add two numbers")
2324 .handler(|input: AddInput| async move {
2325 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2326 })
2327 .build()
2328 .expect("valid tool name");
2329
2330 let mut router = McpRouter::new().tool(add_tool);
2331 init_router(&mut router).await;
2332
2333 let req = RouterRequest {
2335 id: RequestId::Number(1),
2336 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2337 tool_name: "add".to_string(),
2338 arguments: serde_json::json!({"a": 7, "b": 8}),
2339 ttl: None,
2340 }),
2341 extensions: Extensions::new(),
2342 };
2343
2344 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2345 let task_id = match resp.inner {
2346 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2347 _ => panic!("Expected EnqueueTask response"),
2348 };
2349
2350 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
2352
2353 let req = RouterRequest {
2355 id: RequestId::Number(2),
2356 inner: McpRequest::GetTaskResult(GetTaskResultParams {
2357 task_id: task_id.clone(),
2358 }),
2359 extensions: Extensions::new(),
2360 };
2361
2362 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2363
2364 match resp.inner {
2365 Ok(McpResponse::GetTaskResult(result)) => {
2366 assert_eq!(result.task_id, task_id);
2367 assert_eq!(result.status, TaskStatus::Completed);
2368 assert!(result.result.is_some());
2369 assert!(result.error.is_none());
2370
2371 let tool_result = result.result.unwrap();
2373 match &tool_result.content[0] {
2374 Content::Text { text, .. } => assert_eq!(text, "15"),
2375 _ => panic!("Expected text content"),
2376 }
2377 }
2378 _ => panic!("Expected GetTaskResult response"),
2379 }
2380 }
2381
2382 #[tokio::test]
2383 async fn test_task_cancellation() {
2384 let slow_tool = ToolBuilder::new("slow")
2386 .description("Slow tool")
2387 .handler(|_input: serde_json::Value| async move {
2388 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
2389 Ok(CallToolResult::text("done"))
2390 })
2391 .build()
2392 .expect("valid tool name");
2393
2394 let mut router = McpRouter::new().tool(slow_tool);
2395 init_router(&mut router).await;
2396
2397 let req = RouterRequest {
2399 id: RequestId::Number(1),
2400 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2401 tool_name: "slow".to_string(),
2402 arguments: serde_json::json!({}),
2403 ttl: None,
2404 }),
2405 extensions: Extensions::new(),
2406 };
2407
2408 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2409 let task_id = match resp.inner {
2410 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2411 _ => panic!("Expected EnqueueTask response"),
2412 };
2413
2414 let req = RouterRequest {
2416 id: RequestId::Number(2),
2417 inner: McpRequest::CancelTask(CancelTaskParams {
2418 task_id: task_id.clone(),
2419 reason: Some("Test cancellation".to_string()),
2420 }),
2421 extensions: Extensions::new(),
2422 };
2423
2424 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2425
2426 match resp.inner {
2427 Ok(McpResponse::CancelTask(result)) => {
2428 assert!(result.cancelled);
2429 assert_eq!(result.status, TaskStatus::Cancelled);
2430 }
2431 _ => panic!("Expected CancelTask response"),
2432 }
2433 }
2434
2435 #[tokio::test]
2436 async fn test_get_task_info() {
2437 let add_tool = ToolBuilder::new("add")
2438 .description("Add two numbers")
2439 .handler(|input: AddInput| async move {
2440 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2441 })
2442 .build()
2443 .expect("valid tool name");
2444
2445 let mut router = McpRouter::new().tool(add_tool);
2446 init_router(&mut router).await;
2447
2448 let req = RouterRequest {
2450 id: RequestId::Number(1),
2451 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2452 tool_name: "add".to_string(),
2453 arguments: serde_json::json!({"a": 1, "b": 2}),
2454 ttl: Some(600),
2455 }),
2456 extensions: Extensions::new(),
2457 };
2458
2459 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2460 let task_id = match resp.inner {
2461 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2462 _ => panic!("Expected EnqueueTask response"),
2463 };
2464
2465 let req = RouterRequest {
2467 id: RequestId::Number(2),
2468 inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2469 task_id: task_id.clone(),
2470 }),
2471 extensions: Extensions::new(),
2472 };
2473
2474 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2475
2476 match resp.inner {
2477 Ok(McpResponse::GetTaskInfo(info)) => {
2478 assert_eq!(info.task_id, task_id);
2479 assert!(info.created_at.contains('T')); assert_eq!(info.ttl, Some(600));
2481 }
2482 _ => panic!("Expected GetTaskInfo response"),
2483 }
2484 }
2485
2486 #[tokio::test]
2487 async fn test_enqueue_nonexistent_tool() {
2488 let mut router = McpRouter::new();
2489 init_router(&mut router).await;
2490
2491 let req = RouterRequest {
2492 id: RequestId::Number(1),
2493 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2494 tool_name: "nonexistent".to_string(),
2495 arguments: serde_json::json!({}),
2496 ttl: None,
2497 }),
2498 extensions: Extensions::new(),
2499 };
2500
2501 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2502
2503 match resp.inner {
2504 Err(e) => {
2505 assert!(e.message.contains("not found"));
2506 }
2507 _ => panic!("Expected error response"),
2508 }
2509 }
2510
2511 #[tokio::test]
2512 async fn test_get_nonexistent_task() {
2513 let mut router = McpRouter::new();
2514 init_router(&mut router).await;
2515
2516 let req = RouterRequest {
2517 id: RequestId::Number(1),
2518 inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2519 task_id: "task-999".to_string(),
2520 }),
2521 extensions: Extensions::new(),
2522 };
2523
2524 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2525
2526 match resp.inner {
2527 Err(e) => {
2528 assert!(e.message.contains("not found"));
2529 }
2530 _ => panic!("Expected error response"),
2531 }
2532 }
2533
2534 #[tokio::test]
2539 async fn test_subscribe_to_resource() {
2540 use crate::resource::ResourceBuilder;
2541
2542 let resource = ResourceBuilder::new("file:///test.txt")
2543 .name("Test File")
2544 .text("Hello");
2545
2546 let mut router = McpRouter::new().resource(resource);
2547 init_router(&mut router).await;
2548
2549 let req = RouterRequest {
2551 id: RequestId::Number(1),
2552 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2553 uri: "file:///test.txt".to_string(),
2554 }),
2555 extensions: Extensions::new(),
2556 };
2557
2558 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2559
2560 match resp.inner {
2561 Ok(McpResponse::SubscribeResource(_)) => {
2562 assert!(router.is_subscribed("file:///test.txt"));
2564 }
2565 _ => panic!("Expected SubscribeResource response"),
2566 }
2567 }
2568
2569 #[tokio::test]
2570 async fn test_unsubscribe_from_resource() {
2571 use crate::resource::ResourceBuilder;
2572
2573 let resource = ResourceBuilder::new("file:///test.txt")
2574 .name("Test File")
2575 .text("Hello");
2576
2577 let mut router = McpRouter::new().resource(resource);
2578 init_router(&mut router).await;
2579
2580 let req = RouterRequest {
2582 id: RequestId::Number(1),
2583 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2584 uri: "file:///test.txt".to_string(),
2585 }),
2586 extensions: Extensions::new(),
2587 };
2588 let _ = router.ready().await.unwrap().call(req).await.unwrap();
2589 assert!(router.is_subscribed("file:///test.txt"));
2590
2591 let req = RouterRequest {
2593 id: RequestId::Number(2),
2594 inner: McpRequest::UnsubscribeResource(UnsubscribeResourceParams {
2595 uri: "file:///test.txt".to_string(),
2596 }),
2597 extensions: Extensions::new(),
2598 };
2599
2600 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2601
2602 match resp.inner {
2603 Ok(McpResponse::UnsubscribeResource(_)) => {
2604 assert!(!router.is_subscribed("file:///test.txt"));
2606 }
2607 _ => panic!("Expected UnsubscribeResource response"),
2608 }
2609 }
2610
2611 #[tokio::test]
2612 async fn test_subscribe_nonexistent_resource() {
2613 let mut router = McpRouter::new();
2614 init_router(&mut router).await;
2615
2616 let req = RouterRequest {
2617 id: RequestId::Number(1),
2618 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2619 uri: "file:///nonexistent.txt".to_string(),
2620 }),
2621 extensions: Extensions::new(),
2622 };
2623
2624 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2625
2626 match resp.inner {
2627 Err(e) => {
2628 assert!(e.message.contains("not found"));
2629 }
2630 _ => panic!("Expected error response"),
2631 }
2632 }
2633
2634 #[tokio::test]
2635 async fn test_notify_resource_updated() {
2636 use crate::context::notification_channel;
2637 use crate::resource::ResourceBuilder;
2638
2639 let (tx, mut rx) = notification_channel(10);
2640
2641 let resource = ResourceBuilder::new("file:///test.txt")
2642 .name("Test File")
2643 .text("Hello");
2644
2645 let router = McpRouter::new()
2646 .resource(resource)
2647 .with_notification_sender(tx);
2648
2649 router.subscribe("file:///test.txt");
2651
2652 let sent = router.notify_resource_updated("file:///test.txt");
2654 assert!(sent);
2655
2656 let notification = rx.try_recv().unwrap();
2658 match notification {
2659 ServerNotification::ResourceUpdated { uri } => {
2660 assert_eq!(uri, "file:///test.txt");
2661 }
2662 _ => panic!("Expected ResourceUpdated notification"),
2663 }
2664 }
2665
2666 #[tokio::test]
2667 async fn test_notify_resource_updated_not_subscribed() {
2668 use crate::context::notification_channel;
2669 use crate::resource::ResourceBuilder;
2670
2671 let (tx, mut rx) = notification_channel(10);
2672
2673 let resource = ResourceBuilder::new("file:///test.txt")
2674 .name("Test File")
2675 .text("Hello");
2676
2677 let router = McpRouter::new()
2678 .resource(resource)
2679 .with_notification_sender(tx);
2680
2681 let sent = router.notify_resource_updated("file:///test.txt");
2683 assert!(!sent); assert!(rx.try_recv().is_err());
2687 }
2688
2689 #[tokio::test]
2690 async fn test_notify_resources_list_changed() {
2691 use crate::context::notification_channel;
2692
2693 let (tx, mut rx) = notification_channel(10);
2694 let router = McpRouter::new().with_notification_sender(tx);
2695
2696 let sent = router.notify_resources_list_changed();
2697 assert!(sent);
2698
2699 let notification = rx.try_recv().unwrap();
2700 match notification {
2701 ServerNotification::ResourcesListChanged => {}
2702 _ => panic!("Expected ResourcesListChanged notification"),
2703 }
2704 }
2705
2706 #[tokio::test]
2707 async fn test_subscribed_uris() {
2708 use crate::resource::ResourceBuilder;
2709
2710 let resource1 = ResourceBuilder::new("file:///a.txt").name("A").text("A");
2711
2712 let resource2 = ResourceBuilder::new("file:///b.txt").name("B").text("B");
2713
2714 let router = McpRouter::new().resource(resource1).resource(resource2);
2715
2716 router.subscribe("file:///a.txt");
2718 router.subscribe("file:///b.txt");
2719
2720 let uris = router.subscribed_uris();
2721 assert_eq!(uris.len(), 2);
2722 assert!(uris.contains(&"file:///a.txt".to_string()));
2723 assert!(uris.contains(&"file:///b.txt".to_string()));
2724 }
2725
2726 #[tokio::test]
2727 async fn test_subscription_capability_advertised() {
2728 use crate::resource::ResourceBuilder;
2729
2730 let resource = ResourceBuilder::new("file:///test.txt")
2731 .name("Test")
2732 .text("Hello");
2733
2734 let mut router = McpRouter::new().resource(resource);
2735
2736 let init_req = RouterRequest {
2738 id: RequestId::Number(0),
2739 inner: McpRequest::Initialize(InitializeParams {
2740 protocol_version: "2025-11-25".to_string(),
2741 capabilities: ClientCapabilities {
2742 roots: None,
2743 sampling: None,
2744 elicitation: None,
2745 },
2746 client_info: Implementation {
2747 name: "test".to_string(),
2748 version: "1.0".to_string(),
2749 ..Default::default()
2750 },
2751 }),
2752 extensions: Extensions::new(),
2753 };
2754 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2755
2756 match resp.inner {
2757 Ok(McpResponse::Initialize(result)) => {
2758 let resources_cap = result.capabilities.resources.unwrap();
2760 assert!(resources_cap.subscribe);
2761 }
2762 _ => panic!("Expected Initialize response"),
2763 }
2764 }
2765
2766 #[tokio::test]
2767 async fn test_completion_handler() {
2768 let router = McpRouter::new()
2769 .server_info("test", "1.0")
2770 .completion_handler(|params: CompleteParams| async move {
2771 let prefix = ¶ms.argument.value;
2773 let suggestions: Vec<String> = vec!["alpha", "beta", "gamma"]
2774 .into_iter()
2775 .filter(|s| s.starts_with(prefix))
2776 .map(String::from)
2777 .collect();
2778 Ok(CompleteResult::new(suggestions))
2779 });
2780
2781 let init_req = RouterRequest {
2783 id: RequestId::Number(0),
2784 inner: McpRequest::Initialize(InitializeParams {
2785 protocol_version: "2025-11-25".to_string(),
2786 capabilities: ClientCapabilities::default(),
2787 client_info: Implementation {
2788 name: "test".to_string(),
2789 version: "1.0".to_string(),
2790 ..Default::default()
2791 },
2792 }),
2793 extensions: Extensions::new(),
2794 };
2795 let resp = router
2796 .clone()
2797 .ready()
2798 .await
2799 .unwrap()
2800 .call(init_req)
2801 .await
2802 .unwrap();
2803
2804 match resp.inner {
2806 Ok(McpResponse::Initialize(result)) => {
2807 assert!(result.capabilities.completions.is_some());
2808 }
2809 _ => panic!("Expected Initialize response"),
2810 }
2811
2812 router.handle_notification(McpNotification::Initialized);
2814
2815 let complete_req = RouterRequest {
2817 id: RequestId::Number(1),
2818 inner: McpRequest::Complete(CompleteParams {
2819 reference: CompletionReference::prompt("test-prompt"),
2820 argument: CompletionArgument::new("query", "al"),
2821 }),
2822 extensions: Extensions::new(),
2823 };
2824 let resp = router
2825 .clone()
2826 .ready()
2827 .await
2828 .unwrap()
2829 .call(complete_req)
2830 .await
2831 .unwrap();
2832
2833 match resp.inner {
2834 Ok(McpResponse::Complete(result)) => {
2835 assert_eq!(result.completion.values, vec!["alpha"]);
2836 }
2837 _ => panic!("Expected Complete response"),
2838 }
2839 }
2840
2841 #[tokio::test]
2842 async fn test_completion_without_handler_returns_empty() {
2843 let router = McpRouter::new().server_info("test", "1.0");
2844
2845 let init_req = RouterRequest {
2847 id: RequestId::Number(0),
2848 inner: McpRequest::Initialize(InitializeParams {
2849 protocol_version: "2025-11-25".to_string(),
2850 capabilities: ClientCapabilities::default(),
2851 client_info: Implementation {
2852 name: "test".to_string(),
2853 version: "1.0".to_string(),
2854 ..Default::default()
2855 },
2856 }),
2857 extensions: Extensions::new(),
2858 };
2859 let resp = router
2860 .clone()
2861 .ready()
2862 .await
2863 .unwrap()
2864 .call(init_req)
2865 .await
2866 .unwrap();
2867
2868 match resp.inner {
2870 Ok(McpResponse::Initialize(result)) => {
2871 assert!(result.capabilities.completions.is_none());
2872 }
2873 _ => panic!("Expected Initialize response"),
2874 }
2875
2876 router.handle_notification(McpNotification::Initialized);
2878
2879 let complete_req = RouterRequest {
2881 id: RequestId::Number(1),
2882 inner: McpRequest::Complete(CompleteParams {
2883 reference: CompletionReference::prompt("test-prompt"),
2884 argument: CompletionArgument::new("query", "al"),
2885 }),
2886 extensions: Extensions::new(),
2887 };
2888 let resp = router
2889 .clone()
2890 .ready()
2891 .await
2892 .unwrap()
2893 .call(complete_req)
2894 .await
2895 .unwrap();
2896
2897 match resp.inner {
2898 Ok(McpResponse::Complete(result)) => {
2899 assert!(result.completion.values.is_empty());
2900 }
2901 _ => panic!("Expected Complete response"),
2902 }
2903 }
2904
2905 #[tokio::test]
2906 async fn test_tool_filter_list() {
2907 use crate::filter::CapabilityFilter;
2908 use crate::tool::Tool;
2909
2910 let public_tool = ToolBuilder::new("public")
2911 .description("Public tool")
2912 .handler(|_: AddInput| async move { Ok(CallToolResult::text("public")) })
2913 .build()
2914 .expect("valid tool name");
2915
2916 let admin_tool = ToolBuilder::new("admin")
2917 .description("Admin tool")
2918 .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
2919 .build()
2920 .expect("valid tool name");
2921
2922 let mut router = McpRouter::new()
2923 .tool(public_tool)
2924 .tool(admin_tool)
2925 .tool_filter(CapabilityFilter::new(|_, tool: &Tool| tool.name != "admin"));
2926
2927 init_router(&mut router).await;
2929
2930 let req = RouterRequest {
2931 id: RequestId::Number(1),
2932 inner: McpRequest::ListTools(ListToolsParams::default()),
2933 extensions: Extensions::new(),
2934 };
2935
2936 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2937
2938 match resp.inner {
2939 Ok(McpResponse::ListTools(result)) => {
2940 assert_eq!(result.tools.len(), 1);
2942 assert_eq!(result.tools[0].name, "public");
2943 }
2944 _ => panic!("Expected ListTools response"),
2945 }
2946 }
2947
2948 #[tokio::test]
2949 async fn test_tool_filter_call_denied() {
2950 use crate::filter::CapabilityFilter;
2951 use crate::tool::Tool;
2952
2953 let admin_tool = ToolBuilder::new("admin")
2954 .description("Admin tool")
2955 .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
2956 .build()
2957 .expect("valid tool name");
2958
2959 let mut router = McpRouter::new()
2960 .tool(admin_tool)
2961 .tool_filter(CapabilityFilter::new(|_, _: &Tool| false)); init_router(&mut router).await;
2965
2966 let req = RouterRequest {
2967 id: RequestId::Number(1),
2968 inner: McpRequest::CallTool(CallToolParams {
2969 name: "admin".to_string(),
2970 arguments: serde_json::json!({"a": 1, "b": 2}),
2971 meta: None,
2972 }),
2973 extensions: Extensions::new(),
2974 };
2975
2976 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2977
2978 match resp.inner {
2980 Err(e) => {
2981 assert_eq!(e.code, -32601); }
2983 _ => panic!("Expected JsonRpc error"),
2984 }
2985 }
2986
2987 #[tokio::test]
2988 async fn test_tool_filter_call_allowed() {
2989 use crate::filter::CapabilityFilter;
2990 use crate::tool::Tool;
2991
2992 let public_tool = ToolBuilder::new("public")
2993 .description("Public tool")
2994 .handler(|input: AddInput| async move {
2995 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2996 })
2997 .build()
2998 .expect("valid tool name");
2999
3000 let mut router = McpRouter::new()
3001 .tool(public_tool)
3002 .tool_filter(CapabilityFilter::new(|_, _: &Tool| true)); init_router(&mut router).await;
3006
3007 let req = RouterRequest {
3008 id: RequestId::Number(1),
3009 inner: McpRequest::CallTool(CallToolParams {
3010 name: "public".to_string(),
3011 arguments: serde_json::json!({"a": 1, "b": 2}),
3012 meta: None,
3013 }),
3014 extensions: Extensions::new(),
3015 };
3016
3017 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3018
3019 match resp.inner {
3020 Ok(McpResponse::CallTool(result)) => {
3021 assert!(!result.is_error);
3022 }
3023 _ => panic!("Expected CallTool response"),
3024 }
3025 }
3026
3027 #[tokio::test]
3028 async fn test_tool_filter_custom_denial() {
3029 use crate::filter::{CapabilityFilter, DenialBehavior};
3030 use crate::tool::Tool;
3031
3032 let admin_tool = ToolBuilder::new("admin")
3033 .description("Admin tool")
3034 .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
3035 .build()
3036 .expect("valid tool name");
3037
3038 let mut router = McpRouter::new().tool(admin_tool).tool_filter(
3039 CapabilityFilter::new(|_, _: &Tool| false)
3040 .denial_behavior(DenialBehavior::Unauthorized),
3041 );
3042
3043 init_router(&mut router).await;
3045
3046 let req = RouterRequest {
3047 id: RequestId::Number(1),
3048 inner: McpRequest::CallTool(CallToolParams {
3049 name: "admin".to_string(),
3050 arguments: serde_json::json!({"a": 1, "b": 2}),
3051 meta: None,
3052 }),
3053 extensions: Extensions::new(),
3054 };
3055
3056 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3057
3058 match resp.inner {
3060 Err(e) => {
3061 assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
3063 }
3064 _ => panic!("Expected JsonRpc error"),
3065 }
3066 }
3067
3068 #[tokio::test]
3069 async fn test_resource_filter_list() {
3070 use crate::filter::CapabilityFilter;
3071 use crate::resource::{Resource, ResourceBuilder};
3072
3073 let public_resource = ResourceBuilder::new("file:///public.txt")
3074 .name("Public File")
3075 .text("public content");
3076
3077 let secret_resource = ResourceBuilder::new("file:///secret.txt")
3078 .name("Secret File")
3079 .text("secret content");
3080
3081 let mut router = McpRouter::new()
3082 .resource(public_resource)
3083 .resource(secret_resource)
3084 .resource_filter(CapabilityFilter::new(|_, r: &Resource| {
3085 !r.name.contains("Secret")
3086 }));
3087
3088 init_router(&mut router).await;
3090
3091 let req = RouterRequest {
3092 id: RequestId::Number(1),
3093 inner: McpRequest::ListResources(ListResourcesParams::default()),
3094 extensions: Extensions::new(),
3095 };
3096
3097 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3098
3099 match resp.inner {
3100 Ok(McpResponse::ListResources(result)) => {
3101 assert_eq!(result.resources.len(), 1);
3103 assert_eq!(result.resources[0].name, "Public File");
3104 }
3105 _ => panic!("Expected ListResources response"),
3106 }
3107 }
3108
3109 #[tokio::test]
3110 async fn test_resource_filter_read_denied() {
3111 use crate::filter::CapabilityFilter;
3112 use crate::resource::{Resource, ResourceBuilder};
3113
3114 let secret_resource = ResourceBuilder::new("file:///secret.txt")
3115 .name("Secret File")
3116 .text("secret content");
3117
3118 let mut router = McpRouter::new()
3119 .resource(secret_resource)
3120 .resource_filter(CapabilityFilter::new(|_, _: &Resource| false)); init_router(&mut router).await;
3124
3125 let req = RouterRequest {
3126 id: RequestId::Number(1),
3127 inner: McpRequest::ReadResource(ReadResourceParams {
3128 uri: "file:///secret.txt".to_string(),
3129 }),
3130 extensions: Extensions::new(),
3131 };
3132
3133 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3134
3135 match resp.inner {
3137 Err(e) => {
3138 assert_eq!(e.code, -32601); }
3140 _ => panic!("Expected JsonRpc error"),
3141 }
3142 }
3143
3144 #[tokio::test]
3145 async fn test_resource_filter_read_allowed() {
3146 use crate::filter::CapabilityFilter;
3147 use crate::resource::{Resource, ResourceBuilder};
3148
3149 let public_resource = ResourceBuilder::new("file:///public.txt")
3150 .name("Public File")
3151 .text("public content");
3152
3153 let mut router = McpRouter::new()
3154 .resource(public_resource)
3155 .resource_filter(CapabilityFilter::new(|_, _: &Resource| true)); init_router(&mut router).await;
3159
3160 let req = RouterRequest {
3161 id: RequestId::Number(1),
3162 inner: McpRequest::ReadResource(ReadResourceParams {
3163 uri: "file:///public.txt".to_string(),
3164 }),
3165 extensions: Extensions::new(),
3166 };
3167
3168 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3169
3170 match resp.inner {
3171 Ok(McpResponse::ReadResource(result)) => {
3172 assert_eq!(result.contents.len(), 1);
3173 assert_eq!(result.contents[0].text.as_deref(), Some("public content"));
3174 }
3175 _ => panic!("Expected ReadResource response"),
3176 }
3177 }
3178
3179 #[tokio::test]
3180 async fn test_resource_filter_custom_denial() {
3181 use crate::filter::{CapabilityFilter, DenialBehavior};
3182 use crate::resource::{Resource, ResourceBuilder};
3183
3184 let secret_resource = ResourceBuilder::new("file:///secret.txt")
3185 .name("Secret File")
3186 .text("secret content");
3187
3188 let mut router = McpRouter::new().resource(secret_resource).resource_filter(
3189 CapabilityFilter::new(|_, _: &Resource| false)
3190 .denial_behavior(DenialBehavior::Unauthorized),
3191 );
3192
3193 init_router(&mut router).await;
3195
3196 let req = RouterRequest {
3197 id: RequestId::Number(1),
3198 inner: McpRequest::ReadResource(ReadResourceParams {
3199 uri: "file:///secret.txt".to_string(),
3200 }),
3201 extensions: Extensions::new(),
3202 };
3203
3204 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3205
3206 match resp.inner {
3208 Err(e) => {
3209 assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
3211 }
3212 _ => panic!("Expected JsonRpc error"),
3213 }
3214 }
3215
3216 #[tokio::test]
3217 async fn test_prompt_filter_list() {
3218 use crate::filter::CapabilityFilter;
3219 use crate::prompt::{Prompt, PromptBuilder};
3220
3221 let public_prompt = PromptBuilder::new("greeting")
3222 .description("A greeting")
3223 .user_message("Hello!");
3224
3225 let admin_prompt = PromptBuilder::new("system_debug")
3226 .description("Admin prompt")
3227 .user_message("Debug");
3228
3229 let mut router = McpRouter::new()
3230 .prompt(public_prompt)
3231 .prompt(admin_prompt)
3232 .prompt_filter(CapabilityFilter::new(|_, p: &Prompt| {
3233 !p.name.contains("system")
3234 }));
3235
3236 init_router(&mut router).await;
3238
3239 let req = RouterRequest {
3240 id: RequestId::Number(1),
3241 inner: McpRequest::ListPrompts(ListPromptsParams::default()),
3242 extensions: Extensions::new(),
3243 };
3244
3245 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3246
3247 match resp.inner {
3248 Ok(McpResponse::ListPrompts(result)) => {
3249 assert_eq!(result.prompts.len(), 1);
3251 assert_eq!(result.prompts[0].name, "greeting");
3252 }
3253 _ => panic!("Expected ListPrompts response"),
3254 }
3255 }
3256
3257 #[tokio::test]
3258 async fn test_prompt_filter_get_denied() {
3259 use crate::filter::CapabilityFilter;
3260 use crate::prompt::{Prompt, PromptBuilder};
3261 use std::collections::HashMap;
3262
3263 let admin_prompt = PromptBuilder::new("system_debug")
3264 .description("Admin prompt")
3265 .user_message("Debug");
3266
3267 let mut router = McpRouter::new()
3268 .prompt(admin_prompt)
3269 .prompt_filter(CapabilityFilter::new(|_, _: &Prompt| false)); init_router(&mut router).await;
3273
3274 let req = RouterRequest {
3275 id: RequestId::Number(1),
3276 inner: McpRequest::GetPrompt(GetPromptParams {
3277 name: "system_debug".to_string(),
3278 arguments: HashMap::new(),
3279 }),
3280 extensions: Extensions::new(),
3281 };
3282
3283 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3284
3285 match resp.inner {
3287 Err(e) => {
3288 assert_eq!(e.code, -32601); }
3290 _ => panic!("Expected JsonRpc error"),
3291 }
3292 }
3293
3294 #[tokio::test]
3295 async fn test_prompt_filter_get_allowed() {
3296 use crate::filter::CapabilityFilter;
3297 use crate::prompt::{Prompt, PromptBuilder};
3298 use std::collections::HashMap;
3299
3300 let public_prompt = PromptBuilder::new("greeting")
3301 .description("A greeting")
3302 .user_message("Hello!");
3303
3304 let mut router = McpRouter::new()
3305 .prompt(public_prompt)
3306 .prompt_filter(CapabilityFilter::new(|_, _: &Prompt| true)); init_router(&mut router).await;
3310
3311 let req = RouterRequest {
3312 id: RequestId::Number(1),
3313 inner: McpRequest::GetPrompt(GetPromptParams {
3314 name: "greeting".to_string(),
3315 arguments: HashMap::new(),
3316 }),
3317 extensions: Extensions::new(),
3318 };
3319
3320 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3321
3322 match resp.inner {
3323 Ok(McpResponse::GetPrompt(result)) => {
3324 assert_eq!(result.messages.len(), 1);
3325 }
3326 _ => panic!("Expected GetPrompt response"),
3327 }
3328 }
3329
3330 #[tokio::test]
3331 async fn test_prompt_filter_custom_denial() {
3332 use crate::filter::{CapabilityFilter, DenialBehavior};
3333 use crate::prompt::{Prompt, PromptBuilder};
3334 use std::collections::HashMap;
3335
3336 let admin_prompt = PromptBuilder::new("system_debug")
3337 .description("Admin prompt")
3338 .user_message("Debug");
3339
3340 let mut router = McpRouter::new().prompt(admin_prompt).prompt_filter(
3341 CapabilityFilter::new(|_, _: &Prompt| false)
3342 .denial_behavior(DenialBehavior::Unauthorized),
3343 );
3344
3345 init_router(&mut router).await;
3347
3348 let req = RouterRequest {
3349 id: RequestId::Number(1),
3350 inner: McpRequest::GetPrompt(GetPromptParams {
3351 name: "system_debug".to_string(),
3352 arguments: HashMap::new(),
3353 }),
3354 extensions: Extensions::new(),
3355 };
3356
3357 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3358
3359 match resp.inner {
3361 Err(e) => {
3362 assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
3364 }
3365 _ => panic!("Expected JsonRpc error"),
3366 }
3367 }
3368
3369 #[derive(Debug, Deserialize, JsonSchema)]
3374 struct StringInput {
3375 value: String,
3376 }
3377
3378 #[tokio::test]
3379 async fn test_router_merge_tools() {
3380 let tool_a = ToolBuilder::new("tool_a")
3382 .description("Tool A")
3383 .handler(|_: StringInput| async move { Ok(CallToolResult::text("A")) })
3384 .build()
3385 .unwrap();
3386
3387 let router_a = McpRouter::new().tool(tool_a);
3388
3389 let tool_b = ToolBuilder::new("tool_b")
3391 .description("Tool B")
3392 .handler(|_: StringInput| async move { Ok(CallToolResult::text("B")) })
3393 .build()
3394 .unwrap();
3395 let tool_c = ToolBuilder::new("tool_c")
3396 .description("Tool C")
3397 .handler(|_: StringInput| async move { Ok(CallToolResult::text("C")) })
3398 .build()
3399 .unwrap();
3400
3401 let router_b = McpRouter::new().tool(tool_b).tool(tool_c);
3402
3403 let mut merged = McpRouter::new()
3405 .server_info("merged", "1.0")
3406 .merge(router_a)
3407 .merge(router_b);
3408
3409 init_router(&mut merged).await;
3410
3411 let req = RouterRequest {
3413 id: RequestId::Number(1),
3414 inner: McpRequest::ListTools(ListToolsParams::default()),
3415 extensions: Extensions::new(),
3416 };
3417
3418 let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3419
3420 match resp.inner {
3421 Ok(McpResponse::ListTools(result)) => {
3422 assert_eq!(result.tools.len(), 3);
3423 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3424 assert!(names.contains(&"tool_a"));
3425 assert!(names.contains(&"tool_b"));
3426 assert!(names.contains(&"tool_c"));
3427 }
3428 _ => panic!("Expected ListTools response"),
3429 }
3430 }
3431
3432 #[tokio::test]
3433 async fn test_router_merge_overwrites_duplicates() {
3434 let tool_v1 = ToolBuilder::new("shared")
3436 .description("Version 1")
3437 .handler(|_: StringInput| async move { Ok(CallToolResult::text("v1")) })
3438 .build()
3439 .unwrap();
3440
3441 let router_a = McpRouter::new().tool(tool_v1);
3442
3443 let tool_v2 = ToolBuilder::new("shared")
3445 .description("Version 2")
3446 .handler(|_: StringInput| async move { Ok(CallToolResult::text("v2")) })
3447 .build()
3448 .unwrap();
3449
3450 let router_b = McpRouter::new().tool(tool_v2);
3451
3452 let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3454
3455 init_router(&mut merged).await;
3456
3457 let req = RouterRequest {
3458 id: RequestId::Number(1),
3459 inner: McpRequest::ListTools(ListToolsParams::default()),
3460 extensions: Extensions::new(),
3461 };
3462
3463 let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3464
3465 match resp.inner {
3466 Ok(McpResponse::ListTools(result)) => {
3467 assert_eq!(result.tools.len(), 1);
3468 assert_eq!(result.tools[0].name, "shared");
3469 assert_eq!(result.tools[0].description.as_deref(), Some("Version 2"));
3470 }
3471 _ => panic!("Expected ListTools response"),
3472 }
3473 }
3474
3475 #[tokio::test]
3476 async fn test_router_merge_resources() {
3477 use crate::resource::ResourceBuilder;
3478
3479 let router_a = McpRouter::new().resource(
3481 ResourceBuilder::new("file:///a.txt")
3482 .name("File A")
3483 .text("content a"),
3484 );
3485
3486 let router_b = McpRouter::new().resource(
3487 ResourceBuilder::new("file:///b.txt")
3488 .name("File B")
3489 .text("content b"),
3490 );
3491
3492 let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3493
3494 init_router(&mut merged).await;
3495
3496 let req = RouterRequest {
3497 id: RequestId::Number(1),
3498 inner: McpRequest::ListResources(ListResourcesParams::default()),
3499 extensions: Extensions::new(),
3500 };
3501
3502 let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3503
3504 match resp.inner {
3505 Ok(McpResponse::ListResources(result)) => {
3506 assert_eq!(result.resources.len(), 2);
3507 let uris: Vec<&str> = result.resources.iter().map(|r| r.uri.as_str()).collect();
3508 assert!(uris.contains(&"file:///a.txt"));
3509 assert!(uris.contains(&"file:///b.txt"));
3510 }
3511 _ => panic!("Expected ListResources response"),
3512 }
3513 }
3514
3515 #[tokio::test]
3516 async fn test_router_merge_prompts() {
3517 use crate::prompt::PromptBuilder;
3518
3519 let router_a =
3520 McpRouter::new().prompt(PromptBuilder::new("prompt_a").user_message("Hello A"));
3521
3522 let router_b =
3523 McpRouter::new().prompt(PromptBuilder::new("prompt_b").user_message("Hello B"));
3524
3525 let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3526
3527 init_router(&mut merged).await;
3528
3529 let req = RouterRequest {
3530 id: RequestId::Number(1),
3531 inner: McpRequest::ListPrompts(ListPromptsParams::default()),
3532 extensions: Extensions::new(),
3533 };
3534
3535 let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3536
3537 match resp.inner {
3538 Ok(McpResponse::ListPrompts(result)) => {
3539 assert_eq!(result.prompts.len(), 2);
3540 let names: Vec<&str> = result.prompts.iter().map(|p| p.name.as_str()).collect();
3541 assert!(names.contains(&"prompt_a"));
3542 assert!(names.contains(&"prompt_b"));
3543 }
3544 _ => panic!("Expected ListPrompts response"),
3545 }
3546 }
3547
3548 #[tokio::test]
3549 async fn test_router_nest_prefixes_tools() {
3550 let tool_query = ToolBuilder::new("query")
3552 .description("Query the database")
3553 .handler(|_: StringInput| async move { Ok(CallToolResult::text("query result")) })
3554 .build()
3555 .unwrap();
3556 let tool_insert = ToolBuilder::new("insert")
3557 .description("Insert into database")
3558 .handler(|_: StringInput| async move { Ok(CallToolResult::text("insert result")) })
3559 .build()
3560 .unwrap();
3561
3562 let db_router = McpRouter::new().tool(tool_query).tool(tool_insert);
3563
3564 let mut router = McpRouter::new()
3566 .server_info("nested", "1.0")
3567 .nest("db", db_router);
3568
3569 init_router(&mut router).await;
3570
3571 let req = RouterRequest {
3572 id: RequestId::Number(1),
3573 inner: McpRequest::ListTools(ListToolsParams::default()),
3574 extensions: Extensions::new(),
3575 };
3576
3577 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3578
3579 match resp.inner {
3580 Ok(McpResponse::ListTools(result)) => {
3581 assert_eq!(result.tools.len(), 2);
3582 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3583 assert!(names.contains(&"db.query"));
3584 assert!(names.contains(&"db.insert"));
3585 }
3586 _ => panic!("Expected ListTools response"),
3587 }
3588 }
3589
3590 #[tokio::test]
3591 async fn test_router_nest_call_prefixed_tool() {
3592 let tool = ToolBuilder::new("echo")
3593 .description("Echo input")
3594 .handler(|input: StringInput| async move { Ok(CallToolResult::text(&input.value)) })
3595 .build()
3596 .unwrap();
3597
3598 let nested_router = McpRouter::new().tool(tool);
3599
3600 let mut router = McpRouter::new().nest("api", nested_router);
3601
3602 init_router(&mut router).await;
3603
3604 let req = RouterRequest {
3606 id: RequestId::Number(1),
3607 inner: McpRequest::CallTool(CallToolParams {
3608 name: "api.echo".to_string(),
3609 arguments: serde_json::json!({"value": "hello world"}),
3610 meta: None,
3611 }),
3612 extensions: Extensions::new(),
3613 };
3614
3615 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3616
3617 match resp.inner {
3618 Ok(McpResponse::CallTool(result)) => {
3619 assert!(!result.is_error);
3620 match &result.content[0] {
3621 Content::Text { text, .. } => assert_eq!(text, "hello world"),
3622 _ => panic!("Expected text content"),
3623 }
3624 }
3625 _ => panic!("Expected CallTool response"),
3626 }
3627 }
3628
3629 #[tokio::test]
3630 async fn test_router_multiple_nests() {
3631 let db_tool = ToolBuilder::new("query")
3632 .description("Database query")
3633 .handler(|_: StringInput| async move { Ok(CallToolResult::text("db")) })
3634 .build()
3635 .unwrap();
3636
3637 let api_tool = ToolBuilder::new("fetch")
3638 .description("API fetch")
3639 .handler(|_: StringInput| async move { Ok(CallToolResult::text("api")) })
3640 .build()
3641 .unwrap();
3642
3643 let db_router = McpRouter::new().tool(db_tool);
3644 let api_router = McpRouter::new().tool(api_tool);
3645
3646 let mut router = McpRouter::new()
3647 .nest("db", db_router)
3648 .nest("api", api_router);
3649
3650 init_router(&mut router).await;
3651
3652 let req = RouterRequest {
3653 id: RequestId::Number(1),
3654 inner: McpRequest::ListTools(ListToolsParams::default()),
3655 extensions: Extensions::new(),
3656 };
3657
3658 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3659
3660 match resp.inner {
3661 Ok(McpResponse::ListTools(result)) => {
3662 assert_eq!(result.tools.len(), 2);
3663 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3664 assert!(names.contains(&"db.query"));
3665 assert!(names.contains(&"api.fetch"));
3666 }
3667 _ => panic!("Expected ListTools response"),
3668 }
3669 }
3670
3671 #[tokio::test]
3672 async fn test_router_merge_and_nest_combined() {
3673 let tool_a = ToolBuilder::new("local")
3675 .description("Local tool")
3676 .handler(|_: StringInput| async move { Ok(CallToolResult::text("local")) })
3677 .build()
3678 .unwrap();
3679
3680 let nested_tool = ToolBuilder::new("remote")
3681 .description("Remote tool")
3682 .handler(|_: StringInput| async move { Ok(CallToolResult::text("remote")) })
3683 .build()
3684 .unwrap();
3685
3686 let nested_router = McpRouter::new().tool(nested_tool);
3687
3688 let mut router = McpRouter::new()
3689 .tool(tool_a)
3690 .nest("external", nested_router);
3691
3692 init_router(&mut router).await;
3693
3694 let req = RouterRequest {
3695 id: RequestId::Number(1),
3696 inner: McpRequest::ListTools(ListToolsParams::default()),
3697 extensions: Extensions::new(),
3698 };
3699
3700 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3701
3702 match resp.inner {
3703 Ok(McpResponse::ListTools(result)) => {
3704 assert_eq!(result.tools.len(), 2);
3705 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3706 assert!(names.contains(&"local"));
3707 assert!(names.contains(&"external.remote"));
3708 }
3709 _ => panic!("Expected ListTools response"),
3710 }
3711 }
3712
3713 #[tokio::test]
3714 async fn test_router_merge_preserves_server_info() {
3715 let child_router = McpRouter::new()
3716 .server_info("child", "2.0")
3717 .instructions("Child instructions");
3718
3719 let mut router = McpRouter::new()
3720 .server_info("parent", "1.0")
3721 .instructions("Parent instructions")
3722 .merge(child_router);
3723
3724 init_router(&mut router).await;
3725
3726 let init_req = RouterRequest {
3728 id: RequestId::Number(99),
3729 inner: McpRequest::Initialize(InitializeParams {
3730 protocol_version: "2025-11-25".to_string(),
3731 capabilities: ClientCapabilities::default(),
3732 client_info: Implementation {
3733 name: "test".to_string(),
3734 version: "1.0".to_string(),
3735 ..Default::default()
3736 },
3737 }),
3738 extensions: Extensions::new(),
3739 };
3740
3741 let child_router2 = McpRouter::new().server_info("child", "2.0");
3743 let mut fresh_router = McpRouter::new()
3744 .server_info("parent", "1.0")
3745 .merge(child_router2);
3746
3747 let resp = fresh_router
3748 .ready()
3749 .await
3750 .unwrap()
3751 .call(init_req)
3752 .await
3753 .unwrap();
3754
3755 match resp.inner {
3756 Ok(McpResponse::Initialize(result)) => {
3757 assert_eq!(result.server_info.name, "parent");
3758 assert_eq!(result.server_info.version, "1.0");
3759 }
3760 _ => panic!("Expected Initialize response"),
3761 }
3762 }
3763}