Skip to main content

tower_mcp/
router.rs

1//! MCP Router - routes requests to tools, resources, and prompts
2//!
3//! The router implements Tower's `Service` trait, making it composable with
4//! standard tower middleware.
5
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, RwLock};
10use std::task::{Context, Poll};
11
12use tower_service::Service;
13
14use crate::async_task::TaskStore;
15use crate::context::{
16    CancellationToken, ClientRequesterHandle, NotificationSender, RequestContext,
17    ServerNotification,
18};
19use crate::error::{Error, JsonRpcError, Result};
20use crate::filter::{PromptFilter, ResourceFilter, ToolFilter};
21use crate::prompt::Prompt;
22use crate::protocol::*;
23use crate::resource::{Resource, ResourceTemplate};
24use crate::session::SessionState;
25use crate::tool::Tool;
26
27/// Type alias for completion handler function
28pub type CompletionHandler = Arc<
29    dyn Fn(CompleteParams) -> Pin<Box<dyn Future<Output = Result<CompleteResult>> + Send>>
30        + Send
31        + Sync,
32>;
33
34/// MCP Router that dispatches requests to registered handlers
35///
36/// Implements `tower::Service<McpRequest>` for middleware composition.
37///
38/// # Example
39///
40/// ```rust
41/// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
42/// use schemars::JsonSchema;
43/// use serde::Deserialize;
44///
45/// #[derive(Debug, Deserialize, JsonSchema)]
46/// struct Input { value: String }
47///
48/// let tool = ToolBuilder::new("echo")
49///     .description("Echo input")
50///     .handler(|i: Input| async move { Ok(CallToolResult::text(i.value)) })
51///     .build()
52///     .unwrap();
53///
54/// let router = McpRouter::new()
55///     .server_info("my-server", "1.0.0")
56///     .tool(tool);
57/// ```
58#[derive(Clone)]
59pub struct McpRouter {
60    inner: Arc<McpRouterInner>,
61    session: SessionState,
62}
63
64impl std::fmt::Debug for McpRouter {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("McpRouter")
67            .field("server_name", &self.inner.server_name)
68            .field("server_version", &self.inner.server_version)
69            .field("tools_count", &self.inner.tools.len())
70            .field("resources_count", &self.inner.resources.len())
71            .field("prompts_count", &self.inner.prompts.len())
72            .field("session_phase", &self.session.phase())
73            .finish()
74    }
75}
76
77/// Inner configuration that is shared across clones
78#[derive(Clone)]
79struct McpRouterInner {
80    server_name: String,
81    server_version: String,
82    /// Human-readable title for the server
83    server_title: Option<String>,
84    /// Description of the server
85    server_description: Option<String>,
86    /// Icons for the server
87    server_icons: Option<Vec<ToolIcon>>,
88    /// URL of the server's website
89    server_website_url: Option<String>,
90    instructions: Option<String>,
91    tools: HashMap<String, Arc<Tool>>,
92    resources: HashMap<String, Arc<Resource>>,
93    /// Resource templates for dynamic resource matching (keyed by uri_template)
94    resource_templates: Vec<Arc<ResourceTemplate>>,
95    prompts: HashMap<String, Arc<Prompt>>,
96    /// In-flight requests for cancellation tracking (shared across clones)
97    in_flight: Arc<RwLock<HashMap<RequestId, CancellationToken>>>,
98    /// Channel for sending notifications to connected clients
99    notification_tx: Option<NotificationSender>,
100    /// Handle for sending requests to the client (for sampling, etc.)
101    client_requester: Option<ClientRequesterHandle>,
102    /// Task store for async operations
103    task_store: TaskStore,
104    /// Subscribed resource URIs
105    subscriptions: Arc<RwLock<HashSet<String>>>,
106    /// Handler for completion requests
107    completion_handler: Option<CompletionHandler>,
108    /// Filter for tools based on session state
109    tool_filter: Option<ToolFilter>,
110    /// Filter for resources based on session state
111    resource_filter: Option<ResourceFilter>,
112    /// Filter for prompts based on session state
113    prompt_filter: Option<PromptFilter>,
114    /// Router-level extensions (for state and middleware data)
115    extensions: Arc<crate::context::Extensions>,
116}
117
118impl McpRouter {
119    /// Create a new MCP router
120    pub fn new() -> Self {
121        Self {
122            inner: Arc::new(McpRouterInner {
123                server_name: "tower-mcp".to_string(),
124                server_version: env!("CARGO_PKG_VERSION").to_string(),
125                server_title: None,
126                server_description: None,
127                server_icons: None,
128                server_website_url: None,
129                instructions: None,
130                tools: HashMap::new(),
131                resources: HashMap::new(),
132                resource_templates: Vec::new(),
133                prompts: HashMap::new(),
134                in_flight: Arc::new(RwLock::new(HashMap::new())),
135                notification_tx: None,
136                client_requester: None,
137                task_store: TaskStore::new(),
138                subscriptions: Arc::new(RwLock::new(HashSet::new())),
139                extensions: Arc::new(crate::context::Extensions::new()),
140                completion_handler: None,
141                tool_filter: None,
142                resource_filter: None,
143                prompt_filter: None,
144            }),
145            session: SessionState::new(),
146        }
147    }
148
149    /// Get access to the task store for async operations
150    pub fn task_store(&self) -> &TaskStore {
151        &self.inner.task_store
152    }
153
154    /// Set the notification sender for progress reporting
155    ///
156    /// This is typically called by the transport layer to receive notifications.
157    pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
158        Arc::make_mut(&mut self.inner).notification_tx = Some(tx);
159        self
160    }
161
162    /// Get the notification sender (if configured)
163    pub fn notification_sender(&self) -> Option<&NotificationSender> {
164        self.inner.notification_tx.as_ref()
165    }
166
167    /// Set the client requester for server-to-client requests (sampling, etc.)
168    ///
169    /// This is typically called by bidirectional transports (WebSocket, stdio)
170    /// to enable tool handlers to send requests to the client.
171    pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
172        Arc::make_mut(&mut self.inner).client_requester = Some(requester);
173        self
174    }
175
176    /// Get the client requester (if configured)
177    pub fn client_requester(&self) -> Option<&ClientRequesterHandle> {
178        self.inner.client_requester.as_ref()
179    }
180
181    /// Add router-level state that handlers can access via the `Extension<T>` extractor.
182    ///
183    /// This is the recommended way to share state across all tools, resources, and prompts
184    /// in a router. The state is available to handlers via the [`crate::extract::Extension`]
185    /// extractor.
186    ///
187    /// # Example
188    ///
189    /// ```rust
190    /// use std::sync::Arc;
191    /// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
192    /// use tower_mcp::extract::{Extension, Json};
193    /// use schemars::JsonSchema;
194    /// use serde::Deserialize;
195    ///
196    /// #[derive(Clone)]
197    /// struct AppState {
198    ///     db_url: String,
199    /// }
200    ///
201    /// #[derive(Deserialize, JsonSchema)]
202    /// struct QueryInput {
203    ///     sql: String,
204    /// }
205    ///
206    /// let state = Arc::new(AppState { db_url: "postgres://...".into() });
207    ///
208    /// // Tool extracts state via Extension<T>
209    /// let query_tool = ToolBuilder::new("query")
210    ///     .description("Run a database query")
211    ///     .extractor_handler_typed::<_, _, _, QueryInput>(
212    ///         (),
213    ///         |Extension(state): Extension<Arc<AppState>>, Json(input): Json<QueryInput>| async move {
214    ///             Ok(CallToolResult::text(format!("Query on {}: {}", state.db_url, input.sql)))
215    ///         },
216    ///     )
217    ///     .build()
218    ///     .unwrap();
219    ///
220    /// let router = McpRouter::new()
221    ///     .with_state(state)  // State is now available to all handlers
222    ///     .tool(query_tool);
223    /// ```
224    pub fn with_state<T: Clone + Send + Sync + 'static>(mut self, state: T) -> Self {
225        let inner = Arc::make_mut(&mut self.inner);
226        Arc::make_mut(&mut inner.extensions).insert(state);
227        self
228    }
229
230    /// Add an extension value that handlers can access via the `Extension<T>` extractor.
231    ///
232    /// This is a more general form of `with_state()` for when you need multiple
233    /// typed values available to handlers.
234    pub fn with_extension<T: Clone + Send + Sync + 'static>(self, value: T) -> Self {
235        self.with_state(value)
236    }
237
238    /// Get the router's extensions.
239    pub fn extensions(&self) -> &crate::context::Extensions {
240        &self.inner.extensions
241    }
242
243    /// Create a request context for tracking a request
244    ///
245    /// This registers the request for cancellation tracking and sets up
246    /// progress reporting, client requests, and router extensions if configured.
247    pub fn create_context(
248        &self,
249        request_id: RequestId,
250        progress_token: Option<ProgressToken>,
251    ) -> RequestContext {
252        let ctx = RequestContext::new(request_id.clone());
253
254        // Set up progress token if provided
255        let ctx = if let Some(token) = progress_token {
256            ctx.with_progress_token(token)
257        } else {
258            ctx
259        };
260
261        // Set up notification sender if configured
262        let ctx = if let Some(tx) = &self.inner.notification_tx {
263            ctx.with_notification_sender(tx.clone())
264        } else {
265            ctx
266        };
267
268        // Set up client requester if configured (for sampling support)
269        let ctx = if let Some(requester) = &self.inner.client_requester {
270            ctx.with_client_requester(requester.clone())
271        } else {
272            ctx
273        };
274
275        // Include router extensions (for with_state() and middleware data)
276        let ctx = ctx.with_extensions(self.inner.extensions.clone());
277
278        // Register for cancellation tracking
279        let token = ctx.cancellation_token();
280        if let Ok(mut in_flight) = self.inner.in_flight.write() {
281            in_flight.insert(request_id, token);
282        }
283
284        ctx
285    }
286
287    /// Remove a request from tracking (called when request completes)
288    pub fn complete_request(&self, request_id: &RequestId) {
289        if let Ok(mut in_flight) = self.inner.in_flight.write() {
290            in_flight.remove(request_id);
291        }
292    }
293
294    /// Cancel a tracked request
295    fn cancel_request(&self, request_id: &RequestId) -> bool {
296        let Ok(in_flight) = self.inner.in_flight.read() else {
297            return false;
298        };
299        let Some(token) = in_flight.get(request_id) else {
300            return false;
301        };
302        token.cancel();
303        true
304    }
305
306    /// Set server info
307    pub fn server_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
308        let inner = Arc::make_mut(&mut self.inner);
309        inner.server_name = name.into();
310        inner.server_version = version.into();
311        self
312    }
313
314    /// Set instructions for LLMs describing how to use this server
315    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
316        Arc::make_mut(&mut self.inner).instructions = Some(instructions.into());
317        self
318    }
319
320    /// Set a human-readable title for the server
321    pub fn server_title(mut self, title: impl Into<String>) -> Self {
322        Arc::make_mut(&mut self.inner).server_title = Some(title.into());
323        self
324    }
325
326    /// Set the server description
327    pub fn server_description(mut self, description: impl Into<String>) -> Self {
328        Arc::make_mut(&mut self.inner).server_description = Some(description.into());
329        self
330    }
331
332    /// Set icons for the server
333    pub fn server_icons(mut self, icons: Vec<ToolIcon>) -> Self {
334        Arc::make_mut(&mut self.inner).server_icons = Some(icons);
335        self
336    }
337
338    /// Set the server's website URL
339    pub fn server_website_url(mut self, url: impl Into<String>) -> Self {
340        Arc::make_mut(&mut self.inner).server_website_url = Some(url.into());
341        self
342    }
343
344    /// Register a tool
345    pub fn tool(mut self, tool: Tool) -> Self {
346        Arc::make_mut(&mut self.inner)
347            .tools
348            .insert(tool.name.clone(), Arc::new(tool));
349        self
350    }
351
352    /// Register a resource
353    pub fn resource(mut self, resource: Resource) -> Self {
354        Arc::make_mut(&mut self.inner)
355            .resources
356            .insert(resource.uri.clone(), Arc::new(resource));
357        self
358    }
359
360    /// Register a resource template
361    ///
362    /// Resource templates allow dynamic resources to be matched by URI pattern.
363    /// When a client requests a resource URI that doesn't match any static
364    /// resource, the router tries to match it against registered templates.
365    ///
366    /// # Example
367    ///
368    /// ```rust
369    /// use tower_mcp::{McpRouter, ResourceTemplateBuilder};
370    /// use tower_mcp::protocol::{ReadResourceResult, ResourceContent};
371    /// use std::collections::HashMap;
372    ///
373    /// let template = ResourceTemplateBuilder::new("file:///{path}")
374    ///     .name("Project Files")
375    ///     .handler(|uri: String, vars: HashMap<String, String>| async move {
376    ///         let path = vars.get("path").unwrap_or(&String::new()).clone();
377    ///         Ok(ReadResourceResult {
378    ///             contents: vec![ResourceContent {
379    ///                 uri,
380    ///                 mime_type: Some("text/plain".to_string()),
381    ///                 text: Some(format!("Contents of {}", path)),
382    ///                 blob: None,
383    ///             }],
384    ///         })
385    ///     });
386    ///
387    /// let router = McpRouter::new()
388    ///     .resource_template(template);
389    /// ```
390    pub fn resource_template(mut self, template: ResourceTemplate) -> Self {
391        Arc::make_mut(&mut self.inner)
392            .resource_templates
393            .push(Arc::new(template));
394        self
395    }
396
397    /// Register a prompt
398    pub fn prompt(mut self, prompt: Prompt) -> Self {
399        Arc::make_mut(&mut self.inner)
400            .prompts
401            .insert(prompt.name.clone(), Arc::new(prompt));
402        self
403    }
404
405    /// Register multiple tools at once.
406    ///
407    /// # Example
408    ///
409    /// ```rust
410    /// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
411    /// use schemars::JsonSchema;
412    /// use serde::Deserialize;
413    ///
414    /// #[derive(Debug, Deserialize, JsonSchema)]
415    /// struct Input { value: String }
416    ///
417    /// let tools = vec![
418    ///     ToolBuilder::new("a")
419    ///         .description("Tool A")
420    ///         .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
421    ///         .build().unwrap(),
422    ///     ToolBuilder::new("b")
423    ///         .description("Tool B")
424    ///         .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
425    ///         .build().unwrap(),
426    /// ];
427    ///
428    /// let router = McpRouter::new().tools(tools);
429    /// ```
430    pub fn tools(self, tools: impl IntoIterator<Item = Tool>) -> Self {
431        tools
432            .into_iter()
433            .fold(self, |router, tool| router.tool(tool))
434    }
435
436    /// Register multiple resources at once.
437    ///
438    /// # Example
439    ///
440    /// ```rust
441    /// use tower_mcp::{McpRouter, ResourceBuilder};
442    ///
443    /// let resources = vec![
444    ///     ResourceBuilder::new("file:///a.txt")
445    ///         .name("File A")
446    ///         .text("contents a"),
447    ///     ResourceBuilder::new("file:///b.txt")
448    ///         .name("File B")
449    ///         .text("contents b"),
450    /// ];
451    ///
452    /// let router = McpRouter::new().resources(resources);
453    /// ```
454    pub fn resources(self, resources: impl IntoIterator<Item = Resource>) -> Self {
455        resources
456            .into_iter()
457            .fold(self, |router, resource| router.resource(resource))
458    }
459
460    /// Register multiple prompts at once.
461    ///
462    /// # Example
463    ///
464    /// ```rust
465    /// use tower_mcp::{McpRouter, PromptBuilder};
466    ///
467    /// let prompts = vec![
468    ///     PromptBuilder::new("greet")
469    ///         .description("Greet someone")
470    ///         .user_message("Hello!"),
471    ///     PromptBuilder::new("farewell")
472    ///         .description("Say goodbye")
473    ///         .user_message("Goodbye!"),
474    /// ];
475    ///
476    /// let router = McpRouter::new().prompts(prompts);
477    /// ```
478    pub fn prompts(self, prompts: impl IntoIterator<Item = Prompt>) -> Self {
479        prompts
480            .into_iter()
481            .fold(self, |router, prompt| router.prompt(prompt))
482    }
483
484    /// Merge another router's capabilities into this one.
485    ///
486    /// This combines all tools, resources, resource templates, and prompts from
487    /// the other router into this router. Uses "last wins" semantics for conflicts,
488    /// meaning if both routers have a tool/resource/prompt with the same name,
489    /// the one from `other` will replace the one in `self`.
490    ///
491    /// Server info, instructions, filters, and other router-level configuration
492    /// are NOT merged - only the root router's settings are used.
493    ///
494    /// # Example
495    ///
496    /// ```rust
497    /// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult, ResourceBuilder};
498    /// use schemars::JsonSchema;
499    /// use serde::Deserialize;
500    ///
501    /// #[derive(Debug, Deserialize, JsonSchema)]
502    /// struct Input { value: String }
503    ///
504    /// // Create a router with database tools
505    /// let db_tools = McpRouter::new()
506    ///     .tool(
507    ///         ToolBuilder::new("query")
508    ///             .description("Query the database")
509    ///             .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
510    ///             .build()
511    ///             .unwrap()
512    ///     );
513    ///
514    /// // Create a router with API tools
515    /// let api_tools = McpRouter::new()
516    ///     .tool(
517    ///         ToolBuilder::new("fetch")
518    ///             .description("Fetch from API")
519    ///             .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
520    ///             .build()
521    ///             .unwrap()
522    ///     );
523    ///
524    /// // Merge them together
525    /// let router = McpRouter::new()
526    ///     .server_info("combined", "1.0")
527    ///     .merge(db_tools)
528    ///     .merge(api_tools);
529    /// ```
530    pub fn merge(mut self, other: McpRouter) -> Self {
531        let inner = Arc::make_mut(&mut self.inner);
532        let other_inner = other.inner;
533
534        // Merge tools (last wins)
535        for (name, tool) in &other_inner.tools {
536            inner.tools.insert(name.clone(), tool.clone());
537        }
538
539        // Merge resources (last wins)
540        for (uri, resource) in &other_inner.resources {
541            inner.resources.insert(uri.clone(), resource.clone());
542        }
543
544        // Merge resource templates (append - no deduplication since templates
545        // can have complex matching behavior)
546        for template in &other_inner.resource_templates {
547            inner.resource_templates.push(template.clone());
548        }
549
550        // Merge prompts (last wins)
551        for (name, prompt) in &other_inner.prompts {
552            inner.prompts.insert(name.clone(), prompt.clone());
553        }
554
555        self
556    }
557
558    /// Nest another router's capabilities under a prefix.
559    ///
560    /// This is similar to `merge()`, but all tool names from the nested router
561    /// are prefixed with the given string and a dot separator. For example,
562    /// nesting with prefix "db" will turn a tool named "query" into "db.query".
563    ///
564    /// Resources, resource templates, and prompts are merged without modification
565    /// since they use URIs rather than simple names for identification.
566    ///
567    /// # Example
568    ///
569    /// ```rust
570    /// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
571    /// use schemars::JsonSchema;
572    /// use serde::Deserialize;
573    ///
574    /// #[derive(Debug, Deserialize, JsonSchema)]
575    /// struct Input { value: String }
576    ///
577    /// // Create a router with database tools
578    /// let db_tools = McpRouter::new()
579    ///     .tool(
580    ///         ToolBuilder::new("query")
581    ///             .description("Query the database")
582    ///             .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
583    ///             .build()
584    ///             .unwrap()
585    ///     )
586    ///     .tool(
587    ///         ToolBuilder::new("insert")
588    ///             .description("Insert into database")
589    ///             .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
590    ///             .build()
591    ///             .unwrap()
592    ///     );
593    ///
594    /// // Nest under "db" prefix - tools become "db.query" and "db.insert"
595    /// let router = McpRouter::new()
596    ///     .server_info("combined", "1.0")
597    ///     .nest("db", db_tools);
598    /// ```
599    pub fn nest(mut self, prefix: impl Into<String>, other: McpRouter) -> Self {
600        let prefix = prefix.into();
601        let inner = Arc::make_mut(&mut self.inner);
602        let other_inner = other.inner;
603
604        // Nest tools with prefix
605        for tool in other_inner.tools.values() {
606            let prefixed_tool = tool.with_name_prefix(&prefix);
607            inner
608                .tools
609                .insert(prefixed_tool.name.clone(), Arc::new(prefixed_tool));
610        }
611
612        // Merge resources (no prefix - URIs are already namespaced)
613        for (uri, resource) in &other_inner.resources {
614            inner.resources.insert(uri.clone(), resource.clone());
615        }
616
617        // Merge resource templates (no prefix)
618        for template in &other_inner.resource_templates {
619            inner.resource_templates.push(template.clone());
620        }
621
622        // Merge prompts (no prefix - could be added in future if needed)
623        for (name, prompt) in &other_inner.prompts {
624            inner.prompts.insert(name.clone(), prompt.clone());
625        }
626
627        self
628    }
629
630    /// Register a completion handler for `completion/complete` requests.
631    ///
632    /// The handler receives `CompleteParams` containing the reference (prompt or resource)
633    /// and the argument being completed, and should return completion suggestions.
634    ///
635    /// # Example
636    ///
637    /// ```rust
638    /// use tower_mcp::{McpRouter, CompleteResult};
639    /// use tower_mcp::protocol::{CompleteParams, CompletionReference};
640    ///
641    /// let router = McpRouter::new()
642    ///     .completion_handler(|params: CompleteParams| async move {
643    ///         // Provide completions based on the reference and argument
644    ///         match params.reference {
645    ///             CompletionReference::Prompt { name } => {
646    ///                 // Return prompt argument completions
647    ///                 Ok(CompleteResult::new(vec!["option1".to_string(), "option2".to_string()]))
648    ///             }
649    ///             CompletionReference::Resource { uri } => {
650    ///                 // Return resource URI completions
651    ///                 Ok(CompleteResult::new(vec![]))
652    ///             }
653    ///         }
654    ///     });
655    /// ```
656    pub fn completion_handler<F, Fut>(mut self, handler: F) -> Self
657    where
658        F: Fn(CompleteParams) -> Fut + Send + Sync + 'static,
659        Fut: Future<Output = Result<CompleteResult>> + Send + 'static,
660    {
661        Arc::make_mut(&mut self.inner).completion_handler =
662            Some(Arc::new(move |params| Box::pin(handler(params))));
663        self
664    }
665
666    /// Set a filter for tools based on session state.
667    ///
668    /// The filter determines which tools are visible to each session. Tools that
669    /// don't pass the filter will not appear in `tools/list` responses and will
670    /// return an error if called directly.
671    ///
672    /// # Example
673    ///
674    /// ```rust
675    /// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult, CapabilityFilter, Tool, Filterable};
676    /// use schemars::JsonSchema;
677    /// use serde::Deserialize;
678    ///
679    /// #[derive(Debug, Deserialize, JsonSchema)]
680    /// struct Input { value: String }
681    ///
682    /// let public_tool = ToolBuilder::new("public")
683    ///     .description("Available to everyone")
684    ///     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
685    ///     .build()
686    ///     .unwrap();
687    ///
688    /// let admin_tool = ToolBuilder::new("admin")
689    ///     .description("Admin only")
690    ///     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
691    ///     .build()
692    ///     .unwrap();
693    ///
694    /// let router = McpRouter::new()
695    ///     .tool(public_tool)
696    ///     .tool(admin_tool)
697    ///     .tool_filter(CapabilityFilter::new(|_session, tool: &Tool| {
698    ///         // In real code, check session.extensions() for auth claims
699    ///         tool.name() != "admin"
700    ///     }));
701    /// ```
702    pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
703        Arc::make_mut(&mut self.inner).tool_filter = Some(filter);
704        self
705    }
706
707    /// Set a filter for resources based on session state.
708    ///
709    /// The filter receives the current session state and each resource, returning
710    /// `true` if the resource should be visible to this session. Resources that
711    /// don't pass the filter will not appear in `resources/list` responses and will
712    /// return an error if read directly.
713    ///
714    /// # Example
715    ///
716    /// ```rust
717    /// use tower_mcp::{McpRouter, ResourceBuilder, ReadResourceResult, CapabilityFilter, Resource, Filterable};
718    ///
719    /// let public_resource = ResourceBuilder::new("file:///public.txt")
720    ///     .name("Public File")
721    ///     .description("Available to everyone")
722    ///     .text("public content");
723    ///
724    /// let secret_resource = ResourceBuilder::new("file:///secret.txt")
725    ///     .name("Secret File")
726    ///     .description("Admin only")
727    ///     .text("secret content");
728    ///
729    /// let router = McpRouter::new()
730    ///     .resource(public_resource)
731    ///     .resource(secret_resource)
732    ///     .resource_filter(CapabilityFilter::new(|_session, resource: &Resource| {
733    ///         // In real code, check session.extensions() for auth claims
734    ///         !resource.name().contains("Secret")
735    ///     }));
736    /// ```
737    pub fn resource_filter(mut self, filter: ResourceFilter) -> Self {
738        Arc::make_mut(&mut self.inner).resource_filter = Some(filter);
739        self
740    }
741
742    /// Set a filter for prompts based on session state.
743    ///
744    /// The filter receives the current session state and each prompt, returning
745    /// `true` if the prompt should be visible to this session. Prompts that
746    /// don't pass the filter will not appear in `prompts/list` responses and will
747    /// return an error if accessed directly.
748    ///
749    /// # Example
750    ///
751    /// ```rust
752    /// use tower_mcp::{McpRouter, PromptBuilder, CapabilityFilter, Prompt, Filterable};
753    ///
754    /// let public_prompt = PromptBuilder::new("greeting")
755    ///     .description("A friendly greeting")
756    ///     .user_message("Hello!");
757    ///
758    /// let admin_prompt = PromptBuilder::new("system_debug")
759    ///     .description("Admin debugging prompt")
760    ///     .user_message("Debug info");
761    ///
762    /// let router = McpRouter::new()
763    ///     .prompt(public_prompt)
764    ///     .prompt(admin_prompt)
765    ///     .prompt_filter(CapabilityFilter::new(|_session, prompt: &Prompt| {
766    ///         // In real code, check session.extensions() for auth claims
767    ///         !prompt.name().contains("system")
768    ///     }));
769    /// ```
770    pub fn prompt_filter(mut self, filter: PromptFilter) -> Self {
771        Arc::make_mut(&mut self.inner).prompt_filter = Some(filter);
772        self
773    }
774
775    /// Get access to the session state
776    pub fn session(&self) -> &SessionState {
777        &self.session
778    }
779
780    /// Send a log message notification to the client
781    ///
782    /// This sends a `notifications/message` notification with the given parameters.
783    /// Returns `true` if the notification was sent, `false` if no notification channel
784    /// is configured.
785    ///
786    /// # Example
787    ///
788    /// ```rust,ignore
789    /// use tower_mcp::protocol::{LogLevel, LoggingMessageParams};
790    ///
791    /// // Simple info message
792    /// router.log(LoggingMessageParams::new(LogLevel::Info).with_data(
793    ///     serde_json::json!({"message": "Operation completed"})
794    /// ));
795    ///
796    /// // Error with logger name
797    /// router.log(LoggingMessageParams::new(LogLevel::Error)
798    ///     .with_logger("database")
799    ///     .with_data(serde_json::json!({"error": "Connection failed"})));
800    /// ```
801    pub fn log(&self, params: LoggingMessageParams) -> bool {
802        let Some(tx) = &self.inner.notification_tx else {
803            return false;
804        };
805        tx.try_send(ServerNotification::LogMessage(params)).is_ok()
806    }
807
808    /// Send an info-level log message
809    ///
810    /// Convenience method for sending an info log with optional data.
811    pub fn log_info(&self, message: &str) -> bool {
812        self.log(
813            LoggingMessageParams::new(LogLevel::Info)
814                .with_data(serde_json::json!({ "message": message })),
815        )
816    }
817
818    /// Send a warning-level log message
819    pub fn log_warning(&self, message: &str) -> bool {
820        self.log(
821            LoggingMessageParams::new(LogLevel::Warning)
822                .with_data(serde_json::json!({ "message": message })),
823        )
824    }
825
826    /// Send an error-level log message
827    pub fn log_error(&self, message: &str) -> bool {
828        self.log(
829            LoggingMessageParams::new(LogLevel::Error)
830                .with_data(serde_json::json!({ "message": message })),
831        )
832    }
833
834    /// Send a debug-level log message
835    pub fn log_debug(&self, message: &str) -> bool {
836        self.log(
837            LoggingMessageParams::new(LogLevel::Debug)
838                .with_data(serde_json::json!({ "message": message })),
839        )
840    }
841
842    /// Check if a resource URI is currently subscribed
843    pub fn is_subscribed(&self, uri: &str) -> bool {
844        if let Ok(subs) = self.inner.subscriptions.read() {
845            return subs.contains(uri);
846        }
847        false
848    }
849
850    /// Get a list of all subscribed resource URIs
851    pub fn subscribed_uris(&self) -> Vec<String> {
852        if let Ok(subs) = self.inner.subscriptions.read() {
853            return subs.iter().cloned().collect();
854        }
855        Vec::new()
856    }
857
858    /// Subscribe to a resource URI
859    fn subscribe(&self, uri: &str) -> bool {
860        if let Ok(mut subs) = self.inner.subscriptions.write() {
861            return subs.insert(uri.to_string());
862        }
863        false
864    }
865
866    /// Unsubscribe from a resource URI
867    fn unsubscribe(&self, uri: &str) -> bool {
868        if let Ok(mut subs) = self.inner.subscriptions.write() {
869            return subs.remove(uri);
870        }
871        false
872    }
873
874    /// Notify clients that a subscribed resource has been updated
875    ///
876    /// Only sends the notification if the resource is currently subscribed.
877    /// Returns `true` if the notification was sent.
878    pub fn notify_resource_updated(&self, uri: &str) -> bool {
879        // Only notify if the resource is subscribed
880        if !self.is_subscribed(uri) {
881            return false;
882        }
883
884        let Some(tx) = &self.inner.notification_tx else {
885            return false;
886        };
887        tx.try_send(ServerNotification::ResourceUpdated {
888            uri: uri.to_string(),
889        })
890        .is_ok()
891    }
892
893    /// Notify clients that the list of available resources has changed
894    ///
895    /// Returns `true` if the notification was sent.
896    pub fn notify_resources_list_changed(&self) -> bool {
897        let Some(tx) = &self.inner.notification_tx else {
898            return false;
899        };
900        tx.try_send(ServerNotification::ResourcesListChanged)
901            .is_ok()
902    }
903
904    /// Get server capabilities based on registered handlers
905    fn capabilities(&self) -> ServerCapabilities {
906        let has_resources =
907            !self.inner.resources.is_empty() || !self.inner.resource_templates.is_empty();
908
909        ServerCapabilities {
910            tools: if self.inner.tools.is_empty() {
911                None
912            } else {
913                Some(ToolsCapability::default())
914            },
915            resources: if has_resources {
916                Some(ResourcesCapability {
917                    subscribe: true,
918                    ..Default::default()
919                })
920            } else {
921                None
922            },
923            prompts: if self.inner.prompts.is_empty() {
924                None
925            } else {
926                Some(PromptsCapability::default())
927            },
928            // Always advertise logging capability when notification channel is configured
929            logging: if self.inner.notification_tx.is_some() {
930                Some(LoggingCapability::default())
931            } else {
932                None
933            },
934            // Tasks capability is always available
935            tasks: Some(TasksCapability::default()),
936            // Completions capability when a handler is registered
937            completions: if self.inner.completion_handler.is_some() {
938                Some(CompletionsCapability::default())
939            } else {
940                None
941            },
942        }
943    }
944
945    /// Handle an MCP request
946    async fn handle(&self, request_id: RequestId, request: McpRequest) -> Result<McpResponse> {
947        // Enforce session state - reject requests before initialization
948        let method = request.method_name();
949        if !self.session.is_request_allowed(method) {
950            tracing::warn!(
951                method = %method,
952                phase = ?self.session.phase(),
953                "Request rejected: session not initialized"
954            );
955            return Err(Error::JsonRpc(JsonRpcError::invalid_request(format!(
956                "Session not initialized. Only 'initialize' and 'ping' are allowed before initialization. Got: {}",
957                method
958            ))));
959        }
960
961        match request {
962            McpRequest::Initialize(params) => {
963                tracing::info!(
964                    client = %params.client_info.name,
965                    version = %params.client_info.version,
966                    "Client initializing"
967                );
968
969                // Protocol version negotiation: respond with same version if supported,
970                // otherwise respond with our latest supported version
971                let protocol_version = if crate::protocol::SUPPORTED_PROTOCOL_VERSIONS
972                    .contains(&params.protocol_version.as_str())
973                {
974                    params.protocol_version
975                } else {
976                    crate::protocol::LATEST_PROTOCOL_VERSION.to_string()
977                };
978
979                // Transition session state to Initializing
980                self.session.mark_initializing();
981
982                Ok(McpResponse::Initialize(InitializeResult {
983                    protocol_version,
984                    capabilities: self.capabilities(),
985                    server_info: Implementation {
986                        name: self.inner.server_name.clone(),
987                        version: self.inner.server_version.clone(),
988                        title: self.inner.server_title.clone(),
989                        description: self.inner.server_description.clone(),
990                        icons: self.inner.server_icons.clone(),
991                        website_url: self.inner.server_website_url.clone(),
992                    },
993                    instructions: self.inner.instructions.clone(),
994                }))
995            }
996
997            McpRequest::ListTools(_params) => {
998                let tools: Vec<ToolDefinition> = self
999                    .inner
1000                    .tools
1001                    .values()
1002                    .filter(|t| {
1003                        // Apply tool filter if configured
1004                        self.inner
1005                            .tool_filter
1006                            .as_ref()
1007                            .map(|f| f.is_visible(&self.session, t))
1008                            .unwrap_or(true)
1009                    })
1010                    .map(|t| t.definition())
1011                    .collect();
1012
1013                Ok(McpResponse::ListTools(ListToolsResult {
1014                    tools,
1015                    next_cursor: None,
1016                }))
1017            }
1018
1019            McpRequest::CallTool(params) => {
1020                let tool =
1021                    self.inner.tools.get(&params.name).ok_or_else(|| {
1022                        Error::JsonRpc(JsonRpcError::method_not_found(&params.name))
1023                    })?;
1024
1025                // Check tool filter if configured
1026                if let Some(filter) = &self.inner.tool_filter {
1027                    if !filter.is_visible(&self.session, tool) {
1028                        return Err(filter.denial_error(&params.name));
1029                    }
1030                }
1031
1032                // Extract progress token from request metadata
1033                let progress_token = params.meta.and_then(|m| m.progress_token);
1034                let ctx = self.create_context(request_id, progress_token);
1035
1036                tracing::debug!(tool = %params.name, "Calling tool");
1037                let result = tool.call_with_context(ctx, params.arguments).await;
1038
1039                Ok(McpResponse::CallTool(result))
1040            }
1041
1042            McpRequest::ListResources(_params) => {
1043                let resources: Vec<ResourceDefinition> = self
1044                    .inner
1045                    .resources
1046                    .values()
1047                    .filter(|r| {
1048                        // Apply resource filter if configured
1049                        self.inner
1050                            .resource_filter
1051                            .as_ref()
1052                            .map(|f| f.is_visible(&self.session, r))
1053                            .unwrap_or(true)
1054                    })
1055                    .map(|r| r.definition())
1056                    .collect();
1057
1058                Ok(McpResponse::ListResources(ListResourcesResult {
1059                    resources,
1060                    next_cursor: None,
1061                }))
1062            }
1063
1064            McpRequest::ListResourceTemplates(_params) => {
1065                let resource_templates: Vec<ResourceTemplateDefinition> = self
1066                    .inner
1067                    .resource_templates
1068                    .iter()
1069                    .map(|t| t.definition())
1070                    .collect();
1071
1072                Ok(McpResponse::ListResourceTemplates(
1073                    ListResourceTemplatesResult {
1074                        resource_templates,
1075                        next_cursor: None,
1076                    },
1077                ))
1078            }
1079
1080            McpRequest::ReadResource(params) => {
1081                // First, try to find a static resource
1082                if let Some(resource) = self.inner.resources.get(&params.uri) {
1083                    // Check resource filter if configured
1084                    if let Some(filter) = &self.inner.resource_filter {
1085                        if !filter.is_visible(&self.session, resource) {
1086                            return Err(filter.denial_error(&params.uri));
1087                        }
1088                    }
1089
1090                    tracing::debug!(uri = %params.uri, "Reading static resource");
1091                    let result = resource.read().await;
1092                    return Ok(McpResponse::ReadResource(result));
1093                }
1094
1095                // If no static resource found, try to match against templates
1096                for template in &self.inner.resource_templates {
1097                    if let Some(variables) = template.match_uri(&params.uri) {
1098                        tracing::debug!(
1099                            uri = %params.uri,
1100                            template = %template.uri_template,
1101                            "Reading resource via template"
1102                        );
1103                        let result = template.read(&params.uri, variables).await?;
1104                        return Ok(McpResponse::ReadResource(result));
1105                    }
1106                }
1107
1108                // No match found
1109                Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1110                    &params.uri,
1111                )))
1112            }
1113
1114            McpRequest::SubscribeResource(params) => {
1115                // Verify the resource exists
1116                if !self.inner.resources.contains_key(&params.uri) {
1117                    return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1118                        &params.uri,
1119                    )));
1120                }
1121
1122                tracing::debug!(uri = %params.uri, "Subscribing to resource");
1123                self.subscribe(&params.uri);
1124
1125                Ok(McpResponse::SubscribeResource(EmptyResult {}))
1126            }
1127
1128            McpRequest::UnsubscribeResource(params) => {
1129                // Verify the resource exists
1130                if !self.inner.resources.contains_key(&params.uri) {
1131                    return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1132                        &params.uri,
1133                    )));
1134                }
1135
1136                tracing::debug!(uri = %params.uri, "Unsubscribing from resource");
1137                self.unsubscribe(&params.uri);
1138
1139                Ok(McpResponse::UnsubscribeResource(EmptyResult {}))
1140            }
1141
1142            McpRequest::ListPrompts(_params) => {
1143                let prompts: Vec<PromptDefinition> = self
1144                    .inner
1145                    .prompts
1146                    .values()
1147                    .filter(|p| {
1148                        // Apply prompt filter if configured
1149                        self.inner
1150                            .prompt_filter
1151                            .as_ref()
1152                            .map(|f| f.is_visible(&self.session, p))
1153                            .unwrap_or(true)
1154                    })
1155                    .map(|p| p.definition())
1156                    .collect();
1157
1158                Ok(McpResponse::ListPrompts(ListPromptsResult {
1159                    prompts,
1160                    next_cursor: None,
1161                }))
1162            }
1163
1164            McpRequest::GetPrompt(params) => {
1165                let prompt = self.inner.prompts.get(&params.name).ok_or_else(|| {
1166                    Error::JsonRpc(JsonRpcError::method_not_found(&format!(
1167                        "Prompt not found: {}",
1168                        params.name
1169                    )))
1170                })?;
1171
1172                // Check prompt filter if configured
1173                if let Some(filter) = &self.inner.prompt_filter {
1174                    if !filter.is_visible(&self.session, prompt) {
1175                        return Err(filter.denial_error(&params.name));
1176                    }
1177                }
1178
1179                tracing::debug!(name = %params.name, "Getting prompt");
1180                let result = prompt.get(params.arguments).await?;
1181
1182                Ok(McpResponse::GetPrompt(result))
1183            }
1184
1185            McpRequest::Ping => Ok(McpResponse::Pong(EmptyResult {})),
1186
1187            McpRequest::EnqueueTask(params) => {
1188                // Verify the tool exists
1189                let tool = self.inner.tools.get(&params.tool_name).ok_or_else(|| {
1190                    Error::JsonRpc(JsonRpcError::method_not_found(&format!(
1191                        "Tool not found: {}",
1192                        params.tool_name
1193                    )))
1194                })?;
1195
1196                // Create the task
1197                let (task_id, cancellation_token) = self.inner.task_store.create_task(
1198                    &params.tool_name,
1199                    params.arguments.clone(),
1200                    params.ttl,
1201                );
1202
1203                tracing::info!(task_id = %task_id, tool = %params.tool_name, "Enqueued async task");
1204
1205                // Create a context for the async task execution
1206                let ctx = self.create_context(request_id, None);
1207
1208                // Spawn the task execution in the background
1209                let task_store = self.inner.task_store.clone();
1210                let tool = tool.clone();
1211                let arguments = params.arguments;
1212                let task_id_clone = task_id.clone();
1213
1214                tokio::spawn(async move {
1215                    // Check for cancellation before starting
1216                    if cancellation_token.is_cancelled() {
1217                        tracing::debug!(task_id = %task_id_clone, "Task cancelled before execution");
1218                        return;
1219                    }
1220
1221                    // Execute the tool
1222                    let result = tool.call_with_context(ctx, arguments).await;
1223
1224                    if cancellation_token.is_cancelled() {
1225                        tracing::debug!(task_id = %task_id_clone, "Task cancelled during execution");
1226                    } else if result.is_error {
1227                        // Tool returned an error result
1228                        let error_msg = result.first_text().unwrap_or("Tool execution failed");
1229                        task_store.fail_task(&task_id_clone, error_msg);
1230                        tracing::warn!(task_id = %task_id_clone, error = %error_msg, "Task failed");
1231                    } else {
1232                        task_store.complete_task(&task_id_clone, result);
1233                        tracing::debug!(task_id = %task_id_clone, "Task completed successfully");
1234                    }
1235                });
1236
1237                Ok(McpResponse::EnqueueTask(EnqueueTaskResult {
1238                    task_id,
1239                    status: TaskStatus::Working,
1240                    poll_interval: Some(2),
1241                }))
1242            }
1243
1244            McpRequest::ListTasks(params) => {
1245                let tasks = self.inner.task_store.list_tasks(params.status);
1246
1247                Ok(McpResponse::ListTasks(ListTasksResult {
1248                    tasks,
1249                    next_cursor: None,
1250                }))
1251            }
1252
1253            McpRequest::GetTaskInfo(params) => {
1254                let task = self
1255                    .inner
1256                    .task_store
1257                    .get_task(&params.task_id)
1258                    .ok_or_else(|| {
1259                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
1260                            "Task not found: {}",
1261                            params.task_id
1262                        )))
1263                    })?;
1264
1265                Ok(McpResponse::GetTaskInfo(task))
1266            }
1267
1268            McpRequest::GetTaskResult(params) => {
1269                let (status, result, error) = self
1270                    .inner
1271                    .task_store
1272                    .get_task_full(&params.task_id)
1273                    .ok_or_else(|| {
1274                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
1275                            "Task not found: {}",
1276                            params.task_id
1277                        )))
1278                    })?;
1279
1280                Ok(McpResponse::GetTaskResult(GetTaskResultResult {
1281                    task_id: params.task_id,
1282                    status,
1283                    result,
1284                    error,
1285                }))
1286            }
1287
1288            McpRequest::CancelTask(params) => {
1289                let status = self
1290                    .inner
1291                    .task_store
1292                    .cancel_task(&params.task_id, params.reason.as_deref())
1293                    .ok_or_else(|| {
1294                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
1295                            "Task not found: {}",
1296                            params.task_id
1297                        )))
1298                    })?;
1299
1300                let cancelled = status == TaskStatus::Cancelled;
1301
1302                Ok(McpResponse::CancelTask(CancelTaskResult {
1303                    cancelled,
1304                    status,
1305                }))
1306            }
1307
1308            McpRequest::SetLoggingLevel(params) => {
1309                // Store the log level for filtering outgoing log notifications
1310                // For now, we just accept the request - actual filtering would be
1311                // implemented in the notification sending logic
1312                tracing::debug!(level = ?params.level, "Client set logging level");
1313                Ok(McpResponse::SetLoggingLevel(EmptyResult {}))
1314            }
1315
1316            McpRequest::Complete(params) => {
1317                tracing::debug!(
1318                    reference = ?params.reference,
1319                    argument = %params.argument.name,
1320                    "Completion request"
1321                );
1322
1323                // Delegate to registered completion handler if available
1324                if let Some(ref handler) = self.inner.completion_handler {
1325                    let result = handler(params).await?;
1326                    Ok(McpResponse::Complete(result))
1327                } else {
1328                    // No completion handler registered, return empty completions
1329                    Ok(McpResponse::Complete(CompleteResult::new(vec![])))
1330                }
1331            }
1332
1333            McpRequest::Unknown { method, .. } => {
1334                Err(Error::JsonRpc(JsonRpcError::method_not_found(&method)))
1335            }
1336        }
1337    }
1338
1339    /// Handle an MCP notification (no response expected)
1340    pub fn handle_notification(&self, notification: McpNotification) {
1341        match notification {
1342            McpNotification::Initialized => {
1343                if self.session.mark_initialized() {
1344                    tracing::info!("Session initialized, entering operation phase");
1345                } else {
1346                    tracing::warn!(
1347                        "Received initialized notification in unexpected state: {:?}",
1348                        self.session.phase()
1349                    );
1350                }
1351            }
1352            McpNotification::Cancelled(params) => {
1353                if self.cancel_request(&params.request_id) {
1354                    tracing::info!(
1355                        request_id = ?params.request_id,
1356                        reason = ?params.reason,
1357                        "Request cancelled"
1358                    );
1359                } else {
1360                    tracing::debug!(
1361                        request_id = ?params.request_id,
1362                        reason = ?params.reason,
1363                        "Cancellation requested for unknown request"
1364                    );
1365                }
1366            }
1367            McpNotification::Progress(params) => {
1368                tracing::trace!(
1369                    token = ?params.progress_token,
1370                    progress = params.progress,
1371                    total = ?params.total,
1372                    "Progress notification"
1373                );
1374                // Progress notifications from client are unusual but valid
1375            }
1376            McpNotification::RootsListChanged => {
1377                tracing::info!("Client roots list changed");
1378                // Server should re-request roots if needed
1379                // This is handled by the application layer
1380            }
1381            McpNotification::Unknown { method, .. } => {
1382                tracing::debug!(method = %method, "Unknown notification received");
1383            }
1384        }
1385    }
1386}
1387
1388impl Default for McpRouter {
1389    fn default() -> Self {
1390        Self::new()
1391    }
1392}
1393
1394// =============================================================================
1395// Tower Service implementation
1396// =============================================================================
1397
1398// Re-export Extensions from context for backwards compatibility
1399pub use crate::context::Extensions;
1400
1401/// Request type for the tower Service implementation
1402#[derive(Debug)]
1403pub struct RouterRequest {
1404    pub id: RequestId,
1405    pub inner: McpRequest,
1406    /// Type-map for passing data (e.g., `TokenClaims`) through middleware.
1407    pub extensions: Extensions,
1408}
1409
1410/// Response type for the tower Service implementation
1411#[derive(Debug)]
1412pub struct RouterResponse {
1413    pub id: RequestId,
1414    pub inner: std::result::Result<McpResponse, JsonRpcError>,
1415}
1416
1417impl RouterResponse {
1418    /// Convert to JSON-RPC response
1419    pub fn into_jsonrpc(self) -> JsonRpcResponse {
1420        match self.inner {
1421            Ok(response) => match serde_json::to_value(response) {
1422                Ok(result) => JsonRpcResponse::result(self.id, result),
1423                Err(e) => {
1424                    tracing::error!(error = %e, "Failed to serialize response");
1425                    JsonRpcResponse::error(
1426                        Some(self.id),
1427                        JsonRpcError::internal_error(format!("Serialization error: {}", e)),
1428                    )
1429                }
1430            },
1431            Err(error) => JsonRpcResponse::error(Some(self.id), error),
1432        }
1433    }
1434}
1435
1436impl Service<RouterRequest> for McpRouter {
1437    type Response = RouterResponse;
1438    type Error = std::convert::Infallible; // Errors are in the response
1439    type Future =
1440        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
1441
1442    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
1443        Poll::Ready(Ok(()))
1444    }
1445
1446    fn call(&mut self, req: RouterRequest) -> Self::Future {
1447        let router = self.clone();
1448        let request_id = req.id.clone();
1449        Box::pin(async move {
1450            let result = router.handle(req.id, req.inner).await;
1451            // Clean up tracking after request completes
1452            router.complete_request(&request_id);
1453            Ok(RouterResponse {
1454                id: request_id,
1455                // Map tower-mcp errors to JSON-RPC errors:
1456                // - Error::JsonRpc: forwarded as-is (preserves original code)
1457                // - Error::Tool: mapped to -32603 (Internal Error)
1458                // - All others: mapped to -32603 (Internal Error)
1459                inner: result.map_err(|e| match e {
1460                    Error::JsonRpc(err) => err,
1461                    Error::Tool(err) => JsonRpcError::internal_error(err.to_string()),
1462                    e => JsonRpcError::internal_error(e.to_string()),
1463                }),
1464            })
1465        })
1466    }
1467}
1468
1469#[cfg(test)]
1470mod tests {
1471    use super::*;
1472    use crate::extract::{Context, Json};
1473    use crate::jsonrpc::JsonRpcService;
1474    use crate::tool::ToolBuilder;
1475    use schemars::JsonSchema;
1476    use serde::Deserialize;
1477    use tower::ServiceExt;
1478
1479    #[derive(Debug, Deserialize, JsonSchema)]
1480    struct AddInput {
1481        a: i64,
1482        b: i64,
1483    }
1484
1485    /// Helper to initialize a router for testing
1486    async fn init_router(router: &mut McpRouter) {
1487        // Send initialize request
1488        let init_req = RouterRequest {
1489            id: RequestId::Number(0),
1490            inner: McpRequest::Initialize(InitializeParams {
1491                protocol_version: "2025-11-25".to_string(),
1492                capabilities: ClientCapabilities {
1493                    roots: None,
1494                    sampling: None,
1495                    elicitation: None,
1496                },
1497                client_info: Implementation {
1498                    name: "test".to_string(),
1499                    version: "1.0".to_string(),
1500                    ..Default::default()
1501                },
1502            }),
1503            extensions: Extensions::new(),
1504        };
1505        let _ = router.ready().await.unwrap().call(init_req).await.unwrap();
1506        // Send initialized notification
1507        router.handle_notification(McpNotification::Initialized);
1508    }
1509
1510    #[tokio::test]
1511    async fn test_router_list_tools() {
1512        let add_tool = ToolBuilder::new("add")
1513            .description("Add two numbers")
1514            .handler(|input: AddInput| async move {
1515                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1516            })
1517            .build()
1518            .expect("valid tool name");
1519
1520        let mut router = McpRouter::new().tool(add_tool);
1521
1522        // Initialize session first
1523        init_router(&mut router).await;
1524
1525        let req = RouterRequest {
1526            id: RequestId::Number(1),
1527            inner: McpRequest::ListTools(ListToolsParams::default()),
1528            extensions: Extensions::new(),
1529        };
1530
1531        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1532
1533        match resp.inner {
1534            Ok(McpResponse::ListTools(result)) => {
1535                assert_eq!(result.tools.len(), 1);
1536                assert_eq!(result.tools[0].name, "add");
1537            }
1538            _ => panic!("Expected ListTools response"),
1539        }
1540    }
1541
1542    #[tokio::test]
1543    async fn test_router_call_tool() {
1544        let add_tool = ToolBuilder::new("add")
1545            .description("Add two numbers")
1546            .handler(|input: AddInput| async move {
1547                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1548            })
1549            .build()
1550            .expect("valid tool name");
1551
1552        let mut router = McpRouter::new().tool(add_tool);
1553
1554        // Initialize session first
1555        init_router(&mut router).await;
1556
1557        let req = RouterRequest {
1558            id: RequestId::Number(1),
1559            inner: McpRequest::CallTool(CallToolParams {
1560                name: "add".to_string(),
1561                arguments: serde_json::json!({"a": 2, "b": 3}),
1562                meta: None,
1563            }),
1564            extensions: Extensions::new(),
1565        };
1566
1567        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1568
1569        match resp.inner {
1570            Ok(McpResponse::CallTool(result)) => {
1571                assert!(!result.is_error);
1572                // Check the text content
1573                match &result.content[0] {
1574                    Content::Text { text, .. } => assert_eq!(text, "5"),
1575                    _ => panic!("Expected text content"),
1576                }
1577            }
1578            _ => panic!("Expected CallTool response"),
1579        }
1580    }
1581
1582    /// Helper to initialize a JsonRpcService for testing
1583    async fn init_jsonrpc_service(service: &mut JsonRpcService<McpRouter>, router: &McpRouter) {
1584        let init_req = JsonRpcRequest::new(0, "initialize").with_params(serde_json::json!({
1585            "protocolVersion": "2025-11-25",
1586            "capabilities": {},
1587            "clientInfo": { "name": "test", "version": "1.0" }
1588        }));
1589        let _ = service.call_single(init_req).await.unwrap();
1590        router.handle_notification(McpNotification::Initialized);
1591    }
1592
1593    #[tokio::test]
1594    async fn test_jsonrpc_service() {
1595        let add_tool = ToolBuilder::new("add")
1596            .description("Add two numbers")
1597            .handler(|input: AddInput| async move {
1598                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1599            })
1600            .build()
1601            .expect("valid tool name");
1602
1603        let router = McpRouter::new().tool(add_tool);
1604        let mut service = JsonRpcService::new(router.clone());
1605
1606        // Initialize session first
1607        init_jsonrpc_service(&mut service, &router).await;
1608
1609        let req = JsonRpcRequest::new(1, "tools/list");
1610
1611        let resp = service.call_single(req).await.unwrap();
1612
1613        match resp {
1614            JsonRpcResponse::Result(r) => {
1615                assert_eq!(r.id, RequestId::Number(1));
1616                let tools = r.result.get("tools").unwrap().as_array().unwrap();
1617                assert_eq!(tools.len(), 1);
1618            }
1619            JsonRpcResponse::Error(_) => panic!("Expected success response"),
1620        }
1621    }
1622
1623    #[tokio::test]
1624    async fn test_batch_request() {
1625        let add_tool = ToolBuilder::new("add")
1626            .description("Add two numbers")
1627            .handler(|input: AddInput| async move {
1628                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1629            })
1630            .build()
1631            .expect("valid tool name");
1632
1633        let router = McpRouter::new().tool(add_tool);
1634        let mut service = JsonRpcService::new(router.clone());
1635
1636        // Initialize session first
1637        init_jsonrpc_service(&mut service, &router).await;
1638
1639        // Create a batch of requests
1640        let requests = vec![
1641            JsonRpcRequest::new(1, "tools/list"),
1642            JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1643                "name": "add",
1644                "arguments": {"a": 10, "b": 20}
1645            })),
1646            JsonRpcRequest::new(3, "ping"),
1647        ];
1648
1649        let responses = service.call_batch(requests).await.unwrap();
1650
1651        assert_eq!(responses.len(), 3);
1652
1653        // Check first response (tools/list)
1654        match &responses[0] {
1655            JsonRpcResponse::Result(r) => {
1656                assert_eq!(r.id, RequestId::Number(1));
1657                let tools = r.result.get("tools").unwrap().as_array().unwrap();
1658                assert_eq!(tools.len(), 1);
1659            }
1660            JsonRpcResponse::Error(_) => panic!("Expected success for tools/list"),
1661        }
1662
1663        // Check second response (tools/call)
1664        match &responses[1] {
1665            JsonRpcResponse::Result(r) => {
1666                assert_eq!(r.id, RequestId::Number(2));
1667                let content = r.result.get("content").unwrap().as_array().unwrap();
1668                let text = content[0].get("text").unwrap().as_str().unwrap();
1669                assert_eq!(text, "30");
1670            }
1671            JsonRpcResponse::Error(_) => panic!("Expected success for tools/call"),
1672        }
1673
1674        // Check third response (ping)
1675        match &responses[2] {
1676            JsonRpcResponse::Result(r) => {
1677                assert_eq!(r.id, RequestId::Number(3));
1678            }
1679            JsonRpcResponse::Error(_) => panic!("Expected success for ping"),
1680        }
1681    }
1682
1683    #[tokio::test]
1684    async fn test_empty_batch_error() {
1685        let router = McpRouter::new();
1686        let mut service = JsonRpcService::new(router);
1687
1688        let result = service.call_batch(vec![]).await;
1689        assert!(result.is_err());
1690    }
1691
1692    // =========================================================================
1693    // Progress Token Tests
1694    // =========================================================================
1695
1696    #[tokio::test]
1697    async fn test_progress_token_extraction() {
1698        use crate::context::{ServerNotification, notification_channel};
1699        use crate::protocol::ProgressToken;
1700        use std::sync::Arc;
1701        use std::sync::atomic::{AtomicBool, Ordering};
1702
1703        // Track whether progress was reported
1704        let progress_reported = Arc::new(AtomicBool::new(false));
1705        let progress_ref = progress_reported.clone();
1706
1707        // Create a tool that reports progress
1708        let tool = ToolBuilder::new("progress_tool")
1709            .description("Tool that reports progress")
1710            .extractor_handler_typed::<_, _, _, AddInput>(
1711                (),
1712                move |ctx: Context, Json(_input): Json<AddInput>| {
1713                    let reported = progress_ref.clone();
1714                    async move {
1715                        // Report progress - this should work if token was extracted
1716                        ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
1717                            .await;
1718                        reported.store(true, Ordering::SeqCst);
1719                        Ok(CallToolResult::text("done"))
1720                    }
1721                },
1722            )
1723            .build()
1724            .expect("valid tool name");
1725
1726        // Set up notification channel
1727        let (tx, mut rx) = notification_channel(10);
1728        let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1729        let mut service = JsonRpcService::new(router.clone());
1730
1731        // Initialize
1732        init_jsonrpc_service(&mut service, &router).await;
1733
1734        // Call tool WITH progress token in _meta
1735        let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1736            "name": "progress_tool",
1737            "arguments": {"a": 1, "b": 2},
1738            "_meta": {
1739                "progressToken": "test-token-123"
1740            }
1741        }));
1742
1743        let resp = service.call_single(req).await.unwrap();
1744
1745        // Verify the tool was called successfully
1746        match resp {
1747            JsonRpcResponse::Result(_) => {}
1748            JsonRpcResponse::Error(e) => panic!("Expected success, got error: {:?}", e),
1749        }
1750
1751        // Verify progress was reported by handler
1752        assert!(progress_reported.load(Ordering::SeqCst));
1753
1754        // Verify progress notification was sent through channel
1755        let notification = rx.try_recv().expect("Expected progress notification");
1756        match notification {
1757            ServerNotification::Progress(params) => {
1758                assert_eq!(
1759                    params.progress_token,
1760                    ProgressToken::String("test-token-123".to_string())
1761                );
1762                assert_eq!(params.progress, 50.0);
1763                assert_eq!(params.total, Some(100.0));
1764                assert_eq!(params.message.as_deref(), Some("Halfway"));
1765            }
1766            _ => panic!("Expected Progress notification"),
1767        }
1768    }
1769
1770    #[tokio::test]
1771    async fn test_tool_call_without_progress_token() {
1772        use crate::context::notification_channel;
1773        use std::sync::Arc;
1774        use std::sync::atomic::{AtomicBool, Ordering};
1775
1776        let progress_attempted = Arc::new(AtomicBool::new(false));
1777        let progress_ref = progress_attempted.clone();
1778
1779        let tool = ToolBuilder::new("no_token_tool")
1780            .description("Tool that tries to report progress without token")
1781            .extractor_handler_typed::<_, _, _, AddInput>(
1782                (),
1783                move |ctx: Context, Json(_input): Json<AddInput>| {
1784                    let attempted = progress_ref.clone();
1785                    async move {
1786                        // Try to report progress - should be a no-op without token
1787                        ctx.report_progress(50.0, Some(100.0), None).await;
1788                        attempted.store(true, Ordering::SeqCst);
1789                        Ok(CallToolResult::text("done"))
1790                    }
1791                },
1792            )
1793            .build()
1794            .expect("valid tool name");
1795
1796        let (tx, mut rx) = notification_channel(10);
1797        let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1798        let mut service = JsonRpcService::new(router.clone());
1799
1800        init_jsonrpc_service(&mut service, &router).await;
1801
1802        // Call tool WITHOUT progress token
1803        let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1804            "name": "no_token_tool",
1805            "arguments": {"a": 1, "b": 2}
1806        }));
1807
1808        let resp = service.call_single(req).await.unwrap();
1809        assert!(matches!(resp, JsonRpcResponse::Result(_)));
1810
1811        // Handler was called
1812        assert!(progress_attempted.load(Ordering::SeqCst));
1813
1814        // But no notification was sent (no progress token)
1815        assert!(rx.try_recv().is_err());
1816    }
1817
1818    #[tokio::test]
1819    async fn test_batch_errors_returned_not_dropped() {
1820        let add_tool = ToolBuilder::new("add")
1821            .description("Add two numbers")
1822            .handler(|input: AddInput| async move {
1823                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1824            })
1825            .build()
1826            .expect("valid tool name");
1827
1828        let router = McpRouter::new().tool(add_tool);
1829        let mut service = JsonRpcService::new(router.clone());
1830
1831        init_jsonrpc_service(&mut service, &router).await;
1832
1833        // Create a batch with one valid and one invalid request
1834        let requests = vec![
1835            // Valid request
1836            JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1837                "name": "add",
1838                "arguments": {"a": 10, "b": 20}
1839            })),
1840            // Invalid request - tool doesn't exist
1841            JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1842                "name": "nonexistent_tool",
1843                "arguments": {}
1844            })),
1845            // Another valid request
1846            JsonRpcRequest::new(3, "ping"),
1847        ];
1848
1849        let responses = service.call_batch(requests).await.unwrap();
1850
1851        // All three requests should have responses (errors are not dropped)
1852        assert_eq!(responses.len(), 3);
1853
1854        // First should be success
1855        match &responses[0] {
1856            JsonRpcResponse::Result(r) => {
1857                assert_eq!(r.id, RequestId::Number(1));
1858            }
1859            JsonRpcResponse::Error(_) => panic!("Expected success for first request"),
1860        }
1861
1862        // Second should be an error (tool not found)
1863        match &responses[1] {
1864            JsonRpcResponse::Error(e) => {
1865                assert_eq!(e.id, Some(RequestId::Number(2)));
1866                // Error should indicate method not found
1867                assert!(e.error.message.contains("not found") || e.error.code == -32601);
1868            }
1869            JsonRpcResponse::Result(_) => panic!("Expected error for second request"),
1870        }
1871
1872        // Third should be success
1873        match &responses[2] {
1874            JsonRpcResponse::Result(r) => {
1875                assert_eq!(r.id, RequestId::Number(3));
1876            }
1877            JsonRpcResponse::Error(_) => panic!("Expected success for third request"),
1878        }
1879    }
1880
1881    // =========================================================================
1882    // Resource Template Tests
1883    // =========================================================================
1884
1885    #[tokio::test]
1886    async fn test_list_resource_templates() {
1887        use crate::resource::ResourceTemplateBuilder;
1888        use std::collections::HashMap;
1889
1890        let template = ResourceTemplateBuilder::new("file:///{path}")
1891            .name("Project Files")
1892            .description("Access project files")
1893            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1894                Ok(ReadResourceResult {
1895                    contents: vec![ResourceContent {
1896                        uri,
1897                        mime_type: None,
1898                        text: None,
1899                        blob: None,
1900                    }],
1901                })
1902            });
1903
1904        let mut router = McpRouter::new().resource_template(template);
1905
1906        // Initialize session
1907        init_router(&mut router).await;
1908
1909        let req = RouterRequest {
1910            id: RequestId::Number(1),
1911            inner: McpRequest::ListResourceTemplates(ListResourceTemplatesParams::default()),
1912            extensions: Extensions::new(),
1913        };
1914
1915        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1916
1917        match resp.inner {
1918            Ok(McpResponse::ListResourceTemplates(result)) => {
1919                assert_eq!(result.resource_templates.len(), 1);
1920                assert_eq!(result.resource_templates[0].uri_template, "file:///{path}");
1921                assert_eq!(result.resource_templates[0].name, "Project Files");
1922            }
1923            _ => panic!("Expected ListResourceTemplates response"),
1924        }
1925    }
1926
1927    #[tokio::test]
1928    async fn test_read_resource_via_template() {
1929        use crate::resource::ResourceTemplateBuilder;
1930        use std::collections::HashMap;
1931
1932        let template = ResourceTemplateBuilder::new("db://users/{id}")
1933            .name("User Records")
1934            .handler(|uri: String, vars: HashMap<String, String>| async move {
1935                let id = vars.get("id").unwrap().clone();
1936                Ok(ReadResourceResult {
1937                    contents: vec![ResourceContent {
1938                        uri,
1939                        mime_type: Some("application/json".to_string()),
1940                        text: Some(format!(r#"{{"id": "{}"}}"#, id)),
1941                        blob: None,
1942                    }],
1943                })
1944            });
1945
1946        let mut router = McpRouter::new().resource_template(template);
1947
1948        // Initialize session
1949        init_router(&mut router).await;
1950
1951        // Read a resource that matches the template
1952        let req = RouterRequest {
1953            id: RequestId::Number(1),
1954            inner: McpRequest::ReadResource(ReadResourceParams {
1955                uri: "db://users/123".to_string(),
1956            }),
1957            extensions: Extensions::new(),
1958        };
1959
1960        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1961
1962        match resp.inner {
1963            Ok(McpResponse::ReadResource(result)) => {
1964                assert_eq!(result.contents.len(), 1);
1965                assert_eq!(result.contents[0].uri, "db://users/123");
1966                assert!(result.contents[0].text.as_ref().unwrap().contains("123"));
1967            }
1968            _ => panic!("Expected ReadResource response"),
1969        }
1970    }
1971
1972    #[tokio::test]
1973    async fn test_static_resource_takes_precedence_over_template() {
1974        use crate::resource::{ResourceBuilder, ResourceTemplateBuilder};
1975        use std::collections::HashMap;
1976
1977        // Template that would match the same URI
1978        let template = ResourceTemplateBuilder::new("file:///{path}")
1979            .name("Files Template")
1980            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1981                Ok(ReadResourceResult {
1982                    contents: vec![ResourceContent {
1983                        uri,
1984                        mime_type: None,
1985                        text: Some("from template".to_string()),
1986                        blob: None,
1987                    }],
1988                })
1989            });
1990
1991        // Static resource with exact URI
1992        let static_resource = ResourceBuilder::new("file:///README.md")
1993            .name("README")
1994            .text("from static resource");
1995
1996        let mut router = McpRouter::new()
1997            .resource_template(template)
1998            .resource(static_resource);
1999
2000        // Initialize session
2001        init_router(&mut router).await;
2002
2003        // Read the static resource - should NOT go through template
2004        let req = RouterRequest {
2005            id: RequestId::Number(1),
2006            inner: McpRequest::ReadResource(ReadResourceParams {
2007                uri: "file:///README.md".to_string(),
2008            }),
2009            extensions: Extensions::new(),
2010        };
2011
2012        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2013
2014        match resp.inner {
2015            Ok(McpResponse::ReadResource(result)) => {
2016                // Should get static resource, not template
2017                assert_eq!(
2018                    result.contents[0].text.as_deref(),
2019                    Some("from static resource")
2020                );
2021            }
2022            _ => panic!("Expected ReadResource response"),
2023        }
2024    }
2025
2026    #[tokio::test]
2027    async fn test_resource_not_found_when_no_match() {
2028        use crate::resource::ResourceTemplateBuilder;
2029        use std::collections::HashMap;
2030
2031        let template = ResourceTemplateBuilder::new("db://users/{id}")
2032            .name("Users")
2033            .handler(|uri: String, _vars: HashMap<String, String>| async move {
2034                Ok(ReadResourceResult {
2035                    contents: vec![ResourceContent {
2036                        uri,
2037                        mime_type: None,
2038                        text: None,
2039                        blob: None,
2040                    }],
2041                })
2042            });
2043
2044        let mut router = McpRouter::new().resource_template(template);
2045
2046        // Initialize session
2047        init_router(&mut router).await;
2048
2049        // Try to read a URI that doesn't match any resource or template
2050        let req = RouterRequest {
2051            id: RequestId::Number(1),
2052            inner: McpRequest::ReadResource(ReadResourceParams {
2053                uri: "db://posts/123".to_string(),
2054            }),
2055            extensions: Extensions::new(),
2056        };
2057
2058        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2059
2060        match resp.inner {
2061            Err(err) => {
2062                assert!(err.message.contains("not found"));
2063            }
2064            Ok(_) => panic!("Expected error for non-matching URI"),
2065        }
2066    }
2067
2068    #[tokio::test]
2069    async fn test_capabilities_include_resources_with_only_templates() {
2070        use crate::resource::ResourceTemplateBuilder;
2071        use std::collections::HashMap;
2072
2073        let template = ResourceTemplateBuilder::new("file:///{path}")
2074            .name("Files")
2075            .handler(|uri: String, _vars: HashMap<String, String>| async move {
2076                Ok(ReadResourceResult {
2077                    contents: vec![ResourceContent {
2078                        uri,
2079                        mime_type: None,
2080                        text: None,
2081                        blob: None,
2082                    }],
2083                })
2084            });
2085
2086        let mut router = McpRouter::new().resource_template(template);
2087
2088        // Send initialize request and check capabilities
2089        let init_req = RouterRequest {
2090            id: RequestId::Number(0),
2091            inner: McpRequest::Initialize(InitializeParams {
2092                protocol_version: "2025-11-25".to_string(),
2093                capabilities: ClientCapabilities {
2094                    roots: None,
2095                    sampling: None,
2096                    elicitation: None,
2097                },
2098                client_info: Implementation {
2099                    name: "test".to_string(),
2100                    version: "1.0".to_string(),
2101                    ..Default::default()
2102                },
2103            }),
2104            extensions: Extensions::new(),
2105        };
2106        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2107
2108        match resp.inner {
2109            Ok(McpResponse::Initialize(result)) => {
2110                // Should have resources capability even though only templates registered
2111                assert!(result.capabilities.resources.is_some());
2112            }
2113            _ => panic!("Expected Initialize response"),
2114        }
2115    }
2116
2117    // =========================================================================
2118    // Logging Notification Tests
2119    // =========================================================================
2120
2121    #[tokio::test]
2122    async fn test_log_sends_notification() {
2123        use crate::context::notification_channel;
2124
2125        let (tx, mut rx) = notification_channel(10);
2126        let router = McpRouter::new().with_notification_sender(tx);
2127
2128        // Send an info log
2129        let sent = router.log_info("Test message");
2130        assert!(sent);
2131
2132        // Should receive the notification
2133        let notification = rx.try_recv().unwrap();
2134        match notification {
2135            ServerNotification::LogMessage(params) => {
2136                assert_eq!(params.level, LogLevel::Info);
2137                let data = params.data.unwrap();
2138                assert_eq!(
2139                    data.get("message").unwrap().as_str().unwrap(),
2140                    "Test message"
2141                );
2142            }
2143            _ => panic!("Expected LogMessage notification"),
2144        }
2145    }
2146
2147    #[tokio::test]
2148    async fn test_log_with_custom_params() {
2149        use crate::context::notification_channel;
2150
2151        let (tx, mut rx) = notification_channel(10);
2152        let router = McpRouter::new().with_notification_sender(tx);
2153
2154        // Send a custom log message
2155        let params = LoggingMessageParams::new(LogLevel::Error)
2156            .with_logger("database")
2157            .with_data(serde_json::json!({
2158                "error": "Connection failed",
2159                "host": "localhost"
2160            }));
2161
2162        let sent = router.log(params);
2163        assert!(sent);
2164
2165        let notification = rx.try_recv().unwrap();
2166        match notification {
2167            ServerNotification::LogMessage(params) => {
2168                assert_eq!(params.level, LogLevel::Error);
2169                assert_eq!(params.logger.as_deref(), Some("database"));
2170                let data = params.data.unwrap();
2171                assert_eq!(
2172                    data.get("error").unwrap().as_str().unwrap(),
2173                    "Connection failed"
2174                );
2175            }
2176            _ => panic!("Expected LogMessage notification"),
2177        }
2178    }
2179
2180    #[tokio::test]
2181    async fn test_log_without_channel_returns_false() {
2182        // Router without notification channel
2183        let router = McpRouter::new();
2184
2185        // Should return false when no channel configured
2186        assert!(!router.log_info("Test"));
2187        assert!(!router.log_warning("Test"));
2188        assert!(!router.log_error("Test"));
2189        assert!(!router.log_debug("Test"));
2190    }
2191
2192    #[tokio::test]
2193    async fn test_logging_capability_with_channel() {
2194        use crate::context::notification_channel;
2195
2196        let (tx, _rx) = notification_channel(10);
2197        let mut router = McpRouter::new().with_notification_sender(tx);
2198
2199        // Initialize and check capabilities
2200        let init_req = RouterRequest {
2201            id: RequestId::Number(0),
2202            inner: McpRequest::Initialize(InitializeParams {
2203                protocol_version: "2025-11-25".to_string(),
2204                capabilities: ClientCapabilities {
2205                    roots: None,
2206                    sampling: None,
2207                    elicitation: None,
2208                },
2209                client_info: Implementation {
2210                    name: "test".to_string(),
2211                    version: "1.0".to_string(),
2212                    ..Default::default()
2213                },
2214            }),
2215            extensions: Extensions::new(),
2216        };
2217        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2218
2219        match resp.inner {
2220            Ok(McpResponse::Initialize(result)) => {
2221                // Should have logging capability when notification channel is set
2222                assert!(result.capabilities.logging.is_some());
2223            }
2224            _ => panic!("Expected Initialize response"),
2225        }
2226    }
2227
2228    #[tokio::test]
2229    async fn test_no_logging_capability_without_channel() {
2230        let mut router = McpRouter::new();
2231
2232        // Initialize and check capabilities
2233        let init_req = RouterRequest {
2234            id: RequestId::Number(0),
2235            inner: McpRequest::Initialize(InitializeParams {
2236                protocol_version: "2025-11-25".to_string(),
2237                capabilities: ClientCapabilities {
2238                    roots: None,
2239                    sampling: None,
2240                    elicitation: None,
2241                },
2242                client_info: Implementation {
2243                    name: "test".to_string(),
2244                    version: "1.0".to_string(),
2245                    ..Default::default()
2246                },
2247            }),
2248            extensions: Extensions::new(),
2249        };
2250        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2251
2252        match resp.inner {
2253            Ok(McpResponse::Initialize(result)) => {
2254                // Should NOT have logging capability without notification channel
2255                assert!(result.capabilities.logging.is_none());
2256            }
2257            _ => panic!("Expected Initialize response"),
2258        }
2259    }
2260
2261    // =========================================================================
2262    // Task Lifecycle Tests
2263    // =========================================================================
2264
2265    #[tokio::test]
2266    async fn test_enqueue_task() {
2267        let add_tool = ToolBuilder::new("add")
2268            .description("Add two numbers")
2269            .handler(|input: AddInput| async move {
2270                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2271            })
2272            .build()
2273            .expect("valid tool name");
2274
2275        let mut router = McpRouter::new().tool(add_tool);
2276        init_router(&mut router).await;
2277
2278        let req = RouterRequest {
2279            id: RequestId::Number(1),
2280            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2281                tool_name: "add".to_string(),
2282                arguments: serde_json::json!({"a": 5, "b": 10}),
2283                ttl: None,
2284            }),
2285            extensions: Extensions::new(),
2286        };
2287
2288        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2289
2290        match resp.inner {
2291            Ok(McpResponse::EnqueueTask(result)) => {
2292                assert!(result.task_id.starts_with("task-"));
2293                assert_eq!(result.status, TaskStatus::Working);
2294            }
2295            _ => panic!("Expected EnqueueTask response"),
2296        }
2297    }
2298
2299    #[tokio::test]
2300    async fn test_list_tasks_empty() {
2301        let mut router = McpRouter::new();
2302        init_router(&mut router).await;
2303
2304        let req = RouterRequest {
2305            id: RequestId::Number(1),
2306            inner: McpRequest::ListTasks(ListTasksParams::default()),
2307            extensions: Extensions::new(),
2308        };
2309
2310        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2311
2312        match resp.inner {
2313            Ok(McpResponse::ListTasks(result)) => {
2314                assert!(result.tasks.is_empty());
2315            }
2316            _ => panic!("Expected ListTasks response"),
2317        }
2318    }
2319
2320    #[tokio::test]
2321    async fn test_task_lifecycle_complete() {
2322        let add_tool = ToolBuilder::new("add")
2323            .description("Add two numbers")
2324            .handler(|input: AddInput| async move {
2325                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2326            })
2327            .build()
2328            .expect("valid tool name");
2329
2330        let mut router = McpRouter::new().tool(add_tool);
2331        init_router(&mut router).await;
2332
2333        // Enqueue task
2334        let req = RouterRequest {
2335            id: RequestId::Number(1),
2336            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2337                tool_name: "add".to_string(),
2338                arguments: serde_json::json!({"a": 7, "b": 8}),
2339                ttl: None,
2340            }),
2341            extensions: Extensions::new(),
2342        };
2343
2344        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2345        let task_id = match resp.inner {
2346            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2347            _ => panic!("Expected EnqueueTask response"),
2348        };
2349
2350        // Wait for task to complete
2351        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
2352
2353        // Get task result
2354        let req = RouterRequest {
2355            id: RequestId::Number(2),
2356            inner: McpRequest::GetTaskResult(GetTaskResultParams {
2357                task_id: task_id.clone(),
2358            }),
2359            extensions: Extensions::new(),
2360        };
2361
2362        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2363
2364        match resp.inner {
2365            Ok(McpResponse::GetTaskResult(result)) => {
2366                assert_eq!(result.task_id, task_id);
2367                assert_eq!(result.status, TaskStatus::Completed);
2368                assert!(result.result.is_some());
2369                assert!(result.error.is_none());
2370
2371                // Check the result content
2372                let tool_result = result.result.unwrap();
2373                match &tool_result.content[0] {
2374                    Content::Text { text, .. } => assert_eq!(text, "15"),
2375                    _ => panic!("Expected text content"),
2376                }
2377            }
2378            _ => panic!("Expected GetTaskResult response"),
2379        }
2380    }
2381
2382    #[tokio::test]
2383    async fn test_task_cancellation() {
2384        // Use a slow tool to test cancellation
2385        let slow_tool = ToolBuilder::new("slow")
2386            .description("Slow tool")
2387            .handler(|_input: serde_json::Value| async move {
2388                tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
2389                Ok(CallToolResult::text("done"))
2390            })
2391            .build()
2392            .expect("valid tool name");
2393
2394        let mut router = McpRouter::new().tool(slow_tool);
2395        init_router(&mut router).await;
2396
2397        // Enqueue task
2398        let req = RouterRequest {
2399            id: RequestId::Number(1),
2400            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2401                tool_name: "slow".to_string(),
2402                arguments: serde_json::json!({}),
2403                ttl: None,
2404            }),
2405            extensions: Extensions::new(),
2406        };
2407
2408        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2409        let task_id = match resp.inner {
2410            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2411            _ => panic!("Expected EnqueueTask response"),
2412        };
2413
2414        // Cancel the task
2415        let req = RouterRequest {
2416            id: RequestId::Number(2),
2417            inner: McpRequest::CancelTask(CancelTaskParams {
2418                task_id: task_id.clone(),
2419                reason: Some("Test cancellation".to_string()),
2420            }),
2421            extensions: Extensions::new(),
2422        };
2423
2424        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2425
2426        match resp.inner {
2427            Ok(McpResponse::CancelTask(result)) => {
2428                assert!(result.cancelled);
2429                assert_eq!(result.status, TaskStatus::Cancelled);
2430            }
2431            _ => panic!("Expected CancelTask response"),
2432        }
2433    }
2434
2435    #[tokio::test]
2436    async fn test_get_task_info() {
2437        let add_tool = ToolBuilder::new("add")
2438            .description("Add two numbers")
2439            .handler(|input: AddInput| async move {
2440                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2441            })
2442            .build()
2443            .expect("valid tool name");
2444
2445        let mut router = McpRouter::new().tool(add_tool);
2446        init_router(&mut router).await;
2447
2448        // Enqueue task
2449        let req = RouterRequest {
2450            id: RequestId::Number(1),
2451            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2452                tool_name: "add".to_string(),
2453                arguments: serde_json::json!({"a": 1, "b": 2}),
2454                ttl: Some(600),
2455            }),
2456            extensions: Extensions::new(),
2457        };
2458
2459        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2460        let task_id = match resp.inner {
2461            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2462            _ => panic!("Expected EnqueueTask response"),
2463        };
2464
2465        // Get task info
2466        let req = RouterRequest {
2467            id: RequestId::Number(2),
2468            inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2469                task_id: task_id.clone(),
2470            }),
2471            extensions: Extensions::new(),
2472        };
2473
2474        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2475
2476        match resp.inner {
2477            Ok(McpResponse::GetTaskInfo(info)) => {
2478                assert_eq!(info.task_id, task_id);
2479                assert!(info.created_at.contains('T')); // ISO 8601
2480                assert_eq!(info.ttl, Some(600));
2481            }
2482            _ => panic!("Expected GetTaskInfo response"),
2483        }
2484    }
2485
2486    #[tokio::test]
2487    async fn test_enqueue_nonexistent_tool() {
2488        let mut router = McpRouter::new();
2489        init_router(&mut router).await;
2490
2491        let req = RouterRequest {
2492            id: RequestId::Number(1),
2493            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2494                tool_name: "nonexistent".to_string(),
2495                arguments: serde_json::json!({}),
2496                ttl: None,
2497            }),
2498            extensions: Extensions::new(),
2499        };
2500
2501        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2502
2503        match resp.inner {
2504            Err(e) => {
2505                assert!(e.message.contains("not found"));
2506            }
2507            _ => panic!("Expected error response"),
2508        }
2509    }
2510
2511    #[tokio::test]
2512    async fn test_get_nonexistent_task() {
2513        let mut router = McpRouter::new();
2514        init_router(&mut router).await;
2515
2516        let req = RouterRequest {
2517            id: RequestId::Number(1),
2518            inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2519                task_id: "task-999".to_string(),
2520            }),
2521            extensions: Extensions::new(),
2522        };
2523
2524        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2525
2526        match resp.inner {
2527            Err(e) => {
2528                assert!(e.message.contains("not found"));
2529            }
2530            _ => panic!("Expected error response"),
2531        }
2532    }
2533
2534    // =========================================================================
2535    // Resource Subscription Tests
2536    // =========================================================================
2537
2538    #[tokio::test]
2539    async fn test_subscribe_to_resource() {
2540        use crate::resource::ResourceBuilder;
2541
2542        let resource = ResourceBuilder::new("file:///test.txt")
2543            .name("Test File")
2544            .text("Hello");
2545
2546        let mut router = McpRouter::new().resource(resource);
2547        init_router(&mut router).await;
2548
2549        // Subscribe to the resource
2550        let req = RouterRequest {
2551            id: RequestId::Number(1),
2552            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2553                uri: "file:///test.txt".to_string(),
2554            }),
2555            extensions: Extensions::new(),
2556        };
2557
2558        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2559
2560        match resp.inner {
2561            Ok(McpResponse::SubscribeResource(_)) => {
2562                // Should be subscribed now
2563                assert!(router.is_subscribed("file:///test.txt"));
2564            }
2565            _ => panic!("Expected SubscribeResource response"),
2566        }
2567    }
2568
2569    #[tokio::test]
2570    async fn test_unsubscribe_from_resource() {
2571        use crate::resource::ResourceBuilder;
2572
2573        let resource = ResourceBuilder::new("file:///test.txt")
2574            .name("Test File")
2575            .text("Hello");
2576
2577        let mut router = McpRouter::new().resource(resource);
2578        init_router(&mut router).await;
2579
2580        // Subscribe first
2581        let req = RouterRequest {
2582            id: RequestId::Number(1),
2583            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2584                uri: "file:///test.txt".to_string(),
2585            }),
2586            extensions: Extensions::new(),
2587        };
2588        let _ = router.ready().await.unwrap().call(req).await.unwrap();
2589        assert!(router.is_subscribed("file:///test.txt"));
2590
2591        // Now unsubscribe
2592        let req = RouterRequest {
2593            id: RequestId::Number(2),
2594            inner: McpRequest::UnsubscribeResource(UnsubscribeResourceParams {
2595                uri: "file:///test.txt".to_string(),
2596            }),
2597            extensions: Extensions::new(),
2598        };
2599
2600        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2601
2602        match resp.inner {
2603            Ok(McpResponse::UnsubscribeResource(_)) => {
2604                // Should no longer be subscribed
2605                assert!(!router.is_subscribed("file:///test.txt"));
2606            }
2607            _ => panic!("Expected UnsubscribeResource response"),
2608        }
2609    }
2610
2611    #[tokio::test]
2612    async fn test_subscribe_nonexistent_resource() {
2613        let mut router = McpRouter::new();
2614        init_router(&mut router).await;
2615
2616        let req = RouterRequest {
2617            id: RequestId::Number(1),
2618            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2619                uri: "file:///nonexistent.txt".to_string(),
2620            }),
2621            extensions: Extensions::new(),
2622        };
2623
2624        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2625
2626        match resp.inner {
2627            Err(e) => {
2628                assert!(e.message.contains("not found"));
2629            }
2630            _ => panic!("Expected error response"),
2631        }
2632    }
2633
2634    #[tokio::test]
2635    async fn test_notify_resource_updated() {
2636        use crate::context::notification_channel;
2637        use crate::resource::ResourceBuilder;
2638
2639        let (tx, mut rx) = notification_channel(10);
2640
2641        let resource = ResourceBuilder::new("file:///test.txt")
2642            .name("Test File")
2643            .text("Hello");
2644
2645        let router = McpRouter::new()
2646            .resource(resource)
2647            .with_notification_sender(tx);
2648
2649        // First, manually subscribe (simulate subscription)
2650        router.subscribe("file:///test.txt");
2651
2652        // Now notify
2653        let sent = router.notify_resource_updated("file:///test.txt");
2654        assert!(sent);
2655
2656        // Check the notification was sent
2657        let notification = rx.try_recv().unwrap();
2658        match notification {
2659            ServerNotification::ResourceUpdated { uri } => {
2660                assert_eq!(uri, "file:///test.txt");
2661            }
2662            _ => panic!("Expected ResourceUpdated notification"),
2663        }
2664    }
2665
2666    #[tokio::test]
2667    async fn test_notify_resource_updated_not_subscribed() {
2668        use crate::context::notification_channel;
2669        use crate::resource::ResourceBuilder;
2670
2671        let (tx, mut rx) = notification_channel(10);
2672
2673        let resource = ResourceBuilder::new("file:///test.txt")
2674            .name("Test File")
2675            .text("Hello");
2676
2677        let router = McpRouter::new()
2678            .resource(resource)
2679            .with_notification_sender(tx);
2680
2681        // Try to notify without subscribing
2682        let sent = router.notify_resource_updated("file:///test.txt");
2683        assert!(!sent); // Should not send because not subscribed
2684
2685        // Channel should be empty
2686        assert!(rx.try_recv().is_err());
2687    }
2688
2689    #[tokio::test]
2690    async fn test_notify_resources_list_changed() {
2691        use crate::context::notification_channel;
2692
2693        let (tx, mut rx) = notification_channel(10);
2694        let router = McpRouter::new().with_notification_sender(tx);
2695
2696        let sent = router.notify_resources_list_changed();
2697        assert!(sent);
2698
2699        let notification = rx.try_recv().unwrap();
2700        match notification {
2701            ServerNotification::ResourcesListChanged => {}
2702            _ => panic!("Expected ResourcesListChanged notification"),
2703        }
2704    }
2705
2706    #[tokio::test]
2707    async fn test_subscribed_uris() {
2708        use crate::resource::ResourceBuilder;
2709
2710        let resource1 = ResourceBuilder::new("file:///a.txt").name("A").text("A");
2711
2712        let resource2 = ResourceBuilder::new("file:///b.txt").name("B").text("B");
2713
2714        let router = McpRouter::new().resource(resource1).resource(resource2);
2715
2716        // Subscribe to both
2717        router.subscribe("file:///a.txt");
2718        router.subscribe("file:///b.txt");
2719
2720        let uris = router.subscribed_uris();
2721        assert_eq!(uris.len(), 2);
2722        assert!(uris.contains(&"file:///a.txt".to_string()));
2723        assert!(uris.contains(&"file:///b.txt".to_string()));
2724    }
2725
2726    #[tokio::test]
2727    async fn test_subscription_capability_advertised() {
2728        use crate::resource::ResourceBuilder;
2729
2730        let resource = ResourceBuilder::new("file:///test.txt")
2731            .name("Test")
2732            .text("Hello");
2733
2734        let mut router = McpRouter::new().resource(resource);
2735
2736        // Initialize and check capabilities
2737        let init_req = RouterRequest {
2738            id: RequestId::Number(0),
2739            inner: McpRequest::Initialize(InitializeParams {
2740                protocol_version: "2025-11-25".to_string(),
2741                capabilities: ClientCapabilities {
2742                    roots: None,
2743                    sampling: None,
2744                    elicitation: None,
2745                },
2746                client_info: Implementation {
2747                    name: "test".to_string(),
2748                    version: "1.0".to_string(),
2749                    ..Default::default()
2750                },
2751            }),
2752            extensions: Extensions::new(),
2753        };
2754        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2755
2756        match resp.inner {
2757            Ok(McpResponse::Initialize(result)) => {
2758                // Should have resources capability with subscribe enabled
2759                let resources_cap = result.capabilities.resources.unwrap();
2760                assert!(resources_cap.subscribe);
2761            }
2762            _ => panic!("Expected Initialize response"),
2763        }
2764    }
2765
2766    #[tokio::test]
2767    async fn test_completion_handler() {
2768        let router = McpRouter::new()
2769            .server_info("test", "1.0")
2770            .completion_handler(|params: CompleteParams| async move {
2771                // Return suggestions based on the argument value
2772                let prefix = &params.argument.value;
2773                let suggestions: Vec<String> = vec!["alpha", "beta", "gamma"]
2774                    .into_iter()
2775                    .filter(|s| s.starts_with(prefix))
2776                    .map(String::from)
2777                    .collect();
2778                Ok(CompleteResult::new(suggestions))
2779            });
2780
2781        // Initialize
2782        let init_req = RouterRequest {
2783            id: RequestId::Number(0),
2784            inner: McpRequest::Initialize(InitializeParams {
2785                protocol_version: "2025-11-25".to_string(),
2786                capabilities: ClientCapabilities::default(),
2787                client_info: Implementation {
2788                    name: "test".to_string(),
2789                    version: "1.0".to_string(),
2790                    ..Default::default()
2791                },
2792            }),
2793            extensions: Extensions::new(),
2794        };
2795        let resp = router
2796            .clone()
2797            .ready()
2798            .await
2799            .unwrap()
2800            .call(init_req)
2801            .await
2802            .unwrap();
2803
2804        // Check that completions capability is advertised
2805        match resp.inner {
2806            Ok(McpResponse::Initialize(result)) => {
2807                assert!(result.capabilities.completions.is_some());
2808            }
2809            _ => panic!("Expected Initialize response"),
2810        }
2811
2812        // Send initialized notification
2813        router.handle_notification(McpNotification::Initialized);
2814
2815        // Test completion request
2816        let complete_req = RouterRequest {
2817            id: RequestId::Number(1),
2818            inner: McpRequest::Complete(CompleteParams {
2819                reference: CompletionReference::prompt("test-prompt"),
2820                argument: CompletionArgument::new("query", "al"),
2821            }),
2822            extensions: Extensions::new(),
2823        };
2824        let resp = router
2825            .clone()
2826            .ready()
2827            .await
2828            .unwrap()
2829            .call(complete_req)
2830            .await
2831            .unwrap();
2832
2833        match resp.inner {
2834            Ok(McpResponse::Complete(result)) => {
2835                assert_eq!(result.completion.values, vec!["alpha"]);
2836            }
2837            _ => panic!("Expected Complete response"),
2838        }
2839    }
2840
2841    #[tokio::test]
2842    async fn test_completion_without_handler_returns_empty() {
2843        let router = McpRouter::new().server_info("test", "1.0");
2844
2845        // Initialize
2846        let init_req = RouterRequest {
2847            id: RequestId::Number(0),
2848            inner: McpRequest::Initialize(InitializeParams {
2849                protocol_version: "2025-11-25".to_string(),
2850                capabilities: ClientCapabilities::default(),
2851                client_info: Implementation {
2852                    name: "test".to_string(),
2853                    version: "1.0".to_string(),
2854                    ..Default::default()
2855                },
2856            }),
2857            extensions: Extensions::new(),
2858        };
2859        let resp = router
2860            .clone()
2861            .ready()
2862            .await
2863            .unwrap()
2864            .call(init_req)
2865            .await
2866            .unwrap();
2867
2868        // Check that completions capability is NOT advertised
2869        match resp.inner {
2870            Ok(McpResponse::Initialize(result)) => {
2871                assert!(result.capabilities.completions.is_none());
2872            }
2873            _ => panic!("Expected Initialize response"),
2874        }
2875
2876        // Send initialized notification
2877        router.handle_notification(McpNotification::Initialized);
2878
2879        // Test completion request still works but returns empty
2880        let complete_req = RouterRequest {
2881            id: RequestId::Number(1),
2882            inner: McpRequest::Complete(CompleteParams {
2883                reference: CompletionReference::prompt("test-prompt"),
2884                argument: CompletionArgument::new("query", "al"),
2885            }),
2886            extensions: Extensions::new(),
2887        };
2888        let resp = router
2889            .clone()
2890            .ready()
2891            .await
2892            .unwrap()
2893            .call(complete_req)
2894            .await
2895            .unwrap();
2896
2897        match resp.inner {
2898            Ok(McpResponse::Complete(result)) => {
2899                assert!(result.completion.values.is_empty());
2900            }
2901            _ => panic!("Expected Complete response"),
2902        }
2903    }
2904
2905    #[tokio::test]
2906    async fn test_tool_filter_list() {
2907        use crate::filter::CapabilityFilter;
2908        use crate::tool::Tool;
2909
2910        let public_tool = ToolBuilder::new("public")
2911            .description("Public tool")
2912            .handler(|_: AddInput| async move { Ok(CallToolResult::text("public")) })
2913            .build()
2914            .expect("valid tool name");
2915
2916        let admin_tool = ToolBuilder::new("admin")
2917            .description("Admin tool")
2918            .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
2919            .build()
2920            .expect("valid tool name");
2921
2922        let mut router = McpRouter::new()
2923            .tool(public_tool)
2924            .tool(admin_tool)
2925            .tool_filter(CapabilityFilter::new(|_, tool: &Tool| tool.name != "admin"));
2926
2927        // Initialize session
2928        init_router(&mut router).await;
2929
2930        let req = RouterRequest {
2931            id: RequestId::Number(1),
2932            inner: McpRequest::ListTools(ListToolsParams::default()),
2933            extensions: Extensions::new(),
2934        };
2935
2936        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2937
2938        match resp.inner {
2939            Ok(McpResponse::ListTools(result)) => {
2940                // Only public tool should be visible
2941                assert_eq!(result.tools.len(), 1);
2942                assert_eq!(result.tools[0].name, "public");
2943            }
2944            _ => panic!("Expected ListTools response"),
2945        }
2946    }
2947
2948    #[tokio::test]
2949    async fn test_tool_filter_call_denied() {
2950        use crate::filter::CapabilityFilter;
2951        use crate::tool::Tool;
2952
2953        let admin_tool = ToolBuilder::new("admin")
2954            .description("Admin tool")
2955            .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
2956            .build()
2957            .expect("valid tool name");
2958
2959        let mut router = McpRouter::new()
2960            .tool(admin_tool)
2961            .tool_filter(CapabilityFilter::new(|_, _: &Tool| false)); // Deny all
2962
2963        // Initialize session
2964        init_router(&mut router).await;
2965
2966        let req = RouterRequest {
2967            id: RequestId::Number(1),
2968            inner: McpRequest::CallTool(CallToolParams {
2969                name: "admin".to_string(),
2970                arguments: serde_json::json!({"a": 1, "b": 2}),
2971                meta: None,
2972            }),
2973            extensions: Extensions::new(),
2974        };
2975
2976        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2977
2978        // Should get method not found error (default denial behavior)
2979        match resp.inner {
2980            Err(e) => {
2981                assert_eq!(e.code, -32601); // Method not found
2982            }
2983            _ => panic!("Expected JsonRpc error"),
2984        }
2985    }
2986
2987    #[tokio::test]
2988    async fn test_tool_filter_call_allowed() {
2989        use crate::filter::CapabilityFilter;
2990        use crate::tool::Tool;
2991
2992        let public_tool = ToolBuilder::new("public")
2993            .description("Public tool")
2994            .handler(|input: AddInput| async move {
2995                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2996            })
2997            .build()
2998            .expect("valid tool name");
2999
3000        let mut router = McpRouter::new()
3001            .tool(public_tool)
3002            .tool_filter(CapabilityFilter::new(|_, _: &Tool| true)); // Allow all
3003
3004        // Initialize session
3005        init_router(&mut router).await;
3006
3007        let req = RouterRequest {
3008            id: RequestId::Number(1),
3009            inner: McpRequest::CallTool(CallToolParams {
3010                name: "public".to_string(),
3011                arguments: serde_json::json!({"a": 1, "b": 2}),
3012                meta: None,
3013            }),
3014            extensions: Extensions::new(),
3015        };
3016
3017        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3018
3019        match resp.inner {
3020            Ok(McpResponse::CallTool(result)) => {
3021                assert!(!result.is_error);
3022            }
3023            _ => panic!("Expected CallTool response"),
3024        }
3025    }
3026
3027    #[tokio::test]
3028    async fn test_tool_filter_custom_denial() {
3029        use crate::filter::{CapabilityFilter, DenialBehavior};
3030        use crate::tool::Tool;
3031
3032        let admin_tool = ToolBuilder::new("admin")
3033            .description("Admin tool")
3034            .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
3035            .build()
3036            .expect("valid tool name");
3037
3038        let mut router = McpRouter::new().tool(admin_tool).tool_filter(
3039            CapabilityFilter::new(|_, _: &Tool| false)
3040                .denial_behavior(DenialBehavior::Unauthorized),
3041        );
3042
3043        // Initialize session
3044        init_router(&mut router).await;
3045
3046        let req = RouterRequest {
3047            id: RequestId::Number(1),
3048            inner: McpRequest::CallTool(CallToolParams {
3049                name: "admin".to_string(),
3050                arguments: serde_json::json!({"a": 1, "b": 2}),
3051                meta: None,
3052            }),
3053            extensions: Extensions::new(),
3054        };
3055
3056        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3057
3058        // Should get forbidden error
3059        match resp.inner {
3060            Err(e) => {
3061                assert_eq!(e.code, -32007); // Forbidden
3062                assert!(e.message.contains("Unauthorized"));
3063            }
3064            _ => panic!("Expected JsonRpc error"),
3065        }
3066    }
3067
3068    #[tokio::test]
3069    async fn test_resource_filter_list() {
3070        use crate::filter::CapabilityFilter;
3071        use crate::resource::{Resource, ResourceBuilder};
3072
3073        let public_resource = ResourceBuilder::new("file:///public.txt")
3074            .name("Public File")
3075            .text("public content");
3076
3077        let secret_resource = ResourceBuilder::new("file:///secret.txt")
3078            .name("Secret File")
3079            .text("secret content");
3080
3081        let mut router = McpRouter::new()
3082            .resource(public_resource)
3083            .resource(secret_resource)
3084            .resource_filter(CapabilityFilter::new(|_, r: &Resource| {
3085                !r.name.contains("Secret")
3086            }));
3087
3088        // Initialize session
3089        init_router(&mut router).await;
3090
3091        let req = RouterRequest {
3092            id: RequestId::Number(1),
3093            inner: McpRequest::ListResources(ListResourcesParams::default()),
3094            extensions: Extensions::new(),
3095        };
3096
3097        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3098
3099        match resp.inner {
3100            Ok(McpResponse::ListResources(result)) => {
3101                // Should only see public resource
3102                assert_eq!(result.resources.len(), 1);
3103                assert_eq!(result.resources[0].name, "Public File");
3104            }
3105            _ => panic!("Expected ListResources response"),
3106        }
3107    }
3108
3109    #[tokio::test]
3110    async fn test_resource_filter_read_denied() {
3111        use crate::filter::CapabilityFilter;
3112        use crate::resource::{Resource, ResourceBuilder};
3113
3114        let secret_resource = ResourceBuilder::new("file:///secret.txt")
3115            .name("Secret File")
3116            .text("secret content");
3117
3118        let mut router = McpRouter::new()
3119            .resource(secret_resource)
3120            .resource_filter(CapabilityFilter::new(|_, _: &Resource| false)); // Deny all
3121
3122        // Initialize session
3123        init_router(&mut router).await;
3124
3125        let req = RouterRequest {
3126            id: RequestId::Number(1),
3127            inner: McpRequest::ReadResource(ReadResourceParams {
3128                uri: "file:///secret.txt".to_string(),
3129            }),
3130            extensions: Extensions::new(),
3131        };
3132
3133        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3134
3135        // Should get method not found error (default denial behavior)
3136        match resp.inner {
3137            Err(e) => {
3138                assert_eq!(e.code, -32601); // Method not found
3139            }
3140            _ => panic!("Expected JsonRpc error"),
3141        }
3142    }
3143
3144    #[tokio::test]
3145    async fn test_resource_filter_read_allowed() {
3146        use crate::filter::CapabilityFilter;
3147        use crate::resource::{Resource, ResourceBuilder};
3148
3149        let public_resource = ResourceBuilder::new("file:///public.txt")
3150            .name("Public File")
3151            .text("public content");
3152
3153        let mut router = McpRouter::new()
3154            .resource(public_resource)
3155            .resource_filter(CapabilityFilter::new(|_, _: &Resource| true)); // Allow all
3156
3157        // Initialize session
3158        init_router(&mut router).await;
3159
3160        let req = RouterRequest {
3161            id: RequestId::Number(1),
3162            inner: McpRequest::ReadResource(ReadResourceParams {
3163                uri: "file:///public.txt".to_string(),
3164            }),
3165            extensions: Extensions::new(),
3166        };
3167
3168        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3169
3170        match resp.inner {
3171            Ok(McpResponse::ReadResource(result)) => {
3172                assert_eq!(result.contents.len(), 1);
3173                assert_eq!(result.contents[0].text.as_deref(), Some("public content"));
3174            }
3175            _ => panic!("Expected ReadResource response"),
3176        }
3177    }
3178
3179    #[tokio::test]
3180    async fn test_resource_filter_custom_denial() {
3181        use crate::filter::{CapabilityFilter, DenialBehavior};
3182        use crate::resource::{Resource, ResourceBuilder};
3183
3184        let secret_resource = ResourceBuilder::new("file:///secret.txt")
3185            .name("Secret File")
3186            .text("secret content");
3187
3188        let mut router = McpRouter::new().resource(secret_resource).resource_filter(
3189            CapabilityFilter::new(|_, _: &Resource| false)
3190                .denial_behavior(DenialBehavior::Unauthorized),
3191        );
3192
3193        // Initialize session
3194        init_router(&mut router).await;
3195
3196        let req = RouterRequest {
3197            id: RequestId::Number(1),
3198            inner: McpRequest::ReadResource(ReadResourceParams {
3199                uri: "file:///secret.txt".to_string(),
3200            }),
3201            extensions: Extensions::new(),
3202        };
3203
3204        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3205
3206        // Should get forbidden error
3207        match resp.inner {
3208            Err(e) => {
3209                assert_eq!(e.code, -32007); // Forbidden
3210                assert!(e.message.contains("Unauthorized"));
3211            }
3212            _ => panic!("Expected JsonRpc error"),
3213        }
3214    }
3215
3216    #[tokio::test]
3217    async fn test_prompt_filter_list() {
3218        use crate::filter::CapabilityFilter;
3219        use crate::prompt::{Prompt, PromptBuilder};
3220
3221        let public_prompt = PromptBuilder::new("greeting")
3222            .description("A greeting")
3223            .user_message("Hello!");
3224
3225        let admin_prompt = PromptBuilder::new("system_debug")
3226            .description("Admin prompt")
3227            .user_message("Debug");
3228
3229        let mut router = McpRouter::new()
3230            .prompt(public_prompt)
3231            .prompt(admin_prompt)
3232            .prompt_filter(CapabilityFilter::new(|_, p: &Prompt| {
3233                !p.name.contains("system")
3234            }));
3235
3236        // Initialize session
3237        init_router(&mut router).await;
3238
3239        let req = RouterRequest {
3240            id: RequestId::Number(1),
3241            inner: McpRequest::ListPrompts(ListPromptsParams::default()),
3242            extensions: Extensions::new(),
3243        };
3244
3245        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3246
3247        match resp.inner {
3248            Ok(McpResponse::ListPrompts(result)) => {
3249                // Should only see public prompt
3250                assert_eq!(result.prompts.len(), 1);
3251                assert_eq!(result.prompts[0].name, "greeting");
3252            }
3253            _ => panic!("Expected ListPrompts response"),
3254        }
3255    }
3256
3257    #[tokio::test]
3258    async fn test_prompt_filter_get_denied() {
3259        use crate::filter::CapabilityFilter;
3260        use crate::prompt::{Prompt, PromptBuilder};
3261        use std::collections::HashMap;
3262
3263        let admin_prompt = PromptBuilder::new("system_debug")
3264            .description("Admin prompt")
3265            .user_message("Debug");
3266
3267        let mut router = McpRouter::new()
3268            .prompt(admin_prompt)
3269            .prompt_filter(CapabilityFilter::new(|_, _: &Prompt| false)); // Deny all
3270
3271        // Initialize session
3272        init_router(&mut router).await;
3273
3274        let req = RouterRequest {
3275            id: RequestId::Number(1),
3276            inner: McpRequest::GetPrompt(GetPromptParams {
3277                name: "system_debug".to_string(),
3278                arguments: HashMap::new(),
3279            }),
3280            extensions: Extensions::new(),
3281        };
3282
3283        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3284
3285        // Should get method not found error (default denial behavior)
3286        match resp.inner {
3287            Err(e) => {
3288                assert_eq!(e.code, -32601); // Method not found
3289            }
3290            _ => panic!("Expected JsonRpc error"),
3291        }
3292    }
3293
3294    #[tokio::test]
3295    async fn test_prompt_filter_get_allowed() {
3296        use crate::filter::CapabilityFilter;
3297        use crate::prompt::{Prompt, PromptBuilder};
3298        use std::collections::HashMap;
3299
3300        let public_prompt = PromptBuilder::new("greeting")
3301            .description("A greeting")
3302            .user_message("Hello!");
3303
3304        let mut router = McpRouter::new()
3305            .prompt(public_prompt)
3306            .prompt_filter(CapabilityFilter::new(|_, _: &Prompt| true)); // Allow all
3307
3308        // Initialize session
3309        init_router(&mut router).await;
3310
3311        let req = RouterRequest {
3312            id: RequestId::Number(1),
3313            inner: McpRequest::GetPrompt(GetPromptParams {
3314                name: "greeting".to_string(),
3315                arguments: HashMap::new(),
3316            }),
3317            extensions: Extensions::new(),
3318        };
3319
3320        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3321
3322        match resp.inner {
3323            Ok(McpResponse::GetPrompt(result)) => {
3324                assert_eq!(result.messages.len(), 1);
3325            }
3326            _ => panic!("Expected GetPrompt response"),
3327        }
3328    }
3329
3330    #[tokio::test]
3331    async fn test_prompt_filter_custom_denial() {
3332        use crate::filter::{CapabilityFilter, DenialBehavior};
3333        use crate::prompt::{Prompt, PromptBuilder};
3334        use std::collections::HashMap;
3335
3336        let admin_prompt = PromptBuilder::new("system_debug")
3337            .description("Admin prompt")
3338            .user_message("Debug");
3339
3340        let mut router = McpRouter::new().prompt(admin_prompt).prompt_filter(
3341            CapabilityFilter::new(|_, _: &Prompt| false)
3342                .denial_behavior(DenialBehavior::Unauthorized),
3343        );
3344
3345        // Initialize session
3346        init_router(&mut router).await;
3347
3348        let req = RouterRequest {
3349            id: RequestId::Number(1),
3350            inner: McpRequest::GetPrompt(GetPromptParams {
3351                name: "system_debug".to_string(),
3352                arguments: HashMap::new(),
3353            }),
3354            extensions: Extensions::new(),
3355        };
3356
3357        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3358
3359        // Should get forbidden error
3360        match resp.inner {
3361            Err(e) => {
3362                assert_eq!(e.code, -32007); // Forbidden
3363                assert!(e.message.contains("Unauthorized"));
3364            }
3365            _ => panic!("Expected JsonRpc error"),
3366        }
3367    }
3368
3369    // =========================================================================
3370    // Router Composition Tests (merge/nest)
3371    // =========================================================================
3372
3373    #[derive(Debug, Deserialize, JsonSchema)]
3374    struct StringInput {
3375        value: String,
3376    }
3377
3378    #[tokio::test]
3379    async fn test_router_merge_tools() {
3380        // Create first router with a tool
3381        let tool_a = ToolBuilder::new("tool_a")
3382            .description("Tool A")
3383            .handler(|_: StringInput| async move { Ok(CallToolResult::text("A")) })
3384            .build()
3385            .unwrap();
3386
3387        let router_a = McpRouter::new().tool(tool_a);
3388
3389        // Create second router with different tools
3390        let tool_b = ToolBuilder::new("tool_b")
3391            .description("Tool B")
3392            .handler(|_: StringInput| async move { Ok(CallToolResult::text("B")) })
3393            .build()
3394            .unwrap();
3395        let tool_c = ToolBuilder::new("tool_c")
3396            .description("Tool C")
3397            .handler(|_: StringInput| async move { Ok(CallToolResult::text("C")) })
3398            .build()
3399            .unwrap();
3400
3401        let router_b = McpRouter::new().tool(tool_b).tool(tool_c);
3402
3403        // Merge them
3404        let mut merged = McpRouter::new()
3405            .server_info("merged", "1.0")
3406            .merge(router_a)
3407            .merge(router_b);
3408
3409        init_router(&mut merged).await;
3410
3411        // List tools
3412        let req = RouterRequest {
3413            id: RequestId::Number(1),
3414            inner: McpRequest::ListTools(ListToolsParams::default()),
3415            extensions: Extensions::new(),
3416        };
3417
3418        let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3419
3420        match resp.inner {
3421            Ok(McpResponse::ListTools(result)) => {
3422                assert_eq!(result.tools.len(), 3);
3423                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3424                assert!(names.contains(&"tool_a"));
3425                assert!(names.contains(&"tool_b"));
3426                assert!(names.contains(&"tool_c"));
3427            }
3428            _ => panic!("Expected ListTools response"),
3429        }
3430    }
3431
3432    #[tokio::test]
3433    async fn test_router_merge_overwrites_duplicates() {
3434        // Create first router with a tool
3435        let tool_v1 = ToolBuilder::new("shared")
3436            .description("Version 1")
3437            .handler(|_: StringInput| async move { Ok(CallToolResult::text("v1")) })
3438            .build()
3439            .unwrap();
3440
3441        let router_a = McpRouter::new().tool(tool_v1);
3442
3443        // Create second router with same tool name but different description
3444        let tool_v2 = ToolBuilder::new("shared")
3445            .description("Version 2")
3446            .handler(|_: StringInput| async move { Ok(CallToolResult::text("v2")) })
3447            .build()
3448            .unwrap();
3449
3450        let router_b = McpRouter::new().tool(tool_v2);
3451
3452        // Merge - second should win
3453        let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3454
3455        init_router(&mut merged).await;
3456
3457        let req = RouterRequest {
3458            id: RequestId::Number(1),
3459            inner: McpRequest::ListTools(ListToolsParams::default()),
3460            extensions: Extensions::new(),
3461        };
3462
3463        let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3464
3465        match resp.inner {
3466            Ok(McpResponse::ListTools(result)) => {
3467                assert_eq!(result.tools.len(), 1);
3468                assert_eq!(result.tools[0].name, "shared");
3469                assert_eq!(result.tools[0].description.as_deref(), Some("Version 2"));
3470            }
3471            _ => panic!("Expected ListTools response"),
3472        }
3473    }
3474
3475    #[tokio::test]
3476    async fn test_router_merge_resources() {
3477        use crate::resource::ResourceBuilder;
3478
3479        // Create routers with different resources
3480        let router_a = McpRouter::new().resource(
3481            ResourceBuilder::new("file:///a.txt")
3482                .name("File A")
3483                .text("content a"),
3484        );
3485
3486        let router_b = McpRouter::new().resource(
3487            ResourceBuilder::new("file:///b.txt")
3488                .name("File B")
3489                .text("content b"),
3490        );
3491
3492        let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3493
3494        init_router(&mut merged).await;
3495
3496        let req = RouterRequest {
3497            id: RequestId::Number(1),
3498            inner: McpRequest::ListResources(ListResourcesParams::default()),
3499            extensions: Extensions::new(),
3500        };
3501
3502        let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3503
3504        match resp.inner {
3505            Ok(McpResponse::ListResources(result)) => {
3506                assert_eq!(result.resources.len(), 2);
3507                let uris: Vec<&str> = result.resources.iter().map(|r| r.uri.as_str()).collect();
3508                assert!(uris.contains(&"file:///a.txt"));
3509                assert!(uris.contains(&"file:///b.txt"));
3510            }
3511            _ => panic!("Expected ListResources response"),
3512        }
3513    }
3514
3515    #[tokio::test]
3516    async fn test_router_merge_prompts() {
3517        use crate::prompt::PromptBuilder;
3518
3519        let router_a =
3520            McpRouter::new().prompt(PromptBuilder::new("prompt_a").user_message("Hello A"));
3521
3522        let router_b =
3523            McpRouter::new().prompt(PromptBuilder::new("prompt_b").user_message("Hello B"));
3524
3525        let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3526
3527        init_router(&mut merged).await;
3528
3529        let req = RouterRequest {
3530            id: RequestId::Number(1),
3531            inner: McpRequest::ListPrompts(ListPromptsParams::default()),
3532            extensions: Extensions::new(),
3533        };
3534
3535        let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3536
3537        match resp.inner {
3538            Ok(McpResponse::ListPrompts(result)) => {
3539                assert_eq!(result.prompts.len(), 2);
3540                let names: Vec<&str> = result.prompts.iter().map(|p| p.name.as_str()).collect();
3541                assert!(names.contains(&"prompt_a"));
3542                assert!(names.contains(&"prompt_b"));
3543            }
3544            _ => panic!("Expected ListPrompts response"),
3545        }
3546    }
3547
3548    #[tokio::test]
3549    async fn test_router_nest_prefixes_tools() {
3550        // Create a router with tools
3551        let tool_query = ToolBuilder::new("query")
3552            .description("Query the database")
3553            .handler(|_: StringInput| async move { Ok(CallToolResult::text("query result")) })
3554            .build()
3555            .unwrap();
3556        let tool_insert = ToolBuilder::new("insert")
3557            .description("Insert into database")
3558            .handler(|_: StringInput| async move { Ok(CallToolResult::text("insert result")) })
3559            .build()
3560            .unwrap();
3561
3562        let db_router = McpRouter::new().tool(tool_query).tool(tool_insert);
3563
3564        // Nest under "db" prefix
3565        let mut router = McpRouter::new()
3566            .server_info("nested", "1.0")
3567            .nest("db", db_router);
3568
3569        init_router(&mut router).await;
3570
3571        let req = RouterRequest {
3572            id: RequestId::Number(1),
3573            inner: McpRequest::ListTools(ListToolsParams::default()),
3574            extensions: Extensions::new(),
3575        };
3576
3577        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3578
3579        match resp.inner {
3580            Ok(McpResponse::ListTools(result)) => {
3581                assert_eq!(result.tools.len(), 2);
3582                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3583                assert!(names.contains(&"db.query"));
3584                assert!(names.contains(&"db.insert"));
3585            }
3586            _ => panic!("Expected ListTools response"),
3587        }
3588    }
3589
3590    #[tokio::test]
3591    async fn test_router_nest_call_prefixed_tool() {
3592        let tool = ToolBuilder::new("echo")
3593            .description("Echo input")
3594            .handler(|input: StringInput| async move { Ok(CallToolResult::text(&input.value)) })
3595            .build()
3596            .unwrap();
3597
3598        let nested_router = McpRouter::new().tool(tool);
3599
3600        let mut router = McpRouter::new().nest("api", nested_router);
3601
3602        init_router(&mut router).await;
3603
3604        // Call the prefixed tool
3605        let req = RouterRequest {
3606            id: RequestId::Number(1),
3607            inner: McpRequest::CallTool(CallToolParams {
3608                name: "api.echo".to_string(),
3609                arguments: serde_json::json!({"value": "hello world"}),
3610                meta: None,
3611            }),
3612            extensions: Extensions::new(),
3613        };
3614
3615        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3616
3617        match resp.inner {
3618            Ok(McpResponse::CallTool(result)) => {
3619                assert!(!result.is_error);
3620                match &result.content[0] {
3621                    Content::Text { text, .. } => assert_eq!(text, "hello world"),
3622                    _ => panic!("Expected text content"),
3623                }
3624            }
3625            _ => panic!("Expected CallTool response"),
3626        }
3627    }
3628
3629    #[tokio::test]
3630    async fn test_router_multiple_nests() {
3631        let db_tool = ToolBuilder::new("query")
3632            .description("Database query")
3633            .handler(|_: StringInput| async move { Ok(CallToolResult::text("db")) })
3634            .build()
3635            .unwrap();
3636
3637        let api_tool = ToolBuilder::new("fetch")
3638            .description("API fetch")
3639            .handler(|_: StringInput| async move { Ok(CallToolResult::text("api")) })
3640            .build()
3641            .unwrap();
3642
3643        let db_router = McpRouter::new().tool(db_tool);
3644        let api_router = McpRouter::new().tool(api_tool);
3645
3646        let mut router = McpRouter::new()
3647            .nest("db", db_router)
3648            .nest("api", api_router);
3649
3650        init_router(&mut router).await;
3651
3652        let req = RouterRequest {
3653            id: RequestId::Number(1),
3654            inner: McpRequest::ListTools(ListToolsParams::default()),
3655            extensions: Extensions::new(),
3656        };
3657
3658        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3659
3660        match resp.inner {
3661            Ok(McpResponse::ListTools(result)) => {
3662                assert_eq!(result.tools.len(), 2);
3663                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3664                assert!(names.contains(&"db.query"));
3665                assert!(names.contains(&"api.fetch"));
3666            }
3667            _ => panic!("Expected ListTools response"),
3668        }
3669    }
3670
3671    #[tokio::test]
3672    async fn test_router_merge_and_nest_combined() {
3673        // Test combining merge and nest
3674        let tool_a = ToolBuilder::new("local")
3675            .description("Local tool")
3676            .handler(|_: StringInput| async move { Ok(CallToolResult::text("local")) })
3677            .build()
3678            .unwrap();
3679
3680        let nested_tool = ToolBuilder::new("remote")
3681            .description("Remote tool")
3682            .handler(|_: StringInput| async move { Ok(CallToolResult::text("remote")) })
3683            .build()
3684            .unwrap();
3685
3686        let nested_router = McpRouter::new().tool(nested_tool);
3687
3688        let mut router = McpRouter::new()
3689            .tool(tool_a)
3690            .nest("external", nested_router);
3691
3692        init_router(&mut router).await;
3693
3694        let req = RouterRequest {
3695            id: RequestId::Number(1),
3696            inner: McpRequest::ListTools(ListToolsParams::default()),
3697            extensions: Extensions::new(),
3698        };
3699
3700        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3701
3702        match resp.inner {
3703            Ok(McpResponse::ListTools(result)) => {
3704                assert_eq!(result.tools.len(), 2);
3705                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3706                assert!(names.contains(&"local"));
3707                assert!(names.contains(&"external.remote"));
3708            }
3709            _ => panic!("Expected ListTools response"),
3710        }
3711    }
3712
3713    #[tokio::test]
3714    async fn test_router_merge_preserves_server_info() {
3715        let child_router = McpRouter::new()
3716            .server_info("child", "2.0")
3717            .instructions("Child instructions");
3718
3719        let mut router = McpRouter::new()
3720            .server_info("parent", "1.0")
3721            .instructions("Parent instructions")
3722            .merge(child_router);
3723
3724        init_router(&mut router).await;
3725
3726        // Initialize response should have parent's server info
3727        let init_req = RouterRequest {
3728            id: RequestId::Number(99),
3729            inner: McpRequest::Initialize(InitializeParams {
3730                protocol_version: "2025-11-25".to_string(),
3731                capabilities: ClientCapabilities::default(),
3732                client_info: Implementation {
3733                    name: "test".to_string(),
3734                    version: "1.0".to_string(),
3735                    ..Default::default()
3736                },
3737            }),
3738            extensions: Extensions::new(),
3739        };
3740
3741        // Create fresh router for this test since we need to call initialize
3742        let child_router2 = McpRouter::new().server_info("child", "2.0");
3743        let mut fresh_router = McpRouter::new()
3744            .server_info("parent", "1.0")
3745            .merge(child_router2);
3746
3747        let resp = fresh_router
3748            .ready()
3749            .await
3750            .unwrap()
3751            .call(init_req)
3752            .await
3753            .unwrap();
3754
3755        match resp.inner {
3756            Ok(McpResponse::Initialize(result)) => {
3757                assert_eq!(result.server_info.name, "parent");
3758                assert_eq!(result.server_info.version, "1.0");
3759            }
3760            _ => panic!("Expected Initialize response"),
3761        }
3762    }
3763}