1use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, RwLock};
10use std::task::{Context, Poll};
11
12use tower_service::Service;
13
14use crate::async_task::TaskStore;
15use crate::context::{
16 CancellationToken, ClientRequesterHandle, NotificationSender, RequestContext,
17 ServerNotification,
18};
19use crate::error::{Error, JsonRpcError, Result};
20use crate::filter::{PromptFilter, ResourceFilter, ToolFilter};
21use crate::prompt::Prompt;
22use crate::protocol::*;
23use crate::resource::{Resource, ResourceTemplate};
24use crate::session::SessionState;
25use crate::tool::Tool;
26
27pub type CompletionHandler = Arc<
29 dyn Fn(CompleteParams) -> Pin<Box<dyn Future<Output = Result<CompleteResult>> + Send>>
30 + Send
31 + Sync,
32>;
33
34#[derive(Clone)]
58pub struct McpRouter {
59 inner: Arc<McpRouterInner>,
60 session: SessionState,
61}
62
63impl std::fmt::Debug for McpRouter {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 f.debug_struct("McpRouter")
66 .field("server_name", &self.inner.server_name)
67 .field("server_version", &self.inner.server_version)
68 .field("tools_count", &self.inner.tools.len())
69 .field("resources_count", &self.inner.resources.len())
70 .field("prompts_count", &self.inner.prompts.len())
71 .field("session_phase", &self.session.phase())
72 .finish()
73 }
74}
75
76#[derive(Clone, Debug)]
78struct AutoInstructionsConfig {
79 prefix: Option<String>,
80 suffix: Option<String>,
81}
82
83#[derive(Clone)]
85struct McpRouterInner {
86 server_name: String,
87 server_version: String,
88 server_title: Option<String>,
90 server_description: Option<String>,
92 server_icons: Option<Vec<ToolIcon>>,
94 server_website_url: Option<String>,
96 instructions: Option<String>,
97 auto_instructions: Option<AutoInstructionsConfig>,
98 tools: HashMap<String, Arc<Tool>>,
99 resources: HashMap<String, Arc<Resource>>,
100 resource_templates: Vec<Arc<ResourceTemplate>>,
102 prompts: HashMap<String, Arc<Prompt>>,
103 in_flight: Arc<RwLock<HashMap<RequestId, CancellationToken>>>,
105 notification_tx: Option<NotificationSender>,
107 client_requester: Option<ClientRequesterHandle>,
109 task_store: TaskStore,
111 subscriptions: Arc<RwLock<HashSet<String>>>,
113 completion_handler: Option<CompletionHandler>,
115 tool_filter: Option<ToolFilter>,
117 resource_filter: Option<ResourceFilter>,
119 prompt_filter: Option<PromptFilter>,
121 extensions: Arc<crate::context::Extensions>,
123}
124
125impl McpRouterInner {
126 fn generate_instructions(&self, config: &AutoInstructionsConfig) -> String {
128 let mut parts = Vec::new();
129
130 if let Some(prefix) = &config.prefix {
131 parts.push(prefix.clone());
132 }
133
134 if !self.tools.is_empty() {
136 let mut lines = vec!["## Tools".to_string(), String::new()];
137 let mut tools: Vec<_> = self.tools.values().collect();
138 tools.sort_by(|a, b| a.name.cmp(&b.name));
139 for tool in tools {
140 let desc = tool.description.as_deref().unwrap_or("No description");
141 let tags = annotation_tags(tool.annotations.as_ref());
142 if tags.is_empty() {
143 lines.push(format!("- **{}**: {}", tool.name, desc));
144 } else {
145 lines.push(format!("- **{}**: {} [{}]", tool.name, desc, tags));
146 }
147 }
148 parts.push(lines.join("\n"));
149 }
150
151 if !self.resources.is_empty() || !self.resource_templates.is_empty() {
153 let mut lines = vec!["## Resources".to_string(), String::new()];
154 let mut resources: Vec<_> = self.resources.values().collect();
155 resources.sort_by(|a, b| a.uri.cmp(&b.uri));
156 for resource in resources {
157 let desc = resource.description.as_deref().unwrap_or("No description");
158 lines.push(format!("- **{}**: {}", resource.uri, desc));
159 }
160 let mut templates: Vec<_> = self.resource_templates.iter().collect();
161 templates.sort_by(|a, b| a.uri_template.cmp(&b.uri_template));
162 for template in templates {
163 let desc = template.description.as_deref().unwrap_or("No description");
164 lines.push(format!("- **{}**: {}", template.uri_template, desc));
165 }
166 parts.push(lines.join("\n"));
167 }
168
169 if !self.prompts.is_empty() {
171 let mut lines = vec!["## Prompts".to_string(), String::new()];
172 let mut prompts: Vec<_> = self.prompts.values().collect();
173 prompts.sort_by(|a, b| a.name.cmp(&b.name));
174 for prompt in prompts {
175 let desc = prompt.description.as_deref().unwrap_or("No description");
176 lines.push(format!("- **{}**: {}", prompt.name, desc));
177 }
178 parts.push(lines.join("\n"));
179 }
180
181 if let Some(suffix) = &config.suffix {
182 parts.push(suffix.clone());
183 }
184
185 parts.join("\n\n")
186 }
187}
188
189fn annotation_tags(annotations: Option<&crate::protocol::ToolAnnotations>) -> String {
195 let Some(ann) = annotations else {
196 return String::new();
197 };
198 let mut tags = Vec::new();
199 if ann.is_read_only() {
200 tags.push("read-only");
201 }
202 if ann.is_idempotent() {
203 tags.push("idempotent");
204 }
205 tags.join(", ")
206}
207
208impl McpRouter {
209 pub fn new() -> Self {
211 Self {
212 inner: Arc::new(McpRouterInner {
213 server_name: "tower-mcp".to_string(),
214 server_version: env!("CARGO_PKG_VERSION").to_string(),
215 server_title: None,
216 server_description: None,
217 server_icons: None,
218 server_website_url: None,
219 instructions: None,
220 auto_instructions: None,
221 tools: HashMap::new(),
222 resources: HashMap::new(),
223 resource_templates: Vec::new(),
224 prompts: HashMap::new(),
225 in_flight: Arc::new(RwLock::new(HashMap::new())),
226 notification_tx: None,
227 client_requester: None,
228 task_store: TaskStore::new(),
229 subscriptions: Arc::new(RwLock::new(HashSet::new())),
230 extensions: Arc::new(crate::context::Extensions::new()),
231 completion_handler: None,
232 tool_filter: None,
233 resource_filter: None,
234 prompt_filter: None,
235 }),
236 session: SessionState::new(),
237 }
238 }
239
240 pub fn with_fresh_session(&self) -> Self {
248 Self {
249 inner: self.inner.clone(),
250 session: SessionState::new(),
251 }
252 }
253
254 pub fn task_store(&self) -> &TaskStore {
256 &self.inner.task_store
257 }
258
259 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
263 Arc::make_mut(&mut self.inner).notification_tx = Some(tx);
264 self
265 }
266
267 pub fn notification_sender(&self) -> Option<&NotificationSender> {
269 self.inner.notification_tx.as_ref()
270 }
271
272 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
277 Arc::make_mut(&mut self.inner).client_requester = Some(requester);
278 self
279 }
280
281 pub fn client_requester(&self) -> Option<&ClientRequesterHandle> {
283 self.inner.client_requester.as_ref()
284 }
285
286 pub fn with_state<T: Clone + Send + Sync + 'static>(mut self, state: T) -> Self {
329 let inner = Arc::make_mut(&mut self.inner);
330 Arc::make_mut(&mut inner.extensions).insert(state);
331 self
332 }
333
334 pub fn with_extension<T: Clone + Send + Sync + 'static>(self, value: T) -> Self {
339 self.with_state(value)
340 }
341
342 pub fn extensions(&self) -> &crate::context::Extensions {
344 &self.inner.extensions
345 }
346
347 pub fn create_context(
352 &self,
353 request_id: RequestId,
354 progress_token: Option<ProgressToken>,
355 ) -> RequestContext {
356 let ctx = RequestContext::new(request_id.clone());
357
358 let ctx = if let Some(token) = progress_token {
360 ctx.with_progress_token(token)
361 } else {
362 ctx
363 };
364
365 let ctx = if let Some(tx) = &self.inner.notification_tx {
367 ctx.with_notification_sender(tx.clone())
368 } else {
369 ctx
370 };
371
372 let ctx = if let Some(requester) = &self.inner.client_requester {
374 ctx.with_client_requester(requester.clone())
375 } else {
376 ctx
377 };
378
379 let ctx = ctx.with_extensions(self.inner.extensions.clone());
381
382 let token = ctx.cancellation_token();
384 if let Ok(mut in_flight) = self.inner.in_flight.write() {
385 in_flight.insert(request_id, token);
386 }
387
388 ctx
389 }
390
391 pub fn complete_request(&self, request_id: &RequestId) {
393 if let Ok(mut in_flight) = self.inner.in_flight.write() {
394 in_flight.remove(request_id);
395 }
396 }
397
398 fn cancel_request(&self, request_id: &RequestId) -> bool {
400 let Ok(in_flight) = self.inner.in_flight.read() else {
401 return false;
402 };
403 let Some(token) = in_flight.get(request_id) else {
404 return false;
405 };
406 token.cancel();
407 true
408 }
409
410 pub fn server_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
412 let inner = Arc::make_mut(&mut self.inner);
413 inner.server_name = name.into();
414 inner.server_version = version.into();
415 self
416 }
417
418 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
420 Arc::make_mut(&mut self.inner).instructions = Some(instructions.into());
421 self
422 }
423
424 pub fn auto_instructions(mut self) -> Self {
456 Arc::make_mut(&mut self.inner).auto_instructions = Some(AutoInstructionsConfig {
457 prefix: None,
458 suffix: None,
459 });
460 self
461 }
462
463 pub fn auto_instructions_with(
480 mut self,
481 prefix: Option<impl Into<String>>,
482 suffix: Option<impl Into<String>>,
483 ) -> Self {
484 Arc::make_mut(&mut self.inner).auto_instructions = Some(AutoInstructionsConfig {
485 prefix: prefix.map(Into::into),
486 suffix: suffix.map(Into::into),
487 });
488 self
489 }
490
491 pub fn server_title(mut self, title: impl Into<String>) -> Self {
493 Arc::make_mut(&mut self.inner).server_title = Some(title.into());
494 self
495 }
496
497 pub fn server_description(mut self, description: impl Into<String>) -> Self {
499 Arc::make_mut(&mut self.inner).server_description = Some(description.into());
500 self
501 }
502
503 pub fn server_icons(mut self, icons: Vec<ToolIcon>) -> Self {
505 Arc::make_mut(&mut self.inner).server_icons = Some(icons);
506 self
507 }
508
509 pub fn server_website_url(mut self, url: impl Into<String>) -> Self {
511 Arc::make_mut(&mut self.inner).server_website_url = Some(url.into());
512 self
513 }
514
515 pub fn tool(mut self, tool: Tool) -> Self {
517 Arc::make_mut(&mut self.inner)
518 .tools
519 .insert(tool.name.clone(), Arc::new(tool));
520 self
521 }
522
523 pub fn resource(mut self, resource: Resource) -> Self {
525 Arc::make_mut(&mut self.inner)
526 .resources
527 .insert(resource.uri.clone(), Arc::new(resource));
528 self
529 }
530
531 pub fn resource_template(mut self, template: ResourceTemplate) -> Self {
562 Arc::make_mut(&mut self.inner)
563 .resource_templates
564 .push(Arc::new(template));
565 self
566 }
567
568 pub fn prompt(mut self, prompt: Prompt) -> Self {
570 Arc::make_mut(&mut self.inner)
571 .prompts
572 .insert(prompt.name.clone(), Arc::new(prompt));
573 self
574 }
575
576 pub fn tools(self, tools: impl IntoIterator<Item = Tool>) -> Self {
602 tools
603 .into_iter()
604 .fold(self, |router, tool| router.tool(tool))
605 }
606
607 pub fn resources(self, resources: impl IntoIterator<Item = Resource>) -> Self {
626 resources
627 .into_iter()
628 .fold(self, |router, resource| router.resource(resource))
629 }
630
631 pub fn prompts(self, prompts: impl IntoIterator<Item = Prompt>) -> Self {
650 prompts
651 .into_iter()
652 .fold(self, |router, prompt| router.prompt(prompt))
653 }
654
655 pub fn merge(mut self, other: McpRouter) -> Self {
700 let inner = Arc::make_mut(&mut self.inner);
701 let other_inner = other.inner;
702
703 for (name, tool) in &other_inner.tools {
705 inner.tools.insert(name.clone(), tool.clone());
706 }
707
708 for (uri, resource) in &other_inner.resources {
710 inner.resources.insert(uri.clone(), resource.clone());
711 }
712
713 for template in &other_inner.resource_templates {
716 inner.resource_templates.push(template.clone());
717 }
718
719 for (name, prompt) in &other_inner.prompts {
721 inner.prompts.insert(name.clone(), prompt.clone());
722 }
723
724 self
725 }
726
727 pub fn nest(mut self, prefix: impl Into<String>, other: McpRouter) -> Self {
767 let prefix = prefix.into();
768 let inner = Arc::make_mut(&mut self.inner);
769 let other_inner = other.inner;
770
771 for tool in other_inner.tools.values() {
773 let prefixed_tool = tool.with_name_prefix(&prefix);
774 inner
775 .tools
776 .insert(prefixed_tool.name.clone(), Arc::new(prefixed_tool));
777 }
778
779 for (uri, resource) in &other_inner.resources {
781 inner.resources.insert(uri.clone(), resource.clone());
782 }
783
784 for template in &other_inner.resource_templates {
786 inner.resource_templates.push(template.clone());
787 }
788
789 for (name, prompt) in &other_inner.prompts {
791 inner.prompts.insert(name.clone(), prompt.clone());
792 }
793
794 self
795 }
796
797 pub fn completion_handler<F, Fut>(mut self, handler: F) -> Self
824 where
825 F: Fn(CompleteParams) -> Fut + Send + Sync + 'static,
826 Fut: Future<Output = Result<CompleteResult>> + Send + 'static,
827 {
828 Arc::make_mut(&mut self.inner).completion_handler =
829 Some(Arc::new(move |params| Box::pin(handler(params))));
830 self
831 }
832
833 pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
868 Arc::make_mut(&mut self.inner).tool_filter = Some(filter);
869 self
870 }
871
872 pub fn resource_filter(mut self, filter: ResourceFilter) -> Self {
903 Arc::make_mut(&mut self.inner).resource_filter = Some(filter);
904 self
905 }
906
907 pub fn prompt_filter(mut self, filter: PromptFilter) -> Self {
936 Arc::make_mut(&mut self.inner).prompt_filter = Some(filter);
937 self
938 }
939
940 pub fn session(&self) -> &SessionState {
942 &self.session
943 }
944
945 pub fn log(&self, params: LoggingMessageParams) -> bool {
967 let Some(tx) = &self.inner.notification_tx else {
968 return false;
969 };
970 tx.try_send(ServerNotification::LogMessage(params)).is_ok()
971 }
972
973 pub fn log_info(&self, message: &str) -> bool {
977 self.log(
978 LoggingMessageParams::new(LogLevel::Info)
979 .with_data(serde_json::json!({ "message": message })),
980 )
981 }
982
983 pub fn log_warning(&self, message: &str) -> bool {
985 self.log(
986 LoggingMessageParams::new(LogLevel::Warning)
987 .with_data(serde_json::json!({ "message": message })),
988 )
989 }
990
991 pub fn log_error(&self, message: &str) -> bool {
993 self.log(
994 LoggingMessageParams::new(LogLevel::Error)
995 .with_data(serde_json::json!({ "message": message })),
996 )
997 }
998
999 pub fn log_debug(&self, message: &str) -> bool {
1001 self.log(
1002 LoggingMessageParams::new(LogLevel::Debug)
1003 .with_data(serde_json::json!({ "message": message })),
1004 )
1005 }
1006
1007 pub fn is_subscribed(&self, uri: &str) -> bool {
1009 if let Ok(subs) = self.inner.subscriptions.read() {
1010 return subs.contains(uri);
1011 }
1012 false
1013 }
1014
1015 pub fn subscribed_uris(&self) -> Vec<String> {
1017 if let Ok(subs) = self.inner.subscriptions.read() {
1018 return subs.iter().cloned().collect();
1019 }
1020 Vec::new()
1021 }
1022
1023 fn subscribe(&self, uri: &str) -> bool {
1025 if let Ok(mut subs) = self.inner.subscriptions.write() {
1026 return subs.insert(uri.to_string());
1027 }
1028 false
1029 }
1030
1031 fn unsubscribe(&self, uri: &str) -> bool {
1033 if let Ok(mut subs) = self.inner.subscriptions.write() {
1034 return subs.remove(uri);
1035 }
1036 false
1037 }
1038
1039 pub fn notify_resource_updated(&self, uri: &str) -> bool {
1044 if !self.is_subscribed(uri) {
1046 return false;
1047 }
1048
1049 let Some(tx) = &self.inner.notification_tx else {
1050 return false;
1051 };
1052 tx.try_send(ServerNotification::ResourceUpdated {
1053 uri: uri.to_string(),
1054 })
1055 .is_ok()
1056 }
1057
1058 pub fn notify_resources_list_changed(&self) -> bool {
1062 let Some(tx) = &self.inner.notification_tx else {
1063 return false;
1064 };
1065 tx.try_send(ServerNotification::ResourcesListChanged)
1066 .is_ok()
1067 }
1068
1069 fn capabilities(&self) -> ServerCapabilities {
1071 let has_resources =
1072 !self.inner.resources.is_empty() || !self.inner.resource_templates.is_empty();
1073
1074 ServerCapabilities {
1075 tools: if self.inner.tools.is_empty() {
1076 None
1077 } else {
1078 Some(ToolsCapability::default())
1079 },
1080 resources: if has_resources {
1081 Some(ResourcesCapability {
1082 subscribe: true,
1083 ..Default::default()
1084 })
1085 } else {
1086 None
1087 },
1088 prompts: if self.inner.prompts.is_empty() {
1089 None
1090 } else {
1091 Some(PromptsCapability::default())
1092 },
1093 logging: if self.inner.notification_tx.is_some() {
1095 Some(LoggingCapability::default())
1096 } else {
1097 None
1098 },
1099 tasks: Some(TasksCapability::default()),
1101 completions: if self.inner.completion_handler.is_some() {
1103 Some(CompletionsCapability::default())
1104 } else {
1105 None
1106 },
1107 }
1108 }
1109
1110 async fn handle(&self, request_id: RequestId, request: McpRequest) -> Result<McpResponse> {
1112 let method = request.method_name();
1114 if !self.session.is_request_allowed(method) {
1115 tracing::warn!(
1116 method = %method,
1117 phase = ?self.session.phase(),
1118 "Request rejected: session not initialized"
1119 );
1120 return Err(Error::JsonRpc(JsonRpcError::invalid_request(format!(
1121 "Session not initialized. Only 'initialize' and 'ping' are allowed before initialization. Got: {}",
1122 method
1123 ))));
1124 }
1125
1126 match request {
1127 McpRequest::Initialize(params) => {
1128 tracing::info!(
1129 client = %params.client_info.name,
1130 version = %params.client_info.version,
1131 "Client initializing"
1132 );
1133
1134 let protocol_version = if crate::protocol::SUPPORTED_PROTOCOL_VERSIONS
1137 .contains(¶ms.protocol_version.as_str())
1138 {
1139 params.protocol_version
1140 } else {
1141 crate::protocol::LATEST_PROTOCOL_VERSION.to_string()
1142 };
1143
1144 self.session.mark_initializing();
1146
1147 Ok(McpResponse::Initialize(InitializeResult {
1148 protocol_version,
1149 capabilities: self.capabilities(),
1150 server_info: Implementation {
1151 name: self.inner.server_name.clone(),
1152 version: self.inner.server_version.clone(),
1153 title: self.inner.server_title.clone(),
1154 description: self.inner.server_description.clone(),
1155 icons: self.inner.server_icons.clone(),
1156 website_url: self.inner.server_website_url.clone(),
1157 },
1158 instructions: if let Some(config) = &self.inner.auto_instructions {
1159 Some(self.inner.generate_instructions(config))
1160 } else {
1161 self.inner.instructions.clone()
1162 },
1163 }))
1164 }
1165
1166 McpRequest::ListTools(_params) => {
1167 let tools: Vec<ToolDefinition> = self
1168 .inner
1169 .tools
1170 .values()
1171 .filter(|t| {
1172 self.inner
1174 .tool_filter
1175 .as_ref()
1176 .map(|f| f.is_visible(&self.session, t))
1177 .unwrap_or(true)
1178 })
1179 .map(|t| t.definition())
1180 .collect();
1181
1182 Ok(McpResponse::ListTools(ListToolsResult {
1183 tools,
1184 next_cursor: None,
1185 }))
1186 }
1187
1188 McpRequest::CallTool(params) => {
1189 let tool =
1190 self.inner.tools.get(¶ms.name).ok_or_else(|| {
1191 Error::JsonRpc(JsonRpcError::method_not_found(¶ms.name))
1192 })?;
1193
1194 if let Some(filter) = &self.inner.tool_filter
1196 && !filter.is_visible(&self.session, tool)
1197 {
1198 return Err(filter.denial_error(¶ms.name));
1199 }
1200
1201 let progress_token = params.meta.and_then(|m| m.progress_token);
1203 let ctx = self.create_context(request_id, progress_token);
1204
1205 tracing::debug!(tool = %params.name, "Calling tool");
1206 let result = tool.call_with_context(ctx, params.arguments).await;
1207
1208 Ok(McpResponse::CallTool(result))
1209 }
1210
1211 McpRequest::ListResources(_params) => {
1212 let resources: Vec<ResourceDefinition> = self
1213 .inner
1214 .resources
1215 .values()
1216 .filter(|r| {
1217 self.inner
1219 .resource_filter
1220 .as_ref()
1221 .map(|f| f.is_visible(&self.session, r))
1222 .unwrap_or(true)
1223 })
1224 .map(|r| r.definition())
1225 .collect();
1226
1227 Ok(McpResponse::ListResources(ListResourcesResult {
1228 resources,
1229 next_cursor: None,
1230 }))
1231 }
1232
1233 McpRequest::ListResourceTemplates(_params) => {
1234 let resource_templates: Vec<ResourceTemplateDefinition> = self
1235 .inner
1236 .resource_templates
1237 .iter()
1238 .map(|t| t.definition())
1239 .collect();
1240
1241 Ok(McpResponse::ListResourceTemplates(
1242 ListResourceTemplatesResult {
1243 resource_templates,
1244 next_cursor: None,
1245 },
1246 ))
1247 }
1248
1249 McpRequest::ReadResource(params) => {
1250 if let Some(resource) = self.inner.resources.get(¶ms.uri) {
1252 if let Some(filter) = &self.inner.resource_filter
1254 && !filter.is_visible(&self.session, resource)
1255 {
1256 return Err(filter.denial_error(¶ms.uri));
1257 }
1258
1259 tracing::debug!(uri = %params.uri, "Reading static resource");
1260 let result = resource.read().await;
1261 return Ok(McpResponse::ReadResource(result));
1262 }
1263
1264 for template in &self.inner.resource_templates {
1266 if let Some(variables) = template.match_uri(¶ms.uri) {
1267 tracing::debug!(
1268 uri = %params.uri,
1269 template = %template.uri_template,
1270 "Reading resource via template"
1271 );
1272 let result = template.read(¶ms.uri, variables).await?;
1273 return Ok(McpResponse::ReadResource(result));
1274 }
1275 }
1276
1277 Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1279 ¶ms.uri,
1280 )))
1281 }
1282
1283 McpRequest::SubscribeResource(params) => {
1284 if !self.inner.resources.contains_key(¶ms.uri) {
1286 return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1287 ¶ms.uri,
1288 )));
1289 }
1290
1291 tracing::debug!(uri = %params.uri, "Subscribing to resource");
1292 self.subscribe(¶ms.uri);
1293
1294 Ok(McpResponse::SubscribeResource(EmptyResult {}))
1295 }
1296
1297 McpRequest::UnsubscribeResource(params) => {
1298 if !self.inner.resources.contains_key(¶ms.uri) {
1300 return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1301 ¶ms.uri,
1302 )));
1303 }
1304
1305 tracing::debug!(uri = %params.uri, "Unsubscribing from resource");
1306 self.unsubscribe(¶ms.uri);
1307
1308 Ok(McpResponse::UnsubscribeResource(EmptyResult {}))
1309 }
1310
1311 McpRequest::ListPrompts(_params) => {
1312 let prompts: Vec<PromptDefinition> = self
1313 .inner
1314 .prompts
1315 .values()
1316 .filter(|p| {
1317 self.inner
1319 .prompt_filter
1320 .as_ref()
1321 .map(|f| f.is_visible(&self.session, p))
1322 .unwrap_or(true)
1323 })
1324 .map(|p| p.definition())
1325 .collect();
1326
1327 Ok(McpResponse::ListPrompts(ListPromptsResult {
1328 prompts,
1329 next_cursor: None,
1330 }))
1331 }
1332
1333 McpRequest::GetPrompt(params) => {
1334 let prompt = self.inner.prompts.get(¶ms.name).ok_or_else(|| {
1335 Error::JsonRpc(JsonRpcError::method_not_found(&format!(
1336 "Prompt not found: {}",
1337 params.name
1338 )))
1339 })?;
1340
1341 if let Some(filter) = &self.inner.prompt_filter
1343 && !filter.is_visible(&self.session, prompt)
1344 {
1345 return Err(filter.denial_error(¶ms.name));
1346 }
1347
1348 tracing::debug!(name = %params.name, "Getting prompt");
1349 let result = prompt.get(params.arguments).await?;
1350
1351 Ok(McpResponse::GetPrompt(result))
1352 }
1353
1354 McpRequest::Ping => Ok(McpResponse::Pong(EmptyResult {})),
1355
1356 McpRequest::EnqueueTask(params) => {
1357 let tool = self.inner.tools.get(¶ms.tool_name).ok_or_else(|| {
1359 Error::JsonRpc(JsonRpcError::method_not_found(&format!(
1360 "Tool not found: {}",
1361 params.tool_name
1362 )))
1363 })?;
1364
1365 let (task_id, cancellation_token) = self.inner.task_store.create_task(
1367 ¶ms.tool_name,
1368 params.arguments.clone(),
1369 params.ttl,
1370 );
1371
1372 tracing::info!(task_id = %task_id, tool = %params.tool_name, "Enqueued async task");
1373
1374 let ctx = self.create_context(request_id, None);
1376
1377 let task_store = self.inner.task_store.clone();
1379 let tool = tool.clone();
1380 let arguments = params.arguments;
1381 let task_id_clone = task_id.clone();
1382
1383 tokio::spawn(async move {
1384 if cancellation_token.is_cancelled() {
1386 tracing::debug!(task_id = %task_id_clone, "Task cancelled before execution");
1387 return;
1388 }
1389
1390 let result = tool.call_with_context(ctx, arguments).await;
1392
1393 if cancellation_token.is_cancelled() {
1394 tracing::debug!(task_id = %task_id_clone, "Task cancelled during execution");
1395 } else if result.is_error {
1396 let error_msg = result.first_text().unwrap_or("Tool execution failed");
1398 task_store.fail_task(&task_id_clone, error_msg);
1399 tracing::warn!(task_id = %task_id_clone, error = %error_msg, "Task failed");
1400 } else {
1401 task_store.complete_task(&task_id_clone, result);
1402 tracing::debug!(task_id = %task_id_clone, "Task completed successfully");
1403 }
1404 });
1405
1406 Ok(McpResponse::EnqueueTask(EnqueueTaskResult {
1407 task_id,
1408 status: TaskStatus::Working,
1409 poll_interval: Some(2),
1410 }))
1411 }
1412
1413 McpRequest::ListTasks(params) => {
1414 let tasks = self.inner.task_store.list_tasks(params.status);
1415
1416 Ok(McpResponse::ListTasks(ListTasksResult {
1417 tasks,
1418 next_cursor: None,
1419 }))
1420 }
1421
1422 McpRequest::GetTaskInfo(params) => {
1423 let task = self
1424 .inner
1425 .task_store
1426 .get_task(¶ms.task_id)
1427 .ok_or_else(|| {
1428 Error::JsonRpc(JsonRpcError::invalid_params(format!(
1429 "Task not found: {}",
1430 params.task_id
1431 )))
1432 })?;
1433
1434 Ok(McpResponse::GetTaskInfo(task))
1435 }
1436
1437 McpRequest::GetTaskResult(params) => {
1438 let (status, result, error) = self
1439 .inner
1440 .task_store
1441 .get_task_full(¶ms.task_id)
1442 .ok_or_else(|| {
1443 Error::JsonRpc(JsonRpcError::invalid_params(format!(
1444 "Task not found: {}",
1445 params.task_id
1446 )))
1447 })?;
1448
1449 Ok(McpResponse::GetTaskResult(GetTaskResultResult {
1450 task_id: params.task_id,
1451 status,
1452 result,
1453 error,
1454 }))
1455 }
1456
1457 McpRequest::CancelTask(params) => {
1458 let status = self
1459 .inner
1460 .task_store
1461 .cancel_task(¶ms.task_id, params.reason.as_deref())
1462 .ok_or_else(|| {
1463 Error::JsonRpc(JsonRpcError::invalid_params(format!(
1464 "Task not found: {}",
1465 params.task_id
1466 )))
1467 })?;
1468
1469 let cancelled = status == TaskStatus::Cancelled;
1470
1471 Ok(McpResponse::CancelTask(CancelTaskResult {
1472 cancelled,
1473 status,
1474 }))
1475 }
1476
1477 McpRequest::SetLoggingLevel(params) => {
1478 tracing::debug!(level = ?params.level, "Client set logging level");
1482 Ok(McpResponse::SetLoggingLevel(EmptyResult {}))
1483 }
1484
1485 McpRequest::Complete(params) => {
1486 tracing::debug!(
1487 reference = ?params.reference,
1488 argument = %params.argument.name,
1489 "Completion request"
1490 );
1491
1492 if let Some(ref handler) = self.inner.completion_handler {
1494 let result = handler(params).await?;
1495 Ok(McpResponse::Complete(result))
1496 } else {
1497 Ok(McpResponse::Complete(CompleteResult::new(vec![])))
1499 }
1500 }
1501
1502 McpRequest::Unknown { method, .. } => {
1503 Err(Error::JsonRpc(JsonRpcError::method_not_found(&method)))
1504 }
1505 }
1506 }
1507
1508 pub fn handle_notification(&self, notification: McpNotification) {
1510 match notification {
1511 McpNotification::Initialized => {
1512 if self.session.mark_initialized() {
1513 tracing::info!("Session initialized, entering operation phase");
1514 } else {
1515 tracing::warn!(
1516 "Received initialized notification in unexpected state: {:?}",
1517 self.session.phase()
1518 );
1519 }
1520 }
1521 McpNotification::Cancelled(params) => {
1522 if self.cancel_request(¶ms.request_id) {
1523 tracing::info!(
1524 request_id = ?params.request_id,
1525 reason = ?params.reason,
1526 "Request cancelled"
1527 );
1528 } else {
1529 tracing::debug!(
1530 request_id = ?params.request_id,
1531 reason = ?params.reason,
1532 "Cancellation requested for unknown request"
1533 );
1534 }
1535 }
1536 McpNotification::Progress(params) => {
1537 tracing::trace!(
1538 token = ?params.progress_token,
1539 progress = params.progress,
1540 total = ?params.total,
1541 "Progress notification"
1542 );
1543 }
1545 McpNotification::RootsListChanged => {
1546 tracing::info!("Client roots list changed");
1547 }
1550 McpNotification::Unknown { method, .. } => {
1551 tracing::debug!(method = %method, "Unknown notification received");
1552 }
1553 }
1554 }
1555}
1556
1557impl Default for McpRouter {
1558 fn default() -> Self {
1559 Self::new()
1560 }
1561}
1562
1563pub use crate::context::Extensions;
1569
1570#[derive(Debug, Clone)]
1572pub struct RouterRequest {
1573 pub id: RequestId,
1574 pub inner: McpRequest,
1575 pub extensions: Extensions,
1577}
1578
1579#[derive(Debug, Clone)]
1581pub struct RouterResponse {
1582 pub id: RequestId,
1583 pub inner: std::result::Result<McpResponse, JsonRpcError>,
1584}
1585
1586impl RouterResponse {
1587 pub fn into_jsonrpc(self) -> JsonRpcResponse {
1589 match self.inner {
1590 Ok(response) => match serde_json::to_value(response) {
1591 Ok(result) => JsonRpcResponse::result(self.id, result),
1592 Err(e) => {
1593 tracing::error!(error = %e, "Failed to serialize response");
1594 JsonRpcResponse::error(
1595 Some(self.id),
1596 JsonRpcError::internal_error(format!("Serialization error: {}", e)),
1597 )
1598 }
1599 },
1600 Err(error) => JsonRpcResponse::error(Some(self.id), error),
1601 }
1602 }
1603}
1604
1605impl Service<RouterRequest> for McpRouter {
1606 type Response = RouterResponse;
1607 type Error = std::convert::Infallible; type Future =
1609 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
1610
1611 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
1612 Poll::Ready(Ok(()))
1613 }
1614
1615 fn call(&mut self, req: RouterRequest) -> Self::Future {
1616 let router = self.clone();
1617 let request_id = req.id.clone();
1618 Box::pin(async move {
1619 let result = router.handle(req.id, req.inner).await;
1620 router.complete_request(&request_id);
1622 Ok(RouterResponse {
1623 id: request_id,
1624 inner: result.map_err(|e| match e {
1629 Error::JsonRpc(err) => err,
1630 Error::Tool(err) => JsonRpcError::internal_error(err.to_string()),
1631 e => JsonRpcError::internal_error(e.to_string()),
1632 }),
1633 })
1634 })
1635 }
1636}
1637
1638#[cfg(test)]
1639mod tests {
1640 use super::*;
1641 use crate::extract::{Context, Json};
1642 use crate::jsonrpc::JsonRpcService;
1643 use crate::tool::ToolBuilder;
1644 use schemars::JsonSchema;
1645 use serde::Deserialize;
1646 use tower::ServiceExt;
1647
1648 #[derive(Debug, Deserialize, JsonSchema)]
1649 struct AddInput {
1650 a: i64,
1651 b: i64,
1652 }
1653
1654 async fn init_router(router: &mut McpRouter) {
1656 let init_req = RouterRequest {
1658 id: RequestId::Number(0),
1659 inner: McpRequest::Initialize(InitializeParams {
1660 protocol_version: "2025-11-25".to_string(),
1661 capabilities: ClientCapabilities {
1662 roots: None,
1663 sampling: None,
1664 elicitation: None,
1665 },
1666 client_info: Implementation {
1667 name: "test".to_string(),
1668 version: "1.0".to_string(),
1669 ..Default::default()
1670 },
1671 }),
1672 extensions: Extensions::new(),
1673 };
1674 let _ = router.ready().await.unwrap().call(init_req).await.unwrap();
1675 router.handle_notification(McpNotification::Initialized);
1677 }
1678
1679 #[tokio::test]
1680 async fn test_router_list_tools() {
1681 let add_tool = ToolBuilder::new("add")
1682 .description("Add two numbers")
1683 .handler(|input: AddInput| async move {
1684 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1685 })
1686 .build();
1687
1688 let mut router = McpRouter::new().tool(add_tool);
1689
1690 init_router(&mut router).await;
1692
1693 let req = RouterRequest {
1694 id: RequestId::Number(1),
1695 inner: McpRequest::ListTools(ListToolsParams::default()),
1696 extensions: Extensions::new(),
1697 };
1698
1699 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1700
1701 match resp.inner {
1702 Ok(McpResponse::ListTools(result)) => {
1703 assert_eq!(result.tools.len(), 1);
1704 assert_eq!(result.tools[0].name, "add");
1705 }
1706 _ => panic!("Expected ListTools response"),
1707 }
1708 }
1709
1710 #[tokio::test]
1711 async fn test_router_call_tool() {
1712 let add_tool = ToolBuilder::new("add")
1713 .description("Add two numbers")
1714 .handler(|input: AddInput| async move {
1715 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1716 })
1717 .build();
1718
1719 let mut router = McpRouter::new().tool(add_tool);
1720
1721 init_router(&mut router).await;
1723
1724 let req = RouterRequest {
1725 id: RequestId::Number(1),
1726 inner: McpRequest::CallTool(CallToolParams {
1727 name: "add".to_string(),
1728 arguments: serde_json::json!({"a": 2, "b": 3}),
1729 meta: None,
1730 }),
1731 extensions: Extensions::new(),
1732 };
1733
1734 let resp = router.ready().await.unwrap().call(req).await.unwrap();
1735
1736 match resp.inner {
1737 Ok(McpResponse::CallTool(result)) => {
1738 assert!(!result.is_error);
1739 match &result.content[0] {
1741 Content::Text { text, .. } => assert_eq!(text, "5"),
1742 _ => panic!("Expected text content"),
1743 }
1744 }
1745 _ => panic!("Expected CallTool response"),
1746 }
1747 }
1748
1749 async fn init_jsonrpc_service(service: &mut JsonRpcService<McpRouter>, router: &McpRouter) {
1751 let init_req = JsonRpcRequest::new(0, "initialize").with_params(serde_json::json!({
1752 "protocolVersion": "2025-11-25",
1753 "capabilities": {},
1754 "clientInfo": { "name": "test", "version": "1.0" }
1755 }));
1756 let _ = service.call_single(init_req).await.unwrap();
1757 router.handle_notification(McpNotification::Initialized);
1758 }
1759
1760 #[tokio::test]
1761 async fn test_jsonrpc_service() {
1762 let add_tool = ToolBuilder::new("add")
1763 .description("Add two numbers")
1764 .handler(|input: AddInput| async move {
1765 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1766 })
1767 .build();
1768
1769 let router = McpRouter::new().tool(add_tool);
1770 let mut service = JsonRpcService::new(router.clone());
1771
1772 init_jsonrpc_service(&mut service, &router).await;
1774
1775 let req = JsonRpcRequest::new(1, "tools/list");
1776
1777 let resp = service.call_single(req).await.unwrap();
1778
1779 match resp {
1780 JsonRpcResponse::Result(r) => {
1781 assert_eq!(r.id, RequestId::Number(1));
1782 let tools = r.result.get("tools").unwrap().as_array().unwrap();
1783 assert_eq!(tools.len(), 1);
1784 }
1785 JsonRpcResponse::Error(_) => panic!("Expected success response"),
1786 }
1787 }
1788
1789 #[tokio::test]
1790 async fn test_batch_request() {
1791 let add_tool = ToolBuilder::new("add")
1792 .description("Add two numbers")
1793 .handler(|input: AddInput| async move {
1794 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1795 })
1796 .build();
1797
1798 let router = McpRouter::new().tool(add_tool);
1799 let mut service = JsonRpcService::new(router.clone());
1800
1801 init_jsonrpc_service(&mut service, &router).await;
1803
1804 let requests = vec![
1806 JsonRpcRequest::new(1, "tools/list"),
1807 JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1808 "name": "add",
1809 "arguments": {"a": 10, "b": 20}
1810 })),
1811 JsonRpcRequest::new(3, "ping"),
1812 ];
1813
1814 let responses = service.call_batch(requests).await.unwrap();
1815
1816 assert_eq!(responses.len(), 3);
1817
1818 match &responses[0] {
1820 JsonRpcResponse::Result(r) => {
1821 assert_eq!(r.id, RequestId::Number(1));
1822 let tools = r.result.get("tools").unwrap().as_array().unwrap();
1823 assert_eq!(tools.len(), 1);
1824 }
1825 JsonRpcResponse::Error(_) => panic!("Expected success for tools/list"),
1826 }
1827
1828 match &responses[1] {
1830 JsonRpcResponse::Result(r) => {
1831 assert_eq!(r.id, RequestId::Number(2));
1832 let content = r.result.get("content").unwrap().as_array().unwrap();
1833 let text = content[0].get("text").unwrap().as_str().unwrap();
1834 assert_eq!(text, "30");
1835 }
1836 JsonRpcResponse::Error(_) => panic!("Expected success for tools/call"),
1837 }
1838
1839 match &responses[2] {
1841 JsonRpcResponse::Result(r) => {
1842 assert_eq!(r.id, RequestId::Number(3));
1843 }
1844 JsonRpcResponse::Error(_) => panic!("Expected success for ping"),
1845 }
1846 }
1847
1848 #[tokio::test]
1849 async fn test_empty_batch_error() {
1850 let router = McpRouter::new();
1851 let mut service = JsonRpcService::new(router);
1852
1853 let result = service.call_batch(vec![]).await;
1854 assert!(result.is_err());
1855 }
1856
1857 #[tokio::test]
1862 async fn test_progress_token_extraction() {
1863 use crate::context::{ServerNotification, notification_channel};
1864 use crate::protocol::ProgressToken;
1865 use std::sync::Arc;
1866 use std::sync::atomic::{AtomicBool, Ordering};
1867
1868 let progress_reported = Arc::new(AtomicBool::new(false));
1870 let progress_ref = progress_reported.clone();
1871
1872 let tool = ToolBuilder::new("progress_tool")
1874 .description("Tool that reports progress")
1875 .extractor_handler((), move |ctx: Context, Json(_input): Json<AddInput>| {
1876 let reported = progress_ref.clone();
1877 async move {
1878 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
1880 .await;
1881 reported.store(true, Ordering::SeqCst);
1882 Ok(CallToolResult::text("done"))
1883 }
1884 })
1885 .build();
1886
1887 let (tx, mut rx) = notification_channel(10);
1889 let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1890 let mut service = JsonRpcService::new(router.clone());
1891
1892 init_jsonrpc_service(&mut service, &router).await;
1894
1895 let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1897 "name": "progress_tool",
1898 "arguments": {"a": 1, "b": 2},
1899 "_meta": {
1900 "progressToken": "test-token-123"
1901 }
1902 }));
1903
1904 let resp = service.call_single(req).await.unwrap();
1905
1906 match resp {
1908 JsonRpcResponse::Result(_) => {}
1909 JsonRpcResponse::Error(e) => panic!("Expected success, got error: {:?}", e),
1910 }
1911
1912 assert!(progress_reported.load(Ordering::SeqCst));
1914
1915 let notification = rx.try_recv().expect("Expected progress notification");
1917 match notification {
1918 ServerNotification::Progress(params) => {
1919 assert_eq!(
1920 params.progress_token,
1921 ProgressToken::String("test-token-123".to_string())
1922 );
1923 assert_eq!(params.progress, 50.0);
1924 assert_eq!(params.total, Some(100.0));
1925 assert_eq!(params.message.as_deref(), Some("Halfway"));
1926 }
1927 _ => panic!("Expected Progress notification"),
1928 }
1929 }
1930
1931 #[tokio::test]
1932 async fn test_tool_call_without_progress_token() {
1933 use crate::context::notification_channel;
1934 use std::sync::Arc;
1935 use std::sync::atomic::{AtomicBool, Ordering};
1936
1937 let progress_attempted = Arc::new(AtomicBool::new(false));
1938 let progress_ref = progress_attempted.clone();
1939
1940 let tool = ToolBuilder::new("no_token_tool")
1941 .description("Tool that tries to report progress without token")
1942 .extractor_handler((), move |ctx: Context, Json(_input): Json<AddInput>| {
1943 let attempted = progress_ref.clone();
1944 async move {
1945 ctx.report_progress(50.0, Some(100.0), None).await;
1947 attempted.store(true, Ordering::SeqCst);
1948 Ok(CallToolResult::text("done"))
1949 }
1950 })
1951 .build();
1952
1953 let (tx, mut rx) = notification_channel(10);
1954 let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1955 let mut service = JsonRpcService::new(router.clone());
1956
1957 init_jsonrpc_service(&mut service, &router).await;
1958
1959 let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1961 "name": "no_token_tool",
1962 "arguments": {"a": 1, "b": 2}
1963 }));
1964
1965 let resp = service.call_single(req).await.unwrap();
1966 assert!(matches!(resp, JsonRpcResponse::Result(_)));
1967
1968 assert!(progress_attempted.load(Ordering::SeqCst));
1970
1971 assert!(rx.try_recv().is_err());
1973 }
1974
1975 #[tokio::test]
1976 async fn test_batch_errors_returned_not_dropped() {
1977 let add_tool = ToolBuilder::new("add")
1978 .description("Add two numbers")
1979 .handler(|input: AddInput| async move {
1980 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1981 })
1982 .build();
1983
1984 let router = McpRouter::new().tool(add_tool);
1985 let mut service = JsonRpcService::new(router.clone());
1986
1987 init_jsonrpc_service(&mut service, &router).await;
1988
1989 let requests = vec![
1991 JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1993 "name": "add",
1994 "arguments": {"a": 10, "b": 20}
1995 })),
1996 JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1998 "name": "nonexistent_tool",
1999 "arguments": {}
2000 })),
2001 JsonRpcRequest::new(3, "ping"),
2003 ];
2004
2005 let responses = service.call_batch(requests).await.unwrap();
2006
2007 assert_eq!(responses.len(), 3);
2009
2010 match &responses[0] {
2012 JsonRpcResponse::Result(r) => {
2013 assert_eq!(r.id, RequestId::Number(1));
2014 }
2015 JsonRpcResponse::Error(_) => panic!("Expected success for first request"),
2016 }
2017
2018 match &responses[1] {
2020 JsonRpcResponse::Error(e) => {
2021 assert_eq!(e.id, Some(RequestId::Number(2)));
2022 assert!(e.error.message.contains("not found") || e.error.code == -32601);
2024 }
2025 JsonRpcResponse::Result(_) => panic!("Expected error for second request"),
2026 }
2027
2028 match &responses[2] {
2030 JsonRpcResponse::Result(r) => {
2031 assert_eq!(r.id, RequestId::Number(3));
2032 }
2033 JsonRpcResponse::Error(_) => panic!("Expected success for third request"),
2034 }
2035 }
2036
2037 #[tokio::test]
2042 async fn test_list_resource_templates() {
2043 use crate::resource::ResourceTemplateBuilder;
2044 use std::collections::HashMap;
2045
2046 let template = ResourceTemplateBuilder::new("file:///{path}")
2047 .name("Project Files")
2048 .description("Access project files")
2049 .handler(|uri: String, _vars: HashMap<String, String>| async move {
2050 Ok(ReadResourceResult {
2051 contents: vec![ResourceContent {
2052 uri,
2053 mime_type: None,
2054 text: None,
2055 blob: None,
2056 }],
2057 })
2058 });
2059
2060 let mut router = McpRouter::new().resource_template(template);
2061
2062 init_router(&mut router).await;
2064
2065 let req = RouterRequest {
2066 id: RequestId::Number(1),
2067 inner: McpRequest::ListResourceTemplates(ListResourceTemplatesParams::default()),
2068 extensions: Extensions::new(),
2069 };
2070
2071 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2072
2073 match resp.inner {
2074 Ok(McpResponse::ListResourceTemplates(result)) => {
2075 assert_eq!(result.resource_templates.len(), 1);
2076 assert_eq!(result.resource_templates[0].uri_template, "file:///{path}");
2077 assert_eq!(result.resource_templates[0].name, "Project Files");
2078 }
2079 _ => panic!("Expected ListResourceTemplates response"),
2080 }
2081 }
2082
2083 #[tokio::test]
2084 async fn test_read_resource_via_template() {
2085 use crate::resource::ResourceTemplateBuilder;
2086 use std::collections::HashMap;
2087
2088 let template = ResourceTemplateBuilder::new("db://users/{id}")
2089 .name("User Records")
2090 .handler(|uri: String, vars: HashMap<String, String>| async move {
2091 let id = vars.get("id").unwrap().clone();
2092 Ok(ReadResourceResult {
2093 contents: vec![ResourceContent {
2094 uri,
2095 mime_type: Some("application/json".to_string()),
2096 text: Some(format!(r#"{{"id": "{}"}}"#, id)),
2097 blob: None,
2098 }],
2099 })
2100 });
2101
2102 let mut router = McpRouter::new().resource_template(template);
2103
2104 init_router(&mut router).await;
2106
2107 let req = RouterRequest {
2109 id: RequestId::Number(1),
2110 inner: McpRequest::ReadResource(ReadResourceParams {
2111 uri: "db://users/123".to_string(),
2112 }),
2113 extensions: Extensions::new(),
2114 };
2115
2116 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2117
2118 match resp.inner {
2119 Ok(McpResponse::ReadResource(result)) => {
2120 assert_eq!(result.contents.len(), 1);
2121 assert_eq!(result.contents[0].uri, "db://users/123");
2122 assert!(result.contents[0].text.as_ref().unwrap().contains("123"));
2123 }
2124 _ => panic!("Expected ReadResource response"),
2125 }
2126 }
2127
2128 #[tokio::test]
2129 async fn test_static_resource_takes_precedence_over_template() {
2130 use crate::resource::{ResourceBuilder, ResourceTemplateBuilder};
2131 use std::collections::HashMap;
2132
2133 let template = ResourceTemplateBuilder::new("file:///{path}")
2135 .name("Files Template")
2136 .handler(|uri: String, _vars: HashMap<String, String>| async move {
2137 Ok(ReadResourceResult {
2138 contents: vec![ResourceContent {
2139 uri,
2140 mime_type: None,
2141 text: Some("from template".to_string()),
2142 blob: None,
2143 }],
2144 })
2145 });
2146
2147 let static_resource = ResourceBuilder::new("file:///README.md")
2149 .name("README")
2150 .text("from static resource");
2151
2152 let mut router = McpRouter::new()
2153 .resource_template(template)
2154 .resource(static_resource);
2155
2156 init_router(&mut router).await;
2158
2159 let req = RouterRequest {
2161 id: RequestId::Number(1),
2162 inner: McpRequest::ReadResource(ReadResourceParams {
2163 uri: "file:///README.md".to_string(),
2164 }),
2165 extensions: Extensions::new(),
2166 };
2167
2168 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2169
2170 match resp.inner {
2171 Ok(McpResponse::ReadResource(result)) => {
2172 assert_eq!(
2174 result.contents[0].text.as_deref(),
2175 Some("from static resource")
2176 );
2177 }
2178 _ => panic!("Expected ReadResource response"),
2179 }
2180 }
2181
2182 #[tokio::test]
2183 async fn test_resource_not_found_when_no_match() {
2184 use crate::resource::ResourceTemplateBuilder;
2185 use std::collections::HashMap;
2186
2187 let template = ResourceTemplateBuilder::new("db://users/{id}")
2188 .name("Users")
2189 .handler(|uri: String, _vars: HashMap<String, String>| async move {
2190 Ok(ReadResourceResult {
2191 contents: vec![ResourceContent {
2192 uri,
2193 mime_type: None,
2194 text: None,
2195 blob: None,
2196 }],
2197 })
2198 });
2199
2200 let mut router = McpRouter::new().resource_template(template);
2201
2202 init_router(&mut router).await;
2204
2205 let req = RouterRequest {
2207 id: RequestId::Number(1),
2208 inner: McpRequest::ReadResource(ReadResourceParams {
2209 uri: "db://posts/123".to_string(),
2210 }),
2211 extensions: Extensions::new(),
2212 };
2213
2214 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2215
2216 match resp.inner {
2217 Err(err) => {
2218 assert!(err.message.contains("not found"));
2219 }
2220 Ok(_) => panic!("Expected error for non-matching URI"),
2221 }
2222 }
2223
2224 #[tokio::test]
2225 async fn test_capabilities_include_resources_with_only_templates() {
2226 use crate::resource::ResourceTemplateBuilder;
2227 use std::collections::HashMap;
2228
2229 let template = ResourceTemplateBuilder::new("file:///{path}")
2230 .name("Files")
2231 .handler(|uri: String, _vars: HashMap<String, String>| async move {
2232 Ok(ReadResourceResult {
2233 contents: vec![ResourceContent {
2234 uri,
2235 mime_type: None,
2236 text: None,
2237 blob: None,
2238 }],
2239 })
2240 });
2241
2242 let mut router = McpRouter::new().resource_template(template);
2243
2244 let init_req = RouterRequest {
2246 id: RequestId::Number(0),
2247 inner: McpRequest::Initialize(InitializeParams {
2248 protocol_version: "2025-11-25".to_string(),
2249 capabilities: ClientCapabilities {
2250 roots: None,
2251 sampling: None,
2252 elicitation: None,
2253 },
2254 client_info: Implementation {
2255 name: "test".to_string(),
2256 version: "1.0".to_string(),
2257 ..Default::default()
2258 },
2259 }),
2260 extensions: Extensions::new(),
2261 };
2262 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2263
2264 match resp.inner {
2265 Ok(McpResponse::Initialize(result)) => {
2266 assert!(result.capabilities.resources.is_some());
2268 }
2269 _ => panic!("Expected Initialize response"),
2270 }
2271 }
2272
2273 #[tokio::test]
2278 async fn test_log_sends_notification() {
2279 use crate::context::notification_channel;
2280
2281 let (tx, mut rx) = notification_channel(10);
2282 let router = McpRouter::new().with_notification_sender(tx);
2283
2284 let sent = router.log_info("Test message");
2286 assert!(sent);
2287
2288 let notification = rx.try_recv().unwrap();
2290 match notification {
2291 ServerNotification::LogMessage(params) => {
2292 assert_eq!(params.level, LogLevel::Info);
2293 let data = params.data.unwrap();
2294 assert_eq!(
2295 data.get("message").unwrap().as_str().unwrap(),
2296 "Test message"
2297 );
2298 }
2299 _ => panic!("Expected LogMessage notification"),
2300 }
2301 }
2302
2303 #[tokio::test]
2304 async fn test_log_with_custom_params() {
2305 use crate::context::notification_channel;
2306
2307 let (tx, mut rx) = notification_channel(10);
2308 let router = McpRouter::new().with_notification_sender(tx);
2309
2310 let params = LoggingMessageParams::new(LogLevel::Error)
2312 .with_logger("database")
2313 .with_data(serde_json::json!({
2314 "error": "Connection failed",
2315 "host": "localhost"
2316 }));
2317
2318 let sent = router.log(params);
2319 assert!(sent);
2320
2321 let notification = rx.try_recv().unwrap();
2322 match notification {
2323 ServerNotification::LogMessage(params) => {
2324 assert_eq!(params.level, LogLevel::Error);
2325 assert_eq!(params.logger.as_deref(), Some("database"));
2326 let data = params.data.unwrap();
2327 assert_eq!(
2328 data.get("error").unwrap().as_str().unwrap(),
2329 "Connection failed"
2330 );
2331 }
2332 _ => panic!("Expected LogMessage notification"),
2333 }
2334 }
2335
2336 #[tokio::test]
2337 async fn test_log_without_channel_returns_false() {
2338 let router = McpRouter::new();
2340
2341 assert!(!router.log_info("Test"));
2343 assert!(!router.log_warning("Test"));
2344 assert!(!router.log_error("Test"));
2345 assert!(!router.log_debug("Test"));
2346 }
2347
2348 #[tokio::test]
2349 async fn test_logging_capability_with_channel() {
2350 use crate::context::notification_channel;
2351
2352 let (tx, _rx) = notification_channel(10);
2353 let mut router = McpRouter::new().with_notification_sender(tx);
2354
2355 let init_req = RouterRequest {
2357 id: RequestId::Number(0),
2358 inner: McpRequest::Initialize(InitializeParams {
2359 protocol_version: "2025-11-25".to_string(),
2360 capabilities: ClientCapabilities {
2361 roots: None,
2362 sampling: None,
2363 elicitation: None,
2364 },
2365 client_info: Implementation {
2366 name: "test".to_string(),
2367 version: "1.0".to_string(),
2368 ..Default::default()
2369 },
2370 }),
2371 extensions: Extensions::new(),
2372 };
2373 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2374
2375 match resp.inner {
2376 Ok(McpResponse::Initialize(result)) => {
2377 assert!(result.capabilities.logging.is_some());
2379 }
2380 _ => panic!("Expected Initialize response"),
2381 }
2382 }
2383
2384 #[tokio::test]
2385 async fn test_no_logging_capability_without_channel() {
2386 let mut router = McpRouter::new();
2387
2388 let init_req = RouterRequest {
2390 id: RequestId::Number(0),
2391 inner: McpRequest::Initialize(InitializeParams {
2392 protocol_version: "2025-11-25".to_string(),
2393 capabilities: ClientCapabilities {
2394 roots: None,
2395 sampling: None,
2396 elicitation: None,
2397 },
2398 client_info: Implementation {
2399 name: "test".to_string(),
2400 version: "1.0".to_string(),
2401 ..Default::default()
2402 },
2403 }),
2404 extensions: Extensions::new(),
2405 };
2406 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2407
2408 match resp.inner {
2409 Ok(McpResponse::Initialize(result)) => {
2410 assert!(result.capabilities.logging.is_none());
2412 }
2413 _ => panic!("Expected Initialize response"),
2414 }
2415 }
2416
2417 #[tokio::test]
2422 async fn test_enqueue_task() {
2423 let add_tool = ToolBuilder::new("add")
2424 .description("Add two numbers")
2425 .handler(|input: AddInput| async move {
2426 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2427 })
2428 .build();
2429
2430 let mut router = McpRouter::new().tool(add_tool);
2431 init_router(&mut router).await;
2432
2433 let req = RouterRequest {
2434 id: RequestId::Number(1),
2435 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2436 tool_name: "add".to_string(),
2437 arguments: serde_json::json!({"a": 5, "b": 10}),
2438 ttl: None,
2439 }),
2440 extensions: Extensions::new(),
2441 };
2442
2443 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2444
2445 match resp.inner {
2446 Ok(McpResponse::EnqueueTask(result)) => {
2447 assert!(result.task_id.starts_with("task-"));
2448 assert_eq!(result.status, TaskStatus::Working);
2449 }
2450 _ => panic!("Expected EnqueueTask response"),
2451 }
2452 }
2453
2454 #[tokio::test]
2455 async fn test_list_tasks_empty() {
2456 let mut router = McpRouter::new();
2457 init_router(&mut router).await;
2458
2459 let req = RouterRequest {
2460 id: RequestId::Number(1),
2461 inner: McpRequest::ListTasks(ListTasksParams::default()),
2462 extensions: Extensions::new(),
2463 };
2464
2465 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2466
2467 match resp.inner {
2468 Ok(McpResponse::ListTasks(result)) => {
2469 assert!(result.tasks.is_empty());
2470 }
2471 _ => panic!("Expected ListTasks response"),
2472 }
2473 }
2474
2475 #[tokio::test]
2476 async fn test_task_lifecycle_complete() {
2477 let add_tool = ToolBuilder::new("add")
2478 .description("Add two numbers")
2479 .handler(|input: AddInput| async move {
2480 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2481 })
2482 .build();
2483
2484 let mut router = McpRouter::new().tool(add_tool);
2485 init_router(&mut router).await;
2486
2487 let req = RouterRequest {
2489 id: RequestId::Number(1),
2490 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2491 tool_name: "add".to_string(),
2492 arguments: serde_json::json!({"a": 7, "b": 8}),
2493 ttl: None,
2494 }),
2495 extensions: Extensions::new(),
2496 };
2497
2498 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2499 let task_id = match resp.inner {
2500 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2501 _ => panic!("Expected EnqueueTask response"),
2502 };
2503
2504 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
2506
2507 let req = RouterRequest {
2509 id: RequestId::Number(2),
2510 inner: McpRequest::GetTaskResult(GetTaskResultParams {
2511 task_id: task_id.clone(),
2512 }),
2513 extensions: Extensions::new(),
2514 };
2515
2516 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2517
2518 match resp.inner {
2519 Ok(McpResponse::GetTaskResult(result)) => {
2520 assert_eq!(result.task_id, task_id);
2521 assert_eq!(result.status, TaskStatus::Completed);
2522 assert!(result.result.is_some());
2523 assert!(result.error.is_none());
2524
2525 let tool_result = result.result.unwrap();
2527 match &tool_result.content[0] {
2528 Content::Text { text, .. } => assert_eq!(text, "15"),
2529 _ => panic!("Expected text content"),
2530 }
2531 }
2532 _ => panic!("Expected GetTaskResult response"),
2533 }
2534 }
2535
2536 #[tokio::test]
2537 async fn test_task_cancellation() {
2538 let slow_tool = ToolBuilder::new("slow")
2540 .description("Slow tool")
2541 .handler(|_input: serde_json::Value| async move {
2542 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
2543 Ok(CallToolResult::text("done"))
2544 })
2545 .build();
2546
2547 let mut router = McpRouter::new().tool(slow_tool);
2548 init_router(&mut router).await;
2549
2550 let req = RouterRequest {
2552 id: RequestId::Number(1),
2553 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2554 tool_name: "slow".to_string(),
2555 arguments: serde_json::json!({}),
2556 ttl: None,
2557 }),
2558 extensions: Extensions::new(),
2559 };
2560
2561 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2562 let task_id = match resp.inner {
2563 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2564 _ => panic!("Expected EnqueueTask response"),
2565 };
2566
2567 let req = RouterRequest {
2569 id: RequestId::Number(2),
2570 inner: McpRequest::CancelTask(CancelTaskParams {
2571 task_id: task_id.clone(),
2572 reason: Some("Test cancellation".to_string()),
2573 }),
2574 extensions: Extensions::new(),
2575 };
2576
2577 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2578
2579 match resp.inner {
2580 Ok(McpResponse::CancelTask(result)) => {
2581 assert!(result.cancelled);
2582 assert_eq!(result.status, TaskStatus::Cancelled);
2583 }
2584 _ => panic!("Expected CancelTask response"),
2585 }
2586 }
2587
2588 #[tokio::test]
2589 async fn test_get_task_info() {
2590 let add_tool = ToolBuilder::new("add")
2591 .description("Add two numbers")
2592 .handler(|input: AddInput| async move {
2593 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2594 })
2595 .build();
2596
2597 let mut router = McpRouter::new().tool(add_tool);
2598 init_router(&mut router).await;
2599
2600 let req = RouterRequest {
2602 id: RequestId::Number(1),
2603 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2604 tool_name: "add".to_string(),
2605 arguments: serde_json::json!({"a": 1, "b": 2}),
2606 ttl: Some(600),
2607 }),
2608 extensions: Extensions::new(),
2609 };
2610
2611 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2612 let task_id = match resp.inner {
2613 Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2614 _ => panic!("Expected EnqueueTask response"),
2615 };
2616
2617 let req = RouterRequest {
2619 id: RequestId::Number(2),
2620 inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2621 task_id: task_id.clone(),
2622 }),
2623 extensions: Extensions::new(),
2624 };
2625
2626 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2627
2628 match resp.inner {
2629 Ok(McpResponse::GetTaskInfo(info)) => {
2630 assert_eq!(info.task_id, task_id);
2631 assert!(info.created_at.contains('T')); assert_eq!(info.ttl, Some(600));
2633 }
2634 _ => panic!("Expected GetTaskInfo response"),
2635 }
2636 }
2637
2638 #[tokio::test]
2639 async fn test_enqueue_nonexistent_tool() {
2640 let mut router = McpRouter::new();
2641 init_router(&mut router).await;
2642
2643 let req = RouterRequest {
2644 id: RequestId::Number(1),
2645 inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2646 tool_name: "nonexistent".to_string(),
2647 arguments: serde_json::json!({}),
2648 ttl: None,
2649 }),
2650 extensions: Extensions::new(),
2651 };
2652
2653 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2654
2655 match resp.inner {
2656 Err(e) => {
2657 assert!(e.message.contains("not found"));
2658 }
2659 _ => panic!("Expected error response"),
2660 }
2661 }
2662
2663 #[tokio::test]
2664 async fn test_get_nonexistent_task() {
2665 let mut router = McpRouter::new();
2666 init_router(&mut router).await;
2667
2668 let req = RouterRequest {
2669 id: RequestId::Number(1),
2670 inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2671 task_id: "task-999".to_string(),
2672 }),
2673 extensions: Extensions::new(),
2674 };
2675
2676 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2677
2678 match resp.inner {
2679 Err(e) => {
2680 assert!(e.message.contains("not found"));
2681 }
2682 _ => panic!("Expected error response"),
2683 }
2684 }
2685
2686 #[tokio::test]
2691 async fn test_subscribe_to_resource() {
2692 use crate::resource::ResourceBuilder;
2693
2694 let resource = ResourceBuilder::new("file:///test.txt")
2695 .name("Test File")
2696 .text("Hello");
2697
2698 let mut router = McpRouter::new().resource(resource);
2699 init_router(&mut router).await;
2700
2701 let req = RouterRequest {
2703 id: RequestId::Number(1),
2704 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2705 uri: "file:///test.txt".to_string(),
2706 }),
2707 extensions: Extensions::new(),
2708 };
2709
2710 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2711
2712 match resp.inner {
2713 Ok(McpResponse::SubscribeResource(_)) => {
2714 assert!(router.is_subscribed("file:///test.txt"));
2716 }
2717 _ => panic!("Expected SubscribeResource response"),
2718 }
2719 }
2720
2721 #[tokio::test]
2722 async fn test_unsubscribe_from_resource() {
2723 use crate::resource::ResourceBuilder;
2724
2725 let resource = ResourceBuilder::new("file:///test.txt")
2726 .name("Test File")
2727 .text("Hello");
2728
2729 let mut router = McpRouter::new().resource(resource);
2730 init_router(&mut router).await;
2731
2732 let req = RouterRequest {
2734 id: RequestId::Number(1),
2735 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2736 uri: "file:///test.txt".to_string(),
2737 }),
2738 extensions: Extensions::new(),
2739 };
2740 let _ = router.ready().await.unwrap().call(req).await.unwrap();
2741 assert!(router.is_subscribed("file:///test.txt"));
2742
2743 let req = RouterRequest {
2745 id: RequestId::Number(2),
2746 inner: McpRequest::UnsubscribeResource(UnsubscribeResourceParams {
2747 uri: "file:///test.txt".to_string(),
2748 }),
2749 extensions: Extensions::new(),
2750 };
2751
2752 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2753
2754 match resp.inner {
2755 Ok(McpResponse::UnsubscribeResource(_)) => {
2756 assert!(!router.is_subscribed("file:///test.txt"));
2758 }
2759 _ => panic!("Expected UnsubscribeResource response"),
2760 }
2761 }
2762
2763 #[tokio::test]
2764 async fn test_subscribe_nonexistent_resource() {
2765 let mut router = McpRouter::new();
2766 init_router(&mut router).await;
2767
2768 let req = RouterRequest {
2769 id: RequestId::Number(1),
2770 inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2771 uri: "file:///nonexistent.txt".to_string(),
2772 }),
2773 extensions: Extensions::new(),
2774 };
2775
2776 let resp = router.ready().await.unwrap().call(req).await.unwrap();
2777
2778 match resp.inner {
2779 Err(e) => {
2780 assert!(e.message.contains("not found"));
2781 }
2782 _ => panic!("Expected error response"),
2783 }
2784 }
2785
2786 #[tokio::test]
2787 async fn test_notify_resource_updated() {
2788 use crate::context::notification_channel;
2789 use crate::resource::ResourceBuilder;
2790
2791 let (tx, mut rx) = notification_channel(10);
2792
2793 let resource = ResourceBuilder::new("file:///test.txt")
2794 .name("Test File")
2795 .text("Hello");
2796
2797 let router = McpRouter::new()
2798 .resource(resource)
2799 .with_notification_sender(tx);
2800
2801 router.subscribe("file:///test.txt");
2803
2804 let sent = router.notify_resource_updated("file:///test.txt");
2806 assert!(sent);
2807
2808 let notification = rx.try_recv().unwrap();
2810 match notification {
2811 ServerNotification::ResourceUpdated { uri } => {
2812 assert_eq!(uri, "file:///test.txt");
2813 }
2814 _ => panic!("Expected ResourceUpdated notification"),
2815 }
2816 }
2817
2818 #[tokio::test]
2819 async fn test_notify_resource_updated_not_subscribed() {
2820 use crate::context::notification_channel;
2821 use crate::resource::ResourceBuilder;
2822
2823 let (tx, mut rx) = notification_channel(10);
2824
2825 let resource = ResourceBuilder::new("file:///test.txt")
2826 .name("Test File")
2827 .text("Hello");
2828
2829 let router = McpRouter::new()
2830 .resource(resource)
2831 .with_notification_sender(tx);
2832
2833 let sent = router.notify_resource_updated("file:///test.txt");
2835 assert!(!sent); assert!(rx.try_recv().is_err());
2839 }
2840
2841 #[tokio::test]
2842 async fn test_notify_resources_list_changed() {
2843 use crate::context::notification_channel;
2844
2845 let (tx, mut rx) = notification_channel(10);
2846 let router = McpRouter::new().with_notification_sender(tx);
2847
2848 let sent = router.notify_resources_list_changed();
2849 assert!(sent);
2850
2851 let notification = rx.try_recv().unwrap();
2852 match notification {
2853 ServerNotification::ResourcesListChanged => {}
2854 _ => panic!("Expected ResourcesListChanged notification"),
2855 }
2856 }
2857
2858 #[tokio::test]
2859 async fn test_subscribed_uris() {
2860 use crate::resource::ResourceBuilder;
2861
2862 let resource1 = ResourceBuilder::new("file:///a.txt").name("A").text("A");
2863
2864 let resource2 = ResourceBuilder::new("file:///b.txt").name("B").text("B");
2865
2866 let router = McpRouter::new().resource(resource1).resource(resource2);
2867
2868 router.subscribe("file:///a.txt");
2870 router.subscribe("file:///b.txt");
2871
2872 let uris = router.subscribed_uris();
2873 assert_eq!(uris.len(), 2);
2874 assert!(uris.contains(&"file:///a.txt".to_string()));
2875 assert!(uris.contains(&"file:///b.txt".to_string()));
2876 }
2877
2878 #[tokio::test]
2879 async fn test_subscription_capability_advertised() {
2880 use crate::resource::ResourceBuilder;
2881
2882 let resource = ResourceBuilder::new("file:///test.txt")
2883 .name("Test")
2884 .text("Hello");
2885
2886 let mut router = McpRouter::new().resource(resource);
2887
2888 let init_req = RouterRequest {
2890 id: RequestId::Number(0),
2891 inner: McpRequest::Initialize(InitializeParams {
2892 protocol_version: "2025-11-25".to_string(),
2893 capabilities: ClientCapabilities {
2894 roots: None,
2895 sampling: None,
2896 elicitation: None,
2897 },
2898 client_info: Implementation {
2899 name: "test".to_string(),
2900 version: "1.0".to_string(),
2901 ..Default::default()
2902 },
2903 }),
2904 extensions: Extensions::new(),
2905 };
2906 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2907
2908 match resp.inner {
2909 Ok(McpResponse::Initialize(result)) => {
2910 let resources_cap = result.capabilities.resources.unwrap();
2912 assert!(resources_cap.subscribe);
2913 }
2914 _ => panic!("Expected Initialize response"),
2915 }
2916 }
2917
2918 #[tokio::test]
2919 async fn test_completion_handler() {
2920 let router = McpRouter::new()
2921 .server_info("test", "1.0")
2922 .completion_handler(|params: CompleteParams| async move {
2923 let prefix = ¶ms.argument.value;
2925 let suggestions: Vec<String> = vec!["alpha", "beta", "gamma"]
2926 .into_iter()
2927 .filter(|s| s.starts_with(prefix))
2928 .map(String::from)
2929 .collect();
2930 Ok(CompleteResult::new(suggestions))
2931 });
2932
2933 let init_req = RouterRequest {
2935 id: RequestId::Number(0),
2936 inner: McpRequest::Initialize(InitializeParams {
2937 protocol_version: "2025-11-25".to_string(),
2938 capabilities: ClientCapabilities::default(),
2939 client_info: Implementation {
2940 name: "test".to_string(),
2941 version: "1.0".to_string(),
2942 ..Default::default()
2943 },
2944 }),
2945 extensions: Extensions::new(),
2946 };
2947 let resp = router
2948 .clone()
2949 .ready()
2950 .await
2951 .unwrap()
2952 .call(init_req)
2953 .await
2954 .unwrap();
2955
2956 match resp.inner {
2958 Ok(McpResponse::Initialize(result)) => {
2959 assert!(result.capabilities.completions.is_some());
2960 }
2961 _ => panic!("Expected Initialize response"),
2962 }
2963
2964 router.handle_notification(McpNotification::Initialized);
2966
2967 let complete_req = RouterRequest {
2969 id: RequestId::Number(1),
2970 inner: McpRequest::Complete(CompleteParams {
2971 reference: CompletionReference::prompt("test-prompt"),
2972 argument: CompletionArgument::new("query", "al"),
2973 }),
2974 extensions: Extensions::new(),
2975 };
2976 let resp = router
2977 .clone()
2978 .ready()
2979 .await
2980 .unwrap()
2981 .call(complete_req)
2982 .await
2983 .unwrap();
2984
2985 match resp.inner {
2986 Ok(McpResponse::Complete(result)) => {
2987 assert_eq!(result.completion.values, vec!["alpha"]);
2988 }
2989 _ => panic!("Expected Complete response"),
2990 }
2991 }
2992
2993 #[tokio::test]
2994 async fn test_completion_without_handler_returns_empty() {
2995 let router = McpRouter::new().server_info("test", "1.0");
2996
2997 let init_req = RouterRequest {
2999 id: RequestId::Number(0),
3000 inner: McpRequest::Initialize(InitializeParams {
3001 protocol_version: "2025-11-25".to_string(),
3002 capabilities: ClientCapabilities::default(),
3003 client_info: Implementation {
3004 name: "test".to_string(),
3005 version: "1.0".to_string(),
3006 ..Default::default()
3007 },
3008 }),
3009 extensions: Extensions::new(),
3010 };
3011 let resp = router
3012 .clone()
3013 .ready()
3014 .await
3015 .unwrap()
3016 .call(init_req)
3017 .await
3018 .unwrap();
3019
3020 match resp.inner {
3022 Ok(McpResponse::Initialize(result)) => {
3023 assert!(result.capabilities.completions.is_none());
3024 }
3025 _ => panic!("Expected Initialize response"),
3026 }
3027
3028 router.handle_notification(McpNotification::Initialized);
3030
3031 let complete_req = RouterRequest {
3033 id: RequestId::Number(1),
3034 inner: McpRequest::Complete(CompleteParams {
3035 reference: CompletionReference::prompt("test-prompt"),
3036 argument: CompletionArgument::new("query", "al"),
3037 }),
3038 extensions: Extensions::new(),
3039 };
3040 let resp = router
3041 .clone()
3042 .ready()
3043 .await
3044 .unwrap()
3045 .call(complete_req)
3046 .await
3047 .unwrap();
3048
3049 match resp.inner {
3050 Ok(McpResponse::Complete(result)) => {
3051 assert!(result.completion.values.is_empty());
3052 }
3053 _ => panic!("Expected Complete response"),
3054 }
3055 }
3056
3057 #[tokio::test]
3058 async fn test_tool_filter_list() {
3059 use crate::filter::CapabilityFilter;
3060 use crate::tool::Tool;
3061
3062 let public_tool = ToolBuilder::new("public")
3063 .description("Public tool")
3064 .handler(|_: AddInput| async move { Ok(CallToolResult::text("public")) })
3065 .build();
3066
3067 let admin_tool = ToolBuilder::new("admin")
3068 .description("Admin tool")
3069 .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
3070 .build();
3071
3072 let mut router = McpRouter::new()
3073 .tool(public_tool)
3074 .tool(admin_tool)
3075 .tool_filter(CapabilityFilter::new(|_, tool: &Tool| tool.name != "admin"));
3076
3077 init_router(&mut router).await;
3079
3080 let req = RouterRequest {
3081 id: RequestId::Number(1),
3082 inner: McpRequest::ListTools(ListToolsParams::default()),
3083 extensions: Extensions::new(),
3084 };
3085
3086 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3087
3088 match resp.inner {
3089 Ok(McpResponse::ListTools(result)) => {
3090 assert_eq!(result.tools.len(), 1);
3092 assert_eq!(result.tools[0].name, "public");
3093 }
3094 _ => panic!("Expected ListTools response"),
3095 }
3096 }
3097
3098 #[tokio::test]
3099 async fn test_tool_filter_call_denied() {
3100 use crate::filter::CapabilityFilter;
3101 use crate::tool::Tool;
3102
3103 let admin_tool = ToolBuilder::new("admin")
3104 .description("Admin tool")
3105 .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
3106 .build();
3107
3108 let mut router = McpRouter::new()
3109 .tool(admin_tool)
3110 .tool_filter(CapabilityFilter::new(|_, _: &Tool| false)); init_router(&mut router).await;
3114
3115 let req = RouterRequest {
3116 id: RequestId::Number(1),
3117 inner: McpRequest::CallTool(CallToolParams {
3118 name: "admin".to_string(),
3119 arguments: serde_json::json!({"a": 1, "b": 2}),
3120 meta: None,
3121 }),
3122 extensions: Extensions::new(),
3123 };
3124
3125 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3126
3127 match resp.inner {
3129 Err(e) => {
3130 assert_eq!(e.code, -32601); }
3132 _ => panic!("Expected JsonRpc error"),
3133 }
3134 }
3135
3136 #[tokio::test]
3137 async fn test_tool_filter_call_allowed() {
3138 use crate::filter::CapabilityFilter;
3139 use crate::tool::Tool;
3140
3141 let public_tool = ToolBuilder::new("public")
3142 .description("Public tool")
3143 .handler(|input: AddInput| async move {
3144 Ok(CallToolResult::text(format!("{}", input.a + input.b)))
3145 })
3146 .build();
3147
3148 let mut router = McpRouter::new()
3149 .tool(public_tool)
3150 .tool_filter(CapabilityFilter::new(|_, _: &Tool| true)); init_router(&mut router).await;
3154
3155 let req = RouterRequest {
3156 id: RequestId::Number(1),
3157 inner: McpRequest::CallTool(CallToolParams {
3158 name: "public".to_string(),
3159 arguments: serde_json::json!({"a": 1, "b": 2}),
3160 meta: None,
3161 }),
3162 extensions: Extensions::new(),
3163 };
3164
3165 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3166
3167 match resp.inner {
3168 Ok(McpResponse::CallTool(result)) => {
3169 assert!(!result.is_error);
3170 }
3171 _ => panic!("Expected CallTool response"),
3172 }
3173 }
3174
3175 #[tokio::test]
3176 async fn test_tool_filter_custom_denial() {
3177 use crate::filter::{CapabilityFilter, DenialBehavior};
3178 use crate::tool::Tool;
3179
3180 let admin_tool = ToolBuilder::new("admin")
3181 .description("Admin tool")
3182 .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
3183 .build();
3184
3185 let mut router = McpRouter::new().tool(admin_tool).tool_filter(
3186 CapabilityFilter::new(|_, _: &Tool| false)
3187 .denial_behavior(DenialBehavior::Unauthorized),
3188 );
3189
3190 init_router(&mut router).await;
3192
3193 let req = RouterRequest {
3194 id: RequestId::Number(1),
3195 inner: McpRequest::CallTool(CallToolParams {
3196 name: "admin".to_string(),
3197 arguments: serde_json::json!({"a": 1, "b": 2}),
3198 meta: None,
3199 }),
3200 extensions: Extensions::new(),
3201 };
3202
3203 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3204
3205 match resp.inner {
3207 Err(e) => {
3208 assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
3210 }
3211 _ => panic!("Expected JsonRpc error"),
3212 }
3213 }
3214
3215 #[tokio::test]
3216 async fn test_resource_filter_list() {
3217 use crate::filter::CapabilityFilter;
3218 use crate::resource::{Resource, ResourceBuilder};
3219
3220 let public_resource = ResourceBuilder::new("file:///public.txt")
3221 .name("Public File")
3222 .text("public content");
3223
3224 let secret_resource = ResourceBuilder::new("file:///secret.txt")
3225 .name("Secret File")
3226 .text("secret content");
3227
3228 let mut router = McpRouter::new()
3229 .resource(public_resource)
3230 .resource(secret_resource)
3231 .resource_filter(CapabilityFilter::new(|_, r: &Resource| {
3232 !r.name.contains("Secret")
3233 }));
3234
3235 init_router(&mut router).await;
3237
3238 let req = RouterRequest {
3239 id: RequestId::Number(1),
3240 inner: McpRequest::ListResources(ListResourcesParams::default()),
3241 extensions: Extensions::new(),
3242 };
3243
3244 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3245
3246 match resp.inner {
3247 Ok(McpResponse::ListResources(result)) => {
3248 assert_eq!(result.resources.len(), 1);
3250 assert_eq!(result.resources[0].name, "Public File");
3251 }
3252 _ => panic!("Expected ListResources response"),
3253 }
3254 }
3255
3256 #[tokio::test]
3257 async fn test_resource_filter_read_denied() {
3258 use crate::filter::CapabilityFilter;
3259 use crate::resource::{Resource, ResourceBuilder};
3260
3261 let secret_resource = ResourceBuilder::new("file:///secret.txt")
3262 .name("Secret File")
3263 .text("secret content");
3264
3265 let mut router = McpRouter::new()
3266 .resource(secret_resource)
3267 .resource_filter(CapabilityFilter::new(|_, _: &Resource| false)); init_router(&mut router).await;
3271
3272 let req = RouterRequest {
3273 id: RequestId::Number(1),
3274 inner: McpRequest::ReadResource(ReadResourceParams {
3275 uri: "file:///secret.txt".to_string(),
3276 }),
3277 extensions: Extensions::new(),
3278 };
3279
3280 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3281
3282 match resp.inner {
3284 Err(e) => {
3285 assert_eq!(e.code, -32601); }
3287 _ => panic!("Expected JsonRpc error"),
3288 }
3289 }
3290
3291 #[tokio::test]
3292 async fn test_resource_filter_read_allowed() {
3293 use crate::filter::CapabilityFilter;
3294 use crate::resource::{Resource, ResourceBuilder};
3295
3296 let public_resource = ResourceBuilder::new("file:///public.txt")
3297 .name("Public File")
3298 .text("public content");
3299
3300 let mut router = McpRouter::new()
3301 .resource(public_resource)
3302 .resource_filter(CapabilityFilter::new(|_, _: &Resource| true)); init_router(&mut router).await;
3306
3307 let req = RouterRequest {
3308 id: RequestId::Number(1),
3309 inner: McpRequest::ReadResource(ReadResourceParams {
3310 uri: "file:///public.txt".to_string(),
3311 }),
3312 extensions: Extensions::new(),
3313 };
3314
3315 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3316
3317 match resp.inner {
3318 Ok(McpResponse::ReadResource(result)) => {
3319 assert_eq!(result.contents.len(), 1);
3320 assert_eq!(result.contents[0].text.as_deref(), Some("public content"));
3321 }
3322 _ => panic!("Expected ReadResource response"),
3323 }
3324 }
3325
3326 #[tokio::test]
3327 async fn test_resource_filter_custom_denial() {
3328 use crate::filter::{CapabilityFilter, DenialBehavior};
3329 use crate::resource::{Resource, ResourceBuilder};
3330
3331 let secret_resource = ResourceBuilder::new("file:///secret.txt")
3332 .name("Secret File")
3333 .text("secret content");
3334
3335 let mut router = McpRouter::new().resource(secret_resource).resource_filter(
3336 CapabilityFilter::new(|_, _: &Resource| false)
3337 .denial_behavior(DenialBehavior::Unauthorized),
3338 );
3339
3340 init_router(&mut router).await;
3342
3343 let req = RouterRequest {
3344 id: RequestId::Number(1),
3345 inner: McpRequest::ReadResource(ReadResourceParams {
3346 uri: "file:///secret.txt".to_string(),
3347 }),
3348 extensions: Extensions::new(),
3349 };
3350
3351 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3352
3353 match resp.inner {
3355 Err(e) => {
3356 assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
3358 }
3359 _ => panic!("Expected JsonRpc error"),
3360 }
3361 }
3362
3363 #[tokio::test]
3364 async fn test_prompt_filter_list() {
3365 use crate::filter::CapabilityFilter;
3366 use crate::prompt::{Prompt, PromptBuilder};
3367
3368 let public_prompt = PromptBuilder::new("greeting")
3369 .description("A greeting")
3370 .user_message("Hello!");
3371
3372 let admin_prompt = PromptBuilder::new("system_debug")
3373 .description("Admin prompt")
3374 .user_message("Debug");
3375
3376 let mut router = McpRouter::new()
3377 .prompt(public_prompt)
3378 .prompt(admin_prompt)
3379 .prompt_filter(CapabilityFilter::new(|_, p: &Prompt| {
3380 !p.name.contains("system")
3381 }));
3382
3383 init_router(&mut router).await;
3385
3386 let req = RouterRequest {
3387 id: RequestId::Number(1),
3388 inner: McpRequest::ListPrompts(ListPromptsParams::default()),
3389 extensions: Extensions::new(),
3390 };
3391
3392 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3393
3394 match resp.inner {
3395 Ok(McpResponse::ListPrompts(result)) => {
3396 assert_eq!(result.prompts.len(), 1);
3398 assert_eq!(result.prompts[0].name, "greeting");
3399 }
3400 _ => panic!("Expected ListPrompts response"),
3401 }
3402 }
3403
3404 #[tokio::test]
3405 async fn test_prompt_filter_get_denied() {
3406 use crate::filter::CapabilityFilter;
3407 use crate::prompt::{Prompt, PromptBuilder};
3408 use std::collections::HashMap;
3409
3410 let admin_prompt = PromptBuilder::new("system_debug")
3411 .description("Admin prompt")
3412 .user_message("Debug");
3413
3414 let mut router = McpRouter::new()
3415 .prompt(admin_prompt)
3416 .prompt_filter(CapabilityFilter::new(|_, _: &Prompt| false)); init_router(&mut router).await;
3420
3421 let req = RouterRequest {
3422 id: RequestId::Number(1),
3423 inner: McpRequest::GetPrompt(GetPromptParams {
3424 name: "system_debug".to_string(),
3425 arguments: HashMap::new(),
3426 }),
3427 extensions: Extensions::new(),
3428 };
3429
3430 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3431
3432 match resp.inner {
3434 Err(e) => {
3435 assert_eq!(e.code, -32601); }
3437 _ => panic!("Expected JsonRpc error"),
3438 }
3439 }
3440
3441 #[tokio::test]
3442 async fn test_prompt_filter_get_allowed() {
3443 use crate::filter::CapabilityFilter;
3444 use crate::prompt::{Prompt, PromptBuilder};
3445 use std::collections::HashMap;
3446
3447 let public_prompt = PromptBuilder::new("greeting")
3448 .description("A greeting")
3449 .user_message("Hello!");
3450
3451 let mut router = McpRouter::new()
3452 .prompt(public_prompt)
3453 .prompt_filter(CapabilityFilter::new(|_, _: &Prompt| true)); init_router(&mut router).await;
3457
3458 let req = RouterRequest {
3459 id: RequestId::Number(1),
3460 inner: McpRequest::GetPrompt(GetPromptParams {
3461 name: "greeting".to_string(),
3462 arguments: HashMap::new(),
3463 }),
3464 extensions: Extensions::new(),
3465 };
3466
3467 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3468
3469 match resp.inner {
3470 Ok(McpResponse::GetPrompt(result)) => {
3471 assert_eq!(result.messages.len(), 1);
3472 }
3473 _ => panic!("Expected GetPrompt response"),
3474 }
3475 }
3476
3477 #[tokio::test]
3478 async fn test_prompt_filter_custom_denial() {
3479 use crate::filter::{CapabilityFilter, DenialBehavior};
3480 use crate::prompt::{Prompt, PromptBuilder};
3481 use std::collections::HashMap;
3482
3483 let admin_prompt = PromptBuilder::new("system_debug")
3484 .description("Admin prompt")
3485 .user_message("Debug");
3486
3487 let mut router = McpRouter::new().prompt(admin_prompt).prompt_filter(
3488 CapabilityFilter::new(|_, _: &Prompt| false)
3489 .denial_behavior(DenialBehavior::Unauthorized),
3490 );
3491
3492 init_router(&mut router).await;
3494
3495 let req = RouterRequest {
3496 id: RequestId::Number(1),
3497 inner: McpRequest::GetPrompt(GetPromptParams {
3498 name: "system_debug".to_string(),
3499 arguments: HashMap::new(),
3500 }),
3501 extensions: Extensions::new(),
3502 };
3503
3504 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3505
3506 match resp.inner {
3508 Err(e) => {
3509 assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
3511 }
3512 _ => panic!("Expected JsonRpc error"),
3513 }
3514 }
3515
3516 #[derive(Debug, Deserialize, JsonSchema)]
3521 struct StringInput {
3522 value: String,
3523 }
3524
3525 #[tokio::test]
3526 async fn test_router_merge_tools() {
3527 let tool_a = ToolBuilder::new("tool_a")
3529 .description("Tool A")
3530 .handler(|_: StringInput| async move { Ok(CallToolResult::text("A")) })
3531 .build();
3532
3533 let router_a = McpRouter::new().tool(tool_a);
3534
3535 let tool_b = ToolBuilder::new("tool_b")
3537 .description("Tool B")
3538 .handler(|_: StringInput| async move { Ok(CallToolResult::text("B")) })
3539 .build();
3540 let tool_c = ToolBuilder::new("tool_c")
3541 .description("Tool C")
3542 .handler(|_: StringInput| async move { Ok(CallToolResult::text("C")) })
3543 .build();
3544
3545 let router_b = McpRouter::new().tool(tool_b).tool(tool_c);
3546
3547 let mut merged = McpRouter::new()
3549 .server_info("merged", "1.0")
3550 .merge(router_a)
3551 .merge(router_b);
3552
3553 init_router(&mut merged).await;
3554
3555 let req = RouterRequest {
3557 id: RequestId::Number(1),
3558 inner: McpRequest::ListTools(ListToolsParams::default()),
3559 extensions: Extensions::new(),
3560 };
3561
3562 let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3563
3564 match resp.inner {
3565 Ok(McpResponse::ListTools(result)) => {
3566 assert_eq!(result.tools.len(), 3);
3567 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3568 assert!(names.contains(&"tool_a"));
3569 assert!(names.contains(&"tool_b"));
3570 assert!(names.contains(&"tool_c"));
3571 }
3572 _ => panic!("Expected ListTools response"),
3573 }
3574 }
3575
3576 #[tokio::test]
3577 async fn test_router_merge_overwrites_duplicates() {
3578 let tool_v1 = ToolBuilder::new("shared")
3580 .description("Version 1")
3581 .handler(|_: StringInput| async move { Ok(CallToolResult::text("v1")) })
3582 .build();
3583
3584 let router_a = McpRouter::new().tool(tool_v1);
3585
3586 let tool_v2 = ToolBuilder::new("shared")
3588 .description("Version 2")
3589 .handler(|_: StringInput| async move { Ok(CallToolResult::text("v2")) })
3590 .build();
3591
3592 let router_b = McpRouter::new().tool(tool_v2);
3593
3594 let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3596
3597 init_router(&mut merged).await;
3598
3599 let req = RouterRequest {
3600 id: RequestId::Number(1),
3601 inner: McpRequest::ListTools(ListToolsParams::default()),
3602 extensions: Extensions::new(),
3603 };
3604
3605 let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3606
3607 match resp.inner {
3608 Ok(McpResponse::ListTools(result)) => {
3609 assert_eq!(result.tools.len(), 1);
3610 assert_eq!(result.tools[0].name, "shared");
3611 assert_eq!(result.tools[0].description.as_deref(), Some("Version 2"));
3612 }
3613 _ => panic!("Expected ListTools response"),
3614 }
3615 }
3616
3617 #[tokio::test]
3618 async fn test_router_merge_resources() {
3619 use crate::resource::ResourceBuilder;
3620
3621 let router_a = McpRouter::new().resource(
3623 ResourceBuilder::new("file:///a.txt")
3624 .name("File A")
3625 .text("content a"),
3626 );
3627
3628 let router_b = McpRouter::new().resource(
3629 ResourceBuilder::new("file:///b.txt")
3630 .name("File B")
3631 .text("content b"),
3632 );
3633
3634 let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3635
3636 init_router(&mut merged).await;
3637
3638 let req = RouterRequest {
3639 id: RequestId::Number(1),
3640 inner: McpRequest::ListResources(ListResourcesParams::default()),
3641 extensions: Extensions::new(),
3642 };
3643
3644 let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3645
3646 match resp.inner {
3647 Ok(McpResponse::ListResources(result)) => {
3648 assert_eq!(result.resources.len(), 2);
3649 let uris: Vec<&str> = result.resources.iter().map(|r| r.uri.as_str()).collect();
3650 assert!(uris.contains(&"file:///a.txt"));
3651 assert!(uris.contains(&"file:///b.txt"));
3652 }
3653 _ => panic!("Expected ListResources response"),
3654 }
3655 }
3656
3657 #[tokio::test]
3658 async fn test_router_merge_prompts() {
3659 use crate::prompt::PromptBuilder;
3660
3661 let router_a =
3662 McpRouter::new().prompt(PromptBuilder::new("prompt_a").user_message("Hello A"));
3663
3664 let router_b =
3665 McpRouter::new().prompt(PromptBuilder::new("prompt_b").user_message("Hello B"));
3666
3667 let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3668
3669 init_router(&mut merged).await;
3670
3671 let req = RouterRequest {
3672 id: RequestId::Number(1),
3673 inner: McpRequest::ListPrompts(ListPromptsParams::default()),
3674 extensions: Extensions::new(),
3675 };
3676
3677 let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3678
3679 match resp.inner {
3680 Ok(McpResponse::ListPrompts(result)) => {
3681 assert_eq!(result.prompts.len(), 2);
3682 let names: Vec<&str> = result.prompts.iter().map(|p| p.name.as_str()).collect();
3683 assert!(names.contains(&"prompt_a"));
3684 assert!(names.contains(&"prompt_b"));
3685 }
3686 _ => panic!("Expected ListPrompts response"),
3687 }
3688 }
3689
3690 #[tokio::test]
3691 async fn test_router_nest_prefixes_tools() {
3692 let tool_query = ToolBuilder::new("query")
3694 .description("Query the database")
3695 .handler(|_: StringInput| async move { Ok(CallToolResult::text("query result")) })
3696 .build();
3697 let tool_insert = ToolBuilder::new("insert")
3698 .description("Insert into database")
3699 .handler(|_: StringInput| async move { Ok(CallToolResult::text("insert result")) })
3700 .build();
3701
3702 let db_router = McpRouter::new().tool(tool_query).tool(tool_insert);
3703
3704 let mut router = McpRouter::new()
3706 .server_info("nested", "1.0")
3707 .nest("db", db_router);
3708
3709 init_router(&mut router).await;
3710
3711 let req = RouterRequest {
3712 id: RequestId::Number(1),
3713 inner: McpRequest::ListTools(ListToolsParams::default()),
3714 extensions: Extensions::new(),
3715 };
3716
3717 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3718
3719 match resp.inner {
3720 Ok(McpResponse::ListTools(result)) => {
3721 assert_eq!(result.tools.len(), 2);
3722 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3723 assert!(names.contains(&"db.query"));
3724 assert!(names.contains(&"db.insert"));
3725 }
3726 _ => panic!("Expected ListTools response"),
3727 }
3728 }
3729
3730 #[tokio::test]
3731 async fn test_router_nest_call_prefixed_tool() {
3732 let tool = ToolBuilder::new("echo")
3733 .description("Echo input")
3734 .handler(|input: StringInput| async move { Ok(CallToolResult::text(&input.value)) })
3735 .build();
3736
3737 let nested_router = McpRouter::new().tool(tool);
3738
3739 let mut router = McpRouter::new().nest("api", nested_router);
3740
3741 init_router(&mut router).await;
3742
3743 let req = RouterRequest {
3745 id: RequestId::Number(1),
3746 inner: McpRequest::CallTool(CallToolParams {
3747 name: "api.echo".to_string(),
3748 arguments: serde_json::json!({"value": "hello world"}),
3749 meta: None,
3750 }),
3751 extensions: Extensions::new(),
3752 };
3753
3754 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3755
3756 match resp.inner {
3757 Ok(McpResponse::CallTool(result)) => {
3758 assert!(!result.is_error);
3759 match &result.content[0] {
3760 Content::Text { text, .. } => assert_eq!(text, "hello world"),
3761 _ => panic!("Expected text content"),
3762 }
3763 }
3764 _ => panic!("Expected CallTool response"),
3765 }
3766 }
3767
3768 #[tokio::test]
3769 async fn test_router_multiple_nests() {
3770 let db_tool = ToolBuilder::new("query")
3771 .description("Database query")
3772 .handler(|_: StringInput| async move { Ok(CallToolResult::text("db")) })
3773 .build();
3774
3775 let api_tool = ToolBuilder::new("fetch")
3776 .description("API fetch")
3777 .handler(|_: StringInput| async move { Ok(CallToolResult::text("api")) })
3778 .build();
3779
3780 let db_router = McpRouter::new().tool(db_tool);
3781 let api_router = McpRouter::new().tool(api_tool);
3782
3783 let mut router = McpRouter::new()
3784 .nest("db", db_router)
3785 .nest("api", api_router);
3786
3787 init_router(&mut router).await;
3788
3789 let req = RouterRequest {
3790 id: RequestId::Number(1),
3791 inner: McpRequest::ListTools(ListToolsParams::default()),
3792 extensions: Extensions::new(),
3793 };
3794
3795 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3796
3797 match resp.inner {
3798 Ok(McpResponse::ListTools(result)) => {
3799 assert_eq!(result.tools.len(), 2);
3800 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3801 assert!(names.contains(&"db.query"));
3802 assert!(names.contains(&"api.fetch"));
3803 }
3804 _ => panic!("Expected ListTools response"),
3805 }
3806 }
3807
3808 #[tokio::test]
3809 async fn test_router_merge_and_nest_combined() {
3810 let tool_a = ToolBuilder::new("local")
3812 .description("Local tool")
3813 .handler(|_: StringInput| async move { Ok(CallToolResult::text("local")) })
3814 .build();
3815
3816 let nested_tool = ToolBuilder::new("remote")
3817 .description("Remote tool")
3818 .handler(|_: StringInput| async move { Ok(CallToolResult::text("remote")) })
3819 .build();
3820
3821 let nested_router = McpRouter::new().tool(nested_tool);
3822
3823 let mut router = McpRouter::new()
3824 .tool(tool_a)
3825 .nest("external", nested_router);
3826
3827 init_router(&mut router).await;
3828
3829 let req = RouterRequest {
3830 id: RequestId::Number(1),
3831 inner: McpRequest::ListTools(ListToolsParams::default()),
3832 extensions: Extensions::new(),
3833 };
3834
3835 let resp = router.ready().await.unwrap().call(req).await.unwrap();
3836
3837 match resp.inner {
3838 Ok(McpResponse::ListTools(result)) => {
3839 assert_eq!(result.tools.len(), 2);
3840 let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3841 assert!(names.contains(&"local"));
3842 assert!(names.contains(&"external.remote"));
3843 }
3844 _ => panic!("Expected ListTools response"),
3845 }
3846 }
3847
3848 #[tokio::test]
3849 async fn test_router_merge_preserves_server_info() {
3850 let child_router = McpRouter::new()
3851 .server_info("child", "2.0")
3852 .instructions("Child instructions");
3853
3854 let mut router = McpRouter::new()
3855 .server_info("parent", "1.0")
3856 .instructions("Parent instructions")
3857 .merge(child_router);
3858
3859 init_router(&mut router).await;
3860
3861 let init_req = RouterRequest {
3863 id: RequestId::Number(99),
3864 inner: McpRequest::Initialize(InitializeParams {
3865 protocol_version: "2025-11-25".to_string(),
3866 capabilities: ClientCapabilities::default(),
3867 client_info: Implementation {
3868 name: "test".to_string(),
3869 version: "1.0".to_string(),
3870 ..Default::default()
3871 },
3872 }),
3873 extensions: Extensions::new(),
3874 };
3875
3876 let child_router2 = McpRouter::new().server_info("child", "2.0");
3878 let mut fresh_router = McpRouter::new()
3879 .server_info("parent", "1.0")
3880 .merge(child_router2);
3881
3882 let resp = fresh_router
3883 .ready()
3884 .await
3885 .unwrap()
3886 .call(init_req)
3887 .await
3888 .unwrap();
3889
3890 match resp.inner {
3891 Ok(McpResponse::Initialize(result)) => {
3892 assert_eq!(result.server_info.name, "parent");
3893 assert_eq!(result.server_info.version, "1.0");
3894 }
3895 _ => panic!("Expected Initialize response"),
3896 }
3897 }
3898
3899 #[tokio::test]
3904 async fn test_auto_instructions_tools_only() {
3905 let tool_a = ToolBuilder::new("alpha")
3906 .description("Alpha tool")
3907 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
3908 .build();
3909 let tool_b = ToolBuilder::new("beta")
3910 .description("Beta tool")
3911 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
3912 .build();
3913
3914 let mut router = McpRouter::new()
3915 .auto_instructions()
3916 .tool(tool_a)
3917 .tool(tool_b);
3918
3919 let resp = send_initialize(&mut router).await;
3920 let instructions = resp.instructions.expect("should have instructions");
3921
3922 assert!(instructions.contains("## Tools"));
3923 assert!(instructions.contains("- **alpha**: Alpha tool"));
3924 assert!(instructions.contains("- **beta**: Beta tool"));
3925 assert!(!instructions.contains("## Resources"));
3927 assert!(!instructions.contains("## Prompts"));
3928 }
3929
3930 #[tokio::test]
3931 async fn test_auto_instructions_with_annotations() {
3932 let read_only_tool = ToolBuilder::new("query")
3933 .description("Run a query")
3934 .read_only()
3935 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
3936 .build();
3937 let destructive_tool = ToolBuilder::new("delete")
3938 .description("Delete a record")
3939 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
3940 .build();
3941 let idempotent_tool = ToolBuilder::new("upsert")
3942 .description("Upsert a record")
3943 .non_destructive()
3944 .idempotent()
3945 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
3946 .build();
3947
3948 let mut router = McpRouter::new()
3949 .auto_instructions()
3950 .tool(read_only_tool)
3951 .tool(destructive_tool)
3952 .tool(idempotent_tool);
3953
3954 let resp = send_initialize(&mut router).await;
3955 let instructions = resp.instructions.unwrap();
3956
3957 assert!(instructions.contains("- **query**: Run a query [read-only]"));
3958 assert!(instructions.contains("- **delete**: Delete a record\n"));
3960 assert!(instructions.contains("- **upsert**: Upsert a record [idempotent]"));
3961 }
3962
3963 #[tokio::test]
3964 async fn test_auto_instructions_with_resources() {
3965 use crate::resource::ResourceBuilder;
3966
3967 let resource = ResourceBuilder::new("file:///schema.sql")
3968 .name("Schema")
3969 .description("Database schema")
3970 .text("CREATE TABLE ...");
3971
3972 let mut router = McpRouter::new().auto_instructions().resource(resource);
3973
3974 let resp = send_initialize(&mut router).await;
3975 let instructions = resp.instructions.unwrap();
3976
3977 assert!(instructions.contains("## Resources"));
3978 assert!(instructions.contains("- **file:///schema.sql**: Database schema"));
3979 assert!(!instructions.contains("## Tools"));
3980 }
3981
3982 #[tokio::test]
3983 async fn test_auto_instructions_with_resource_templates() {
3984 use crate::resource::ResourceTemplateBuilder;
3985
3986 let template = ResourceTemplateBuilder::new("file:///{path}")
3987 .name("File")
3988 .description("Read a file by path")
3989 .handler(
3990 |_uri: String, _vars: std::collections::HashMap<String, String>| async move {
3991 Ok(crate::ReadResourceResult::text("content", "text/plain"))
3992 },
3993 );
3994
3995 let mut router = McpRouter::new()
3996 .auto_instructions()
3997 .resource_template(template);
3998
3999 let resp = send_initialize(&mut router).await;
4000 let instructions = resp.instructions.unwrap();
4001
4002 assert!(instructions.contains("## Resources"));
4003 assert!(instructions.contains("- **file:///{path}**: Read a file by path"));
4004 }
4005
4006 #[tokio::test]
4007 async fn test_auto_instructions_with_prompts() {
4008 use crate::prompt::PromptBuilder;
4009
4010 let prompt = PromptBuilder::new("write_query")
4011 .description("Help write a SQL query")
4012 .user_message("Write a query for: {task}");
4013
4014 let mut router = McpRouter::new().auto_instructions().prompt(prompt);
4015
4016 let resp = send_initialize(&mut router).await;
4017 let instructions = resp.instructions.unwrap();
4018
4019 assert!(instructions.contains("## Prompts"));
4020 assert!(instructions.contains("- **write_query**: Help write a SQL query"));
4021 assert!(!instructions.contains("## Tools"));
4022 }
4023
4024 #[tokio::test]
4025 async fn test_auto_instructions_all_sections() {
4026 use crate::prompt::PromptBuilder;
4027 use crate::resource::ResourceBuilder;
4028
4029 let tool = ToolBuilder::new("query")
4030 .description("Execute SQL")
4031 .read_only()
4032 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
4033 .build();
4034 let resource = ResourceBuilder::new("db://schema")
4035 .name("Schema")
4036 .description("Full database schema")
4037 .text("schema");
4038 let prompt = PromptBuilder::new("write_query")
4039 .description("Help write a SQL query")
4040 .user_message("Write a query");
4041
4042 let mut router = McpRouter::new()
4043 .auto_instructions()
4044 .tool(tool)
4045 .resource(resource)
4046 .prompt(prompt);
4047
4048 let resp = send_initialize(&mut router).await;
4049 let instructions = resp.instructions.unwrap();
4050
4051 assert!(instructions.contains("## Tools"));
4053 assert!(instructions.contains("## Resources"));
4054 assert!(instructions.contains("## Prompts"));
4055
4056 let tools_pos = instructions.find("## Tools").unwrap();
4058 let resources_pos = instructions.find("## Resources").unwrap();
4059 let prompts_pos = instructions.find("## Prompts").unwrap();
4060 assert!(tools_pos < resources_pos);
4061 assert!(resources_pos < prompts_pos);
4062 }
4063
4064 #[tokio::test]
4065 async fn test_auto_instructions_with_prefix_and_suffix() {
4066 let tool = ToolBuilder::new("echo")
4067 .description("Echo input")
4068 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
4069 .build();
4070
4071 let mut router = McpRouter::new()
4072 .auto_instructions_with(
4073 Some("This server provides echo capabilities."),
4074 Some("Contact admin@example.com for support."),
4075 )
4076 .tool(tool);
4077
4078 let resp = send_initialize(&mut router).await;
4079 let instructions = resp.instructions.unwrap();
4080
4081 assert!(instructions.starts_with("This server provides echo capabilities."));
4082 assert!(instructions.ends_with("Contact admin@example.com for support."));
4083 assert!(instructions.contains("## Tools"));
4084 assert!(instructions.contains("- **echo**: Echo input"));
4085 }
4086
4087 #[tokio::test]
4088 async fn test_auto_instructions_prefix_only() {
4089 let tool = ToolBuilder::new("echo")
4090 .description("Echo input")
4091 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
4092 .build();
4093
4094 let mut router = McpRouter::new()
4095 .auto_instructions_with(Some("My server intro."), None::<String>)
4096 .tool(tool);
4097
4098 let resp = send_initialize(&mut router).await;
4099 let instructions = resp.instructions.unwrap();
4100
4101 assert!(instructions.starts_with("My server intro."));
4102 assert!(instructions.contains("- **echo**: Echo input"));
4103 }
4104
4105 #[tokio::test]
4106 async fn test_auto_instructions_empty_router() {
4107 let mut router = McpRouter::new().auto_instructions();
4108
4109 let resp = send_initialize(&mut router).await;
4110 let instructions = resp.instructions.expect("should have instructions");
4111
4112 assert!(!instructions.contains("## Tools"));
4114 assert!(!instructions.contains("## Resources"));
4115 assert!(!instructions.contains("## Prompts"));
4116 assert!(instructions.is_empty());
4117 }
4118
4119 #[tokio::test]
4120 async fn test_auto_instructions_overrides_manual() {
4121 let tool = ToolBuilder::new("echo")
4122 .description("Echo input")
4123 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
4124 .build();
4125
4126 let mut router = McpRouter::new()
4127 .instructions("This will be overridden")
4128 .auto_instructions()
4129 .tool(tool);
4130
4131 let resp = send_initialize(&mut router).await;
4132 let instructions = resp.instructions.unwrap();
4133
4134 assert!(!instructions.contains("This will be overridden"));
4135 assert!(instructions.contains("- **echo**: Echo input"));
4136 }
4137
4138 #[tokio::test]
4139 async fn test_no_auto_instructions_returns_manual() {
4140 let tool = ToolBuilder::new("echo")
4141 .description("Echo input")
4142 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
4143 .build();
4144
4145 let mut router = McpRouter::new()
4146 .instructions("Manual instructions here")
4147 .tool(tool);
4148
4149 let resp = send_initialize(&mut router).await;
4150 let instructions = resp.instructions.unwrap();
4151
4152 assert_eq!(instructions, "Manual instructions here");
4153 }
4154
4155 #[tokio::test]
4156 async fn test_auto_instructions_no_description_fallback() {
4157 let tool = ToolBuilder::new("mystery")
4158 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
4159 .build();
4160
4161 let mut router = McpRouter::new().auto_instructions().tool(tool);
4162
4163 let resp = send_initialize(&mut router).await;
4164 let instructions = resp.instructions.unwrap();
4165
4166 assert!(instructions.contains("- **mystery**: No description"));
4167 }
4168
4169 #[tokio::test]
4170 async fn test_auto_instructions_sorted_alphabetically() {
4171 let tool_z = ToolBuilder::new("zebra")
4172 .description("Z tool")
4173 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
4174 .build();
4175 let tool_a = ToolBuilder::new("alpha")
4176 .description("A tool")
4177 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
4178 .build();
4179 let tool_m = ToolBuilder::new("middle")
4180 .description("M tool")
4181 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
4182 .build();
4183
4184 let mut router = McpRouter::new()
4185 .auto_instructions()
4186 .tool(tool_z)
4187 .tool(tool_a)
4188 .tool(tool_m);
4189
4190 let resp = send_initialize(&mut router).await;
4191 let instructions = resp.instructions.unwrap();
4192
4193 let alpha_pos = instructions.find("**alpha**").unwrap();
4194 let middle_pos = instructions.find("**middle**").unwrap();
4195 let zebra_pos = instructions.find("**zebra**").unwrap();
4196 assert!(alpha_pos < middle_pos);
4197 assert!(middle_pos < zebra_pos);
4198 }
4199
4200 #[tokio::test]
4201 async fn test_auto_instructions_read_only_and_idempotent_tags() {
4202 let tool = ToolBuilder::new("safe_update")
4203 .description("Safe update operation")
4204 .idempotent()
4205 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
4206 .build();
4207
4208 let mut router = McpRouter::new().auto_instructions().tool(tool);
4209
4210 let resp = send_initialize(&mut router).await;
4211 let instructions = resp.instructions.unwrap();
4212
4213 assert!(
4214 instructions.contains("[idempotent]"),
4215 "got: {}",
4216 instructions
4217 );
4218 }
4219
4220 #[tokio::test]
4221 async fn test_auto_instructions_lazy_generation() {
4222 let mut router = McpRouter::new().auto_instructions();
4225
4226 let tool = ToolBuilder::new("late_tool")
4227 .description("Added after auto_instructions")
4228 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
4229 .build();
4230
4231 router = router.tool(tool);
4232
4233 let resp = send_initialize(&mut router).await;
4234 let instructions = resp.instructions.unwrap();
4235
4236 assert!(instructions.contains("- **late_tool**: Added after auto_instructions"));
4237 }
4238
4239 #[tokio::test]
4240 async fn test_auto_instructions_multiple_annotation_tags() {
4241 let tool = ToolBuilder::new("update")
4242 .description("Update a record")
4243 .annotations(ToolAnnotations {
4244 read_only_hint: true,
4245 idempotent_hint: true,
4246 ..Default::default()
4247 })
4248 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
4249 .build();
4250
4251 let mut router = McpRouter::new().auto_instructions().tool(tool);
4252
4253 let resp = send_initialize(&mut router).await;
4254 let instructions = resp.instructions.unwrap();
4255
4256 assert!(
4257 instructions.contains("[read-only, idempotent]"),
4258 "got: {}",
4259 instructions
4260 );
4261 }
4262
4263 #[tokio::test]
4264 async fn test_auto_instructions_no_annotations_no_tags() {
4265 let tool = ToolBuilder::new("fetch")
4267 .description("Fetch data")
4268 .handler(|_: AddInput| async move { Ok(CallToolResult::text("ok")) })
4269 .build();
4270
4271 let mut router = McpRouter::new().auto_instructions().tool(tool);
4272
4273 let resp = send_initialize(&mut router).await;
4274 let instructions = resp.instructions.unwrap();
4275
4276 assert!(
4278 !instructions.contains('['),
4279 "should have no tags, got: {}",
4280 instructions
4281 );
4282 assert!(instructions.contains("- **fetch**: Fetch data"));
4283 }
4284
4285 async fn send_initialize(router: &mut McpRouter) -> InitializeResult {
4287 let init_req = RouterRequest {
4288 id: RequestId::Number(0),
4289 inner: McpRequest::Initialize(InitializeParams {
4290 protocol_version: "2025-11-25".to_string(),
4291 capabilities: ClientCapabilities {
4292 roots: None,
4293 sampling: None,
4294 elicitation: None,
4295 },
4296 client_info: Implementation {
4297 name: "test".to_string(),
4298 version: "1.0".to_string(),
4299 ..Default::default()
4300 },
4301 }),
4302 extensions: Extensions::new(),
4303 };
4304 let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
4305 match resp.inner {
4306 Ok(McpResponse::Initialize(result)) => result,
4307 other => panic!("Expected Initialize response, got {:?}", other),
4308 }
4309 }
4310}