Skip to main content

tower_mcp/
router.rs

1//! MCP Router - routes requests to tools, resources, and prompts
2//!
3//! The router implements Tower's `Service` trait, making it composable with
4//! standard tower middleware.
5
6use std::any::{Any, TypeId};
7use std::collections::{HashMap, HashSet};
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::{Arc, RwLock};
11use std::task::{Context, Poll};
12
13use tower_service::Service;
14
15use crate::async_task::TaskStore;
16use crate::context::{
17    CancellationToken, ClientRequesterHandle, NotificationSender, RequestContext,
18    ServerNotification,
19};
20use crate::error::{Error, JsonRpcError, Result};
21use crate::filter::{PromptFilter, ResourceFilter, ToolFilter};
22use crate::prompt::Prompt;
23use crate::protocol::*;
24use crate::resource::{Resource, ResourceTemplate};
25use crate::session::SessionState;
26use crate::tool::Tool;
27
28/// Type alias for completion handler function
29pub type CompletionHandler = Arc<
30    dyn Fn(CompleteParams) -> Pin<Box<dyn Future<Output = Result<CompleteResult>> + Send>>
31        + Send
32        + Sync,
33>;
34
35/// MCP Router that dispatches requests to registered handlers
36///
37/// Implements `tower::Service<McpRequest>` for middleware composition.
38///
39/// # Example
40///
41/// ```rust
42/// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
43/// use schemars::JsonSchema;
44/// use serde::Deserialize;
45///
46/// #[derive(Debug, Deserialize, JsonSchema)]
47/// struct Input { value: String }
48///
49/// let tool = ToolBuilder::new("echo")
50///     .description("Echo input")
51///     .handler(|i: Input| async move { Ok(CallToolResult::text(i.value)) })
52///     .build()
53///     .unwrap();
54///
55/// let router = McpRouter::new()
56///     .server_info("my-server", "1.0.0")
57///     .tool(tool);
58/// ```
59#[derive(Clone)]
60pub struct McpRouter {
61    inner: Arc<McpRouterInner>,
62    session: SessionState,
63}
64
65impl std::fmt::Debug for McpRouter {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.debug_struct("McpRouter")
68            .field("server_name", &self.inner.server_name)
69            .field("server_version", &self.inner.server_version)
70            .field("tools_count", &self.inner.tools.len())
71            .field("resources_count", &self.inner.resources.len())
72            .field("prompts_count", &self.inner.prompts.len())
73            .field("session_phase", &self.session.phase())
74            .finish()
75    }
76}
77
78/// Inner configuration that is shared across clones
79#[derive(Clone)]
80struct McpRouterInner {
81    server_name: String,
82    server_version: String,
83    /// Human-readable title for the server
84    server_title: Option<String>,
85    /// Description of the server
86    server_description: Option<String>,
87    /// Icons for the server
88    server_icons: Option<Vec<ToolIcon>>,
89    /// URL of the server's website
90    server_website_url: Option<String>,
91    instructions: Option<String>,
92    tools: HashMap<String, Arc<Tool>>,
93    resources: HashMap<String, Arc<Resource>>,
94    /// Resource templates for dynamic resource matching (keyed by uri_template)
95    resource_templates: Vec<Arc<ResourceTemplate>>,
96    prompts: HashMap<String, Arc<Prompt>>,
97    /// In-flight requests for cancellation tracking (shared across clones)
98    in_flight: Arc<RwLock<HashMap<RequestId, CancellationToken>>>,
99    /// Channel for sending notifications to connected clients
100    notification_tx: Option<NotificationSender>,
101    /// Handle for sending requests to the client (for sampling, etc.)
102    client_requester: Option<ClientRequesterHandle>,
103    /// Task store for async operations
104    task_store: TaskStore,
105    /// Subscribed resource URIs
106    subscriptions: Arc<RwLock<HashSet<String>>>,
107    /// Handler for completion requests
108    completion_handler: Option<CompletionHandler>,
109    /// Filter for tools based on session state
110    tool_filter: Option<ToolFilter>,
111    /// Filter for resources based on session state
112    resource_filter: Option<ResourceFilter>,
113    /// Filter for prompts based on session state
114    prompt_filter: Option<PromptFilter>,
115}
116
117impl McpRouter {
118    /// Create a new MCP router
119    pub fn new() -> Self {
120        Self {
121            inner: Arc::new(McpRouterInner {
122                server_name: "tower-mcp".to_string(),
123                server_version: env!("CARGO_PKG_VERSION").to_string(),
124                server_title: None,
125                server_description: None,
126                server_icons: None,
127                server_website_url: None,
128                instructions: None,
129                tools: HashMap::new(),
130                resources: HashMap::new(),
131                resource_templates: Vec::new(),
132                prompts: HashMap::new(),
133                in_flight: Arc::new(RwLock::new(HashMap::new())),
134                notification_tx: None,
135                client_requester: None,
136                task_store: TaskStore::new(),
137                subscriptions: Arc::new(RwLock::new(HashSet::new())),
138                completion_handler: None,
139                tool_filter: None,
140                resource_filter: None,
141                prompt_filter: None,
142            }),
143            session: SessionState::new(),
144        }
145    }
146
147    /// Get access to the task store for async operations
148    pub fn task_store(&self) -> &TaskStore {
149        &self.inner.task_store
150    }
151
152    /// Set the notification sender for progress reporting
153    ///
154    /// This is typically called by the transport layer to receive notifications.
155    pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
156        Arc::make_mut(&mut self.inner).notification_tx = Some(tx);
157        self
158    }
159
160    /// Get the notification sender (if configured)
161    pub fn notification_sender(&self) -> Option<&NotificationSender> {
162        self.inner.notification_tx.as_ref()
163    }
164
165    /// Set the client requester for server-to-client requests (sampling, etc.)
166    ///
167    /// This is typically called by bidirectional transports (WebSocket, stdio)
168    /// to enable tool handlers to send requests to the client.
169    pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
170        Arc::make_mut(&mut self.inner).client_requester = Some(requester);
171        self
172    }
173
174    /// Get the client requester (if configured)
175    pub fn client_requester(&self) -> Option<&ClientRequesterHandle> {
176        self.inner.client_requester.as_ref()
177    }
178
179    /// Create a request context for tracking a request
180    ///
181    /// This registers the request for cancellation tracking and sets up
182    /// progress reporting and client requests if configured.
183    pub fn create_context(
184        &self,
185        request_id: RequestId,
186        progress_token: Option<ProgressToken>,
187    ) -> RequestContext {
188        let ctx = RequestContext::new(request_id.clone());
189
190        // Set up progress token if provided
191        let ctx = if let Some(token) = progress_token {
192            ctx.with_progress_token(token)
193        } else {
194            ctx
195        };
196
197        // Set up notification sender if configured
198        let ctx = if let Some(tx) = &self.inner.notification_tx {
199            ctx.with_notification_sender(tx.clone())
200        } else {
201            ctx
202        };
203
204        // Set up client requester if configured (for sampling support)
205        let ctx = if let Some(requester) = &self.inner.client_requester {
206            ctx.with_client_requester(requester.clone())
207        } else {
208            ctx
209        };
210
211        // Register for cancellation tracking
212        let token = ctx.cancellation_token();
213        if let Ok(mut in_flight) = self.inner.in_flight.write() {
214            in_flight.insert(request_id, token);
215        }
216
217        ctx
218    }
219
220    /// Remove a request from tracking (called when request completes)
221    pub fn complete_request(&self, request_id: &RequestId) {
222        if let Ok(mut in_flight) = self.inner.in_flight.write() {
223            in_flight.remove(request_id);
224        }
225    }
226
227    /// Cancel a tracked request
228    fn cancel_request(&self, request_id: &RequestId) -> bool {
229        let Ok(in_flight) = self.inner.in_flight.read() else {
230            return false;
231        };
232        let Some(token) = in_flight.get(request_id) else {
233            return false;
234        };
235        token.cancel();
236        true
237    }
238
239    /// Set server info
240    pub fn server_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
241        let inner = Arc::make_mut(&mut self.inner);
242        inner.server_name = name.into();
243        inner.server_version = version.into();
244        self
245    }
246
247    /// Set instructions for LLMs describing how to use this server
248    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
249        Arc::make_mut(&mut self.inner).instructions = Some(instructions.into());
250        self
251    }
252
253    /// Set a human-readable title for the server
254    pub fn server_title(mut self, title: impl Into<String>) -> Self {
255        Arc::make_mut(&mut self.inner).server_title = Some(title.into());
256        self
257    }
258
259    /// Set the server description
260    pub fn server_description(mut self, description: impl Into<String>) -> Self {
261        Arc::make_mut(&mut self.inner).server_description = Some(description.into());
262        self
263    }
264
265    /// Set icons for the server
266    pub fn server_icons(mut self, icons: Vec<ToolIcon>) -> Self {
267        Arc::make_mut(&mut self.inner).server_icons = Some(icons);
268        self
269    }
270
271    /// Set the server's website URL
272    pub fn server_website_url(mut self, url: impl Into<String>) -> Self {
273        Arc::make_mut(&mut self.inner).server_website_url = Some(url.into());
274        self
275    }
276
277    /// Register a tool
278    pub fn tool(mut self, tool: Tool) -> Self {
279        Arc::make_mut(&mut self.inner)
280            .tools
281            .insert(tool.name.clone(), Arc::new(tool));
282        self
283    }
284
285    /// Register a resource
286    pub fn resource(mut self, resource: Resource) -> Self {
287        Arc::make_mut(&mut self.inner)
288            .resources
289            .insert(resource.uri.clone(), Arc::new(resource));
290        self
291    }
292
293    /// Register a resource template
294    ///
295    /// Resource templates allow dynamic resources to be matched by URI pattern.
296    /// When a client requests a resource URI that doesn't match any static
297    /// resource, the router tries to match it against registered templates.
298    ///
299    /// # Example
300    ///
301    /// ```rust
302    /// use tower_mcp::{McpRouter, ResourceTemplateBuilder};
303    /// use tower_mcp::protocol::{ReadResourceResult, ResourceContent};
304    /// use std::collections::HashMap;
305    ///
306    /// let template = ResourceTemplateBuilder::new("file:///{path}")
307    ///     .name("Project Files")
308    ///     .handler(|uri: String, vars: HashMap<String, String>| async move {
309    ///         let path = vars.get("path").unwrap_or(&String::new()).clone();
310    ///         Ok(ReadResourceResult {
311    ///             contents: vec![ResourceContent {
312    ///                 uri,
313    ///                 mime_type: Some("text/plain".to_string()),
314    ///                 text: Some(format!("Contents of {}", path)),
315    ///                 blob: None,
316    ///             }],
317    ///         })
318    ///     });
319    ///
320    /// let router = McpRouter::new()
321    ///     .resource_template(template);
322    /// ```
323    pub fn resource_template(mut self, template: ResourceTemplate) -> Self {
324        Arc::make_mut(&mut self.inner)
325            .resource_templates
326            .push(Arc::new(template));
327        self
328    }
329
330    /// Register a prompt
331    pub fn prompt(mut self, prompt: Prompt) -> Self {
332        Arc::make_mut(&mut self.inner)
333            .prompts
334            .insert(prompt.name.clone(), Arc::new(prompt));
335        self
336    }
337
338    /// Register multiple tools at once.
339    ///
340    /// # Example
341    ///
342    /// ```rust
343    /// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
344    /// use schemars::JsonSchema;
345    /// use serde::Deserialize;
346    ///
347    /// #[derive(Debug, Deserialize, JsonSchema)]
348    /// struct Input { value: String }
349    ///
350    /// let tools = vec![
351    ///     ToolBuilder::new("a")
352    ///         .description("Tool A")
353    ///         .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
354    ///         .build().unwrap(),
355    ///     ToolBuilder::new("b")
356    ///         .description("Tool B")
357    ///         .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
358    ///         .build().unwrap(),
359    /// ];
360    ///
361    /// let router = McpRouter::new().tools(tools);
362    /// ```
363    pub fn tools(self, tools: impl IntoIterator<Item = Tool>) -> Self {
364        tools
365            .into_iter()
366            .fold(self, |router, tool| router.tool(tool))
367    }
368
369    /// Register multiple resources at once.
370    ///
371    /// # Example
372    ///
373    /// ```rust
374    /// use tower_mcp::{McpRouter, ResourceBuilder};
375    ///
376    /// let resources = vec![
377    ///     ResourceBuilder::new("file:///a.txt")
378    ///         .name("File A")
379    ///         .text("contents a"),
380    ///     ResourceBuilder::new("file:///b.txt")
381    ///         .name("File B")
382    ///         .text("contents b"),
383    /// ];
384    ///
385    /// let router = McpRouter::new().resources(resources);
386    /// ```
387    pub fn resources(self, resources: impl IntoIterator<Item = Resource>) -> Self {
388        resources
389            .into_iter()
390            .fold(self, |router, resource| router.resource(resource))
391    }
392
393    /// Register multiple prompts at once.
394    ///
395    /// # Example
396    ///
397    /// ```rust
398    /// use tower_mcp::{McpRouter, PromptBuilder};
399    ///
400    /// let prompts = vec![
401    ///     PromptBuilder::new("greet")
402    ///         .description("Greet someone")
403    ///         .user_message("Hello!"),
404    ///     PromptBuilder::new("farewell")
405    ///         .description("Say goodbye")
406    ///         .user_message("Goodbye!"),
407    /// ];
408    ///
409    /// let router = McpRouter::new().prompts(prompts);
410    /// ```
411    pub fn prompts(self, prompts: impl IntoIterator<Item = Prompt>) -> Self {
412        prompts
413            .into_iter()
414            .fold(self, |router, prompt| router.prompt(prompt))
415    }
416
417    /// Register a completion handler for `completion/complete` requests.
418    ///
419    /// The handler receives `CompleteParams` containing the reference (prompt or resource)
420    /// and the argument being completed, and should return completion suggestions.
421    ///
422    /// # Example
423    ///
424    /// ```rust
425    /// use tower_mcp::{McpRouter, CompleteResult};
426    /// use tower_mcp::protocol::{CompleteParams, CompletionReference};
427    ///
428    /// let router = McpRouter::new()
429    ///     .completion_handler(|params: CompleteParams| async move {
430    ///         // Provide completions based on the reference and argument
431    ///         match params.reference {
432    ///             CompletionReference::Prompt { name } => {
433    ///                 // Return prompt argument completions
434    ///                 Ok(CompleteResult::new(vec!["option1".to_string(), "option2".to_string()]))
435    ///             }
436    ///             CompletionReference::Resource { uri } => {
437    ///                 // Return resource URI completions
438    ///                 Ok(CompleteResult::new(vec![]))
439    ///             }
440    ///         }
441    ///     });
442    /// ```
443    pub fn completion_handler<F, Fut>(mut self, handler: F) -> Self
444    where
445        F: Fn(CompleteParams) -> Fut + Send + Sync + 'static,
446        Fut: Future<Output = Result<CompleteResult>> + Send + 'static,
447    {
448        Arc::make_mut(&mut self.inner).completion_handler =
449            Some(Arc::new(move |params| Box::pin(handler(params))));
450        self
451    }
452
453    /// Set a filter for tools based on session state.
454    ///
455    /// The filter determines which tools are visible to each session. Tools that
456    /// don't pass the filter will not appear in `tools/list` responses and will
457    /// return an error if called directly.
458    ///
459    /// # Example
460    ///
461    /// ```rust
462    /// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult, CapabilityFilter, Tool, Filterable};
463    /// use schemars::JsonSchema;
464    /// use serde::Deserialize;
465    ///
466    /// #[derive(Debug, Deserialize, JsonSchema)]
467    /// struct Input { value: String }
468    ///
469    /// let public_tool = ToolBuilder::new("public")
470    ///     .description("Available to everyone")
471    ///     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
472    ///     .build()
473    ///     .unwrap();
474    ///
475    /// let admin_tool = ToolBuilder::new("admin")
476    ///     .description("Admin only")
477    ///     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
478    ///     .build()
479    ///     .unwrap();
480    ///
481    /// let router = McpRouter::new()
482    ///     .tool(public_tool)
483    ///     .tool(admin_tool)
484    ///     .tool_filter(CapabilityFilter::new(|_session, tool: &Tool| {
485    ///         // In real code, check session.extensions() for auth claims
486    ///         tool.name() != "admin"
487    ///     }));
488    /// ```
489    pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
490        Arc::make_mut(&mut self.inner).tool_filter = Some(filter);
491        self
492    }
493
494    /// Set a filter for resources based on session state.
495    ///
496    /// The filter receives the current session state and each resource, returning
497    /// `true` if the resource should be visible to this session. Resources that
498    /// don't pass the filter will not appear in `resources/list` responses and will
499    /// return an error if read directly.
500    ///
501    /// # Example
502    ///
503    /// ```rust
504    /// use tower_mcp::{McpRouter, ResourceBuilder, ReadResourceResult, CapabilityFilter, Resource, Filterable};
505    ///
506    /// let public_resource = ResourceBuilder::new("file:///public.txt")
507    ///     .name("Public File")
508    ///     .description("Available to everyone")
509    ///     .text("public content");
510    ///
511    /// let secret_resource = ResourceBuilder::new("file:///secret.txt")
512    ///     .name("Secret File")
513    ///     .description("Admin only")
514    ///     .text("secret content");
515    ///
516    /// let router = McpRouter::new()
517    ///     .resource(public_resource)
518    ///     .resource(secret_resource)
519    ///     .resource_filter(CapabilityFilter::new(|_session, resource: &Resource| {
520    ///         // In real code, check session.extensions() for auth claims
521    ///         !resource.name().contains("Secret")
522    ///     }));
523    /// ```
524    pub fn resource_filter(mut self, filter: ResourceFilter) -> Self {
525        Arc::make_mut(&mut self.inner).resource_filter = Some(filter);
526        self
527    }
528
529    /// Set a filter for prompts based on session state.
530    ///
531    /// The filter receives the current session state and each prompt, returning
532    /// `true` if the prompt should be visible to this session. Prompts that
533    /// don't pass the filter will not appear in `prompts/list` responses and will
534    /// return an error if accessed directly.
535    ///
536    /// # Example
537    ///
538    /// ```rust
539    /// use tower_mcp::{McpRouter, PromptBuilder, CapabilityFilter, Prompt, Filterable};
540    ///
541    /// let public_prompt = PromptBuilder::new("greeting")
542    ///     .description("A friendly greeting")
543    ///     .user_message("Hello!");
544    ///
545    /// let admin_prompt = PromptBuilder::new("system_debug")
546    ///     .description("Admin debugging prompt")
547    ///     .user_message("Debug info");
548    ///
549    /// let router = McpRouter::new()
550    ///     .prompt(public_prompt)
551    ///     .prompt(admin_prompt)
552    ///     .prompt_filter(CapabilityFilter::new(|_session, prompt: &Prompt| {
553    ///         // In real code, check session.extensions() for auth claims
554    ///         !prompt.name().contains("system")
555    ///     }));
556    /// ```
557    pub fn prompt_filter(mut self, filter: PromptFilter) -> Self {
558        Arc::make_mut(&mut self.inner).prompt_filter = Some(filter);
559        self
560    }
561
562    /// Get access to the session state
563    pub fn session(&self) -> &SessionState {
564        &self.session
565    }
566
567    /// Send a log message notification to the client
568    ///
569    /// This sends a `notifications/message` notification with the given parameters.
570    /// Returns `true` if the notification was sent, `false` if no notification channel
571    /// is configured.
572    ///
573    /// # Example
574    ///
575    /// ```rust,ignore
576    /// use tower_mcp::protocol::{LogLevel, LoggingMessageParams};
577    ///
578    /// // Simple info message
579    /// router.log(LoggingMessageParams::new(LogLevel::Info).with_data(
580    ///     serde_json::json!({"message": "Operation completed"})
581    /// ));
582    ///
583    /// // Error with logger name
584    /// router.log(LoggingMessageParams::new(LogLevel::Error)
585    ///     .with_logger("database")
586    ///     .with_data(serde_json::json!({"error": "Connection failed"})));
587    /// ```
588    pub fn log(&self, params: LoggingMessageParams) -> bool {
589        let Some(tx) = &self.inner.notification_tx else {
590            return false;
591        };
592        tx.try_send(ServerNotification::LogMessage(params)).is_ok()
593    }
594
595    /// Send an info-level log message
596    ///
597    /// Convenience method for sending an info log with optional data.
598    pub fn log_info(&self, message: &str) -> bool {
599        self.log(
600            LoggingMessageParams::new(LogLevel::Info)
601                .with_data(serde_json::json!({ "message": message })),
602        )
603    }
604
605    /// Send a warning-level log message
606    pub fn log_warning(&self, message: &str) -> bool {
607        self.log(
608            LoggingMessageParams::new(LogLevel::Warning)
609                .with_data(serde_json::json!({ "message": message })),
610        )
611    }
612
613    /// Send an error-level log message
614    pub fn log_error(&self, message: &str) -> bool {
615        self.log(
616            LoggingMessageParams::new(LogLevel::Error)
617                .with_data(serde_json::json!({ "message": message })),
618        )
619    }
620
621    /// Send a debug-level log message
622    pub fn log_debug(&self, message: &str) -> bool {
623        self.log(
624            LoggingMessageParams::new(LogLevel::Debug)
625                .with_data(serde_json::json!({ "message": message })),
626        )
627    }
628
629    /// Check if a resource URI is currently subscribed
630    pub fn is_subscribed(&self, uri: &str) -> bool {
631        if let Ok(subs) = self.inner.subscriptions.read() {
632            return subs.contains(uri);
633        }
634        false
635    }
636
637    /// Get a list of all subscribed resource URIs
638    pub fn subscribed_uris(&self) -> Vec<String> {
639        if let Ok(subs) = self.inner.subscriptions.read() {
640            return subs.iter().cloned().collect();
641        }
642        Vec::new()
643    }
644
645    /// Subscribe to a resource URI
646    fn subscribe(&self, uri: &str) -> bool {
647        if let Ok(mut subs) = self.inner.subscriptions.write() {
648            return subs.insert(uri.to_string());
649        }
650        false
651    }
652
653    /// Unsubscribe from a resource URI
654    fn unsubscribe(&self, uri: &str) -> bool {
655        if let Ok(mut subs) = self.inner.subscriptions.write() {
656            return subs.remove(uri);
657        }
658        false
659    }
660
661    /// Notify clients that a subscribed resource has been updated
662    ///
663    /// Only sends the notification if the resource is currently subscribed.
664    /// Returns `true` if the notification was sent.
665    pub fn notify_resource_updated(&self, uri: &str) -> bool {
666        // Only notify if the resource is subscribed
667        if !self.is_subscribed(uri) {
668            return false;
669        }
670
671        let Some(tx) = &self.inner.notification_tx else {
672            return false;
673        };
674        tx.try_send(ServerNotification::ResourceUpdated {
675            uri: uri.to_string(),
676        })
677        .is_ok()
678    }
679
680    /// Notify clients that the list of available resources has changed
681    ///
682    /// Returns `true` if the notification was sent.
683    pub fn notify_resources_list_changed(&self) -> bool {
684        let Some(tx) = &self.inner.notification_tx else {
685            return false;
686        };
687        tx.try_send(ServerNotification::ResourcesListChanged)
688            .is_ok()
689    }
690
691    /// Get server capabilities based on registered handlers
692    fn capabilities(&self) -> ServerCapabilities {
693        let has_resources =
694            !self.inner.resources.is_empty() || !self.inner.resource_templates.is_empty();
695
696        ServerCapabilities {
697            tools: if self.inner.tools.is_empty() {
698                None
699            } else {
700                Some(ToolsCapability::default())
701            },
702            resources: if has_resources {
703                Some(ResourcesCapability {
704                    subscribe: true,
705                    ..Default::default()
706                })
707            } else {
708                None
709            },
710            prompts: if self.inner.prompts.is_empty() {
711                None
712            } else {
713                Some(PromptsCapability::default())
714            },
715            // Always advertise logging capability when notification channel is configured
716            logging: if self.inner.notification_tx.is_some() {
717                Some(LoggingCapability::default())
718            } else {
719                None
720            },
721            // Tasks capability is always available
722            tasks: Some(TasksCapability::default()),
723            // Completions capability when a handler is registered
724            completions: if self.inner.completion_handler.is_some() {
725                Some(CompletionsCapability::default())
726            } else {
727                None
728            },
729        }
730    }
731
732    /// Handle an MCP request
733    async fn handle(&self, request_id: RequestId, request: McpRequest) -> Result<McpResponse> {
734        // Enforce session state - reject requests before initialization
735        let method = request.method_name();
736        if !self.session.is_request_allowed(method) {
737            tracing::warn!(
738                method = %method,
739                phase = ?self.session.phase(),
740                "Request rejected: session not initialized"
741            );
742            return Err(Error::JsonRpc(JsonRpcError::invalid_request(format!(
743                "Session not initialized. Only 'initialize' and 'ping' are allowed before initialization. Got: {}",
744                method
745            ))));
746        }
747
748        match request {
749            McpRequest::Initialize(params) => {
750                tracing::info!(
751                    client = %params.client_info.name,
752                    version = %params.client_info.version,
753                    "Client initializing"
754                );
755
756                // Protocol version negotiation: respond with same version if supported,
757                // otherwise respond with our latest supported version
758                let protocol_version = if crate::protocol::SUPPORTED_PROTOCOL_VERSIONS
759                    .contains(&params.protocol_version.as_str())
760                {
761                    params.protocol_version
762                } else {
763                    crate::protocol::LATEST_PROTOCOL_VERSION.to_string()
764                };
765
766                // Transition session state to Initializing
767                self.session.mark_initializing();
768
769                Ok(McpResponse::Initialize(InitializeResult {
770                    protocol_version,
771                    capabilities: self.capabilities(),
772                    server_info: Implementation {
773                        name: self.inner.server_name.clone(),
774                        version: self.inner.server_version.clone(),
775                        title: self.inner.server_title.clone(),
776                        description: self.inner.server_description.clone(),
777                        icons: self.inner.server_icons.clone(),
778                        website_url: self.inner.server_website_url.clone(),
779                    },
780                    instructions: self.inner.instructions.clone(),
781                }))
782            }
783
784            McpRequest::ListTools(_params) => {
785                let tools: Vec<ToolDefinition> = self
786                    .inner
787                    .tools
788                    .values()
789                    .filter(|t| {
790                        // Apply tool filter if configured
791                        self.inner
792                            .tool_filter
793                            .as_ref()
794                            .map(|f| f.is_visible(&self.session, t))
795                            .unwrap_or(true)
796                    })
797                    .map(|t| t.definition())
798                    .collect();
799
800                Ok(McpResponse::ListTools(ListToolsResult {
801                    tools,
802                    next_cursor: None,
803                }))
804            }
805
806            McpRequest::CallTool(params) => {
807                let tool =
808                    self.inner.tools.get(&params.name).ok_or_else(|| {
809                        Error::JsonRpc(JsonRpcError::method_not_found(&params.name))
810                    })?;
811
812                // Check tool filter if configured
813                if let Some(filter) = &self.inner.tool_filter {
814                    if !filter.is_visible(&self.session, tool) {
815                        return Err(filter.denial_error(&params.name));
816                    }
817                }
818
819                // Extract progress token from request metadata
820                let progress_token = params.meta.and_then(|m| m.progress_token);
821                let ctx = self.create_context(request_id, progress_token);
822
823                tracing::debug!(tool = %params.name, "Calling tool");
824                let result = tool.call_with_context(ctx, params.arguments).await?;
825
826                Ok(McpResponse::CallTool(result))
827            }
828
829            McpRequest::ListResources(_params) => {
830                let resources: Vec<ResourceDefinition> = self
831                    .inner
832                    .resources
833                    .values()
834                    .filter(|r| {
835                        // Apply resource filter if configured
836                        self.inner
837                            .resource_filter
838                            .as_ref()
839                            .map(|f| f.is_visible(&self.session, r))
840                            .unwrap_or(true)
841                    })
842                    .map(|r| r.definition())
843                    .collect();
844
845                Ok(McpResponse::ListResources(ListResourcesResult {
846                    resources,
847                    next_cursor: None,
848                }))
849            }
850
851            McpRequest::ListResourceTemplates(_params) => {
852                let resource_templates: Vec<ResourceTemplateDefinition> = self
853                    .inner
854                    .resource_templates
855                    .iter()
856                    .map(|t| t.definition())
857                    .collect();
858
859                Ok(McpResponse::ListResourceTemplates(
860                    ListResourceTemplatesResult {
861                        resource_templates,
862                        next_cursor: None,
863                    },
864                ))
865            }
866
867            McpRequest::ReadResource(params) => {
868                // First, try to find a static resource
869                if let Some(resource) = self.inner.resources.get(&params.uri) {
870                    // Check resource filter if configured
871                    if let Some(filter) = &self.inner.resource_filter {
872                        if !filter.is_visible(&self.session, resource) {
873                            return Err(filter.denial_error(&params.uri));
874                        }
875                    }
876
877                    tracing::debug!(uri = %params.uri, "Reading static resource");
878                    let result = resource.read().await?;
879                    return Ok(McpResponse::ReadResource(result));
880                }
881
882                // If no static resource found, try to match against templates
883                for template in &self.inner.resource_templates {
884                    if let Some(variables) = template.match_uri(&params.uri) {
885                        tracing::debug!(
886                            uri = %params.uri,
887                            template = %template.uri_template,
888                            "Reading resource via template"
889                        );
890                        let result = template.read(&params.uri, variables).await?;
891                        return Ok(McpResponse::ReadResource(result));
892                    }
893                }
894
895                // No match found
896                Err(Error::JsonRpc(JsonRpcError::resource_not_found(
897                    &params.uri,
898                )))
899            }
900
901            McpRequest::SubscribeResource(params) => {
902                // Verify the resource exists
903                if !self.inner.resources.contains_key(&params.uri) {
904                    return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
905                        &params.uri,
906                    )));
907                }
908
909                tracing::debug!(uri = %params.uri, "Subscribing to resource");
910                self.subscribe(&params.uri);
911
912                Ok(McpResponse::SubscribeResource(EmptyResult {}))
913            }
914
915            McpRequest::UnsubscribeResource(params) => {
916                // Verify the resource exists
917                if !self.inner.resources.contains_key(&params.uri) {
918                    return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
919                        &params.uri,
920                    )));
921                }
922
923                tracing::debug!(uri = %params.uri, "Unsubscribing from resource");
924                self.unsubscribe(&params.uri);
925
926                Ok(McpResponse::UnsubscribeResource(EmptyResult {}))
927            }
928
929            McpRequest::ListPrompts(_params) => {
930                let prompts: Vec<PromptDefinition> = self
931                    .inner
932                    .prompts
933                    .values()
934                    .filter(|p| {
935                        // Apply prompt filter if configured
936                        self.inner
937                            .prompt_filter
938                            .as_ref()
939                            .map(|f| f.is_visible(&self.session, p))
940                            .unwrap_or(true)
941                    })
942                    .map(|p| p.definition())
943                    .collect();
944
945                Ok(McpResponse::ListPrompts(ListPromptsResult {
946                    prompts,
947                    next_cursor: None,
948                }))
949            }
950
951            McpRequest::GetPrompt(params) => {
952                let prompt = self.inner.prompts.get(&params.name).ok_or_else(|| {
953                    Error::JsonRpc(JsonRpcError::method_not_found(&format!(
954                        "Prompt not found: {}",
955                        params.name
956                    )))
957                })?;
958
959                // Check prompt filter if configured
960                if let Some(filter) = &self.inner.prompt_filter {
961                    if !filter.is_visible(&self.session, prompt) {
962                        return Err(filter.denial_error(&params.name));
963                    }
964                }
965
966                tracing::debug!(name = %params.name, "Getting prompt");
967                let result = prompt.get(params.arguments).await?;
968
969                Ok(McpResponse::GetPrompt(result))
970            }
971
972            McpRequest::Ping => Ok(McpResponse::Pong(EmptyResult {})),
973
974            McpRequest::EnqueueTask(params) => {
975                // Verify the tool exists
976                let tool = self.inner.tools.get(&params.tool_name).ok_or_else(|| {
977                    Error::JsonRpc(JsonRpcError::method_not_found(&format!(
978                        "Tool not found: {}",
979                        params.tool_name
980                    )))
981                })?;
982
983                // Create the task
984                let (task_id, cancellation_token) = self.inner.task_store.create_task(
985                    &params.tool_name,
986                    params.arguments.clone(),
987                    params.ttl,
988                );
989
990                tracing::info!(task_id = %task_id, tool = %params.tool_name, "Enqueued async task");
991
992                // Create a context for the async task execution
993                let ctx = self.create_context(request_id, None);
994
995                // Spawn the task execution in the background
996                let task_store = self.inner.task_store.clone();
997                let tool = tool.clone();
998                let arguments = params.arguments;
999                let task_id_clone = task_id.clone();
1000
1001                tokio::spawn(async move {
1002                    // Check for cancellation before starting
1003                    if cancellation_token.is_cancelled() {
1004                        tracing::debug!(task_id = %task_id_clone, "Task cancelled before execution");
1005                        return;
1006                    }
1007
1008                    // Execute the tool
1009                    match tool.call_with_context(ctx, arguments).await {
1010                        Ok(result) => {
1011                            if cancellation_token.is_cancelled() {
1012                                tracing::debug!(task_id = %task_id_clone, "Task cancelled during execution");
1013                            } else {
1014                                task_store.complete_task(&task_id_clone, result);
1015                                tracing::debug!(task_id = %task_id_clone, "Task completed successfully");
1016                            }
1017                        }
1018                        Err(e) => {
1019                            task_store.fail_task(&task_id_clone, &e.to_string());
1020                            tracing::warn!(task_id = %task_id_clone, error = %e, "Task failed");
1021                        }
1022                    }
1023                });
1024
1025                Ok(McpResponse::EnqueueTask(EnqueueTaskResult {
1026                    task_id,
1027                    status: TaskStatus::Working,
1028                    poll_interval: Some(2),
1029                }))
1030            }
1031
1032            McpRequest::ListTasks(params) => {
1033                let tasks = self.inner.task_store.list_tasks(params.status);
1034
1035                Ok(McpResponse::ListTasks(ListTasksResult {
1036                    tasks,
1037                    next_cursor: None,
1038                }))
1039            }
1040
1041            McpRequest::GetTaskInfo(params) => {
1042                let task = self
1043                    .inner
1044                    .task_store
1045                    .get_task(&params.task_id)
1046                    .ok_or_else(|| {
1047                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
1048                            "Task not found: {}",
1049                            params.task_id
1050                        )))
1051                    })?;
1052
1053                Ok(McpResponse::GetTaskInfo(task))
1054            }
1055
1056            McpRequest::GetTaskResult(params) => {
1057                let (status, result, error) = self
1058                    .inner
1059                    .task_store
1060                    .get_task_full(&params.task_id)
1061                    .ok_or_else(|| {
1062                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
1063                            "Task not found: {}",
1064                            params.task_id
1065                        )))
1066                    })?;
1067
1068                Ok(McpResponse::GetTaskResult(GetTaskResultResult {
1069                    task_id: params.task_id,
1070                    status,
1071                    result,
1072                    error,
1073                }))
1074            }
1075
1076            McpRequest::CancelTask(params) => {
1077                let status = self
1078                    .inner
1079                    .task_store
1080                    .cancel_task(&params.task_id, params.reason.as_deref())
1081                    .ok_or_else(|| {
1082                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
1083                            "Task not found: {}",
1084                            params.task_id
1085                        )))
1086                    })?;
1087
1088                let cancelled = status == TaskStatus::Cancelled;
1089
1090                Ok(McpResponse::CancelTask(CancelTaskResult {
1091                    cancelled,
1092                    status,
1093                }))
1094            }
1095
1096            McpRequest::SetLoggingLevel(params) => {
1097                // Store the log level for filtering outgoing log notifications
1098                // For now, we just accept the request - actual filtering would be
1099                // implemented in the notification sending logic
1100                tracing::debug!(level = ?params.level, "Client set logging level");
1101                Ok(McpResponse::SetLoggingLevel(EmptyResult {}))
1102            }
1103
1104            McpRequest::Complete(params) => {
1105                tracing::debug!(
1106                    reference = ?params.reference,
1107                    argument = %params.argument.name,
1108                    "Completion request"
1109                );
1110
1111                // Delegate to registered completion handler if available
1112                if let Some(ref handler) = self.inner.completion_handler {
1113                    let result = handler(params).await?;
1114                    Ok(McpResponse::Complete(result))
1115                } else {
1116                    // No completion handler registered, return empty completions
1117                    Ok(McpResponse::Complete(CompleteResult::new(vec![])))
1118                }
1119            }
1120
1121            McpRequest::Unknown { method, .. } => {
1122                Err(Error::JsonRpc(JsonRpcError::method_not_found(&method)))
1123            }
1124        }
1125    }
1126
1127    /// Handle an MCP notification (no response expected)
1128    pub fn handle_notification(&self, notification: McpNotification) {
1129        match notification {
1130            McpNotification::Initialized => {
1131                if self.session.mark_initialized() {
1132                    tracing::info!("Session initialized, entering operation phase");
1133                } else {
1134                    tracing::warn!(
1135                        "Received initialized notification in unexpected state: {:?}",
1136                        self.session.phase()
1137                    );
1138                }
1139            }
1140            McpNotification::Cancelled(params) => {
1141                if self.cancel_request(&params.request_id) {
1142                    tracing::info!(
1143                        request_id = ?params.request_id,
1144                        reason = ?params.reason,
1145                        "Request cancelled"
1146                    );
1147                } else {
1148                    tracing::debug!(
1149                        request_id = ?params.request_id,
1150                        reason = ?params.reason,
1151                        "Cancellation requested for unknown request"
1152                    );
1153                }
1154            }
1155            McpNotification::Progress(params) => {
1156                tracing::trace!(
1157                    token = ?params.progress_token,
1158                    progress = params.progress,
1159                    total = ?params.total,
1160                    "Progress notification"
1161                );
1162                // Progress notifications from client are unusual but valid
1163            }
1164            McpNotification::RootsListChanged => {
1165                tracing::info!("Client roots list changed");
1166                // Server should re-request roots if needed
1167                // This is handled by the application layer
1168            }
1169            McpNotification::Unknown { method, .. } => {
1170                tracing::debug!(method = %method, "Unknown notification received");
1171            }
1172        }
1173    }
1174}
1175
1176impl Default for McpRouter {
1177    fn default() -> Self {
1178        Self::new()
1179    }
1180}
1181
1182// =============================================================================
1183// Tower Service implementation
1184// =============================================================================
1185
1186/// A minimal type-map for passing data through middleware.
1187///
1188/// Uses `Arc<dyn Any>` internally so `Clone` is cheap, which is needed for
1189/// batch requests that create multiple `RouterRequest`s from the same HTTP
1190/// request.
1191///
1192/// # Example
1193///
1194/// ```rust
1195/// use tower_mcp::Extensions;
1196///
1197/// let mut ext = Extensions::new();
1198/// ext.insert(42u32);
1199/// assert_eq!(ext.get::<u32>(), Some(&42));
1200/// ```
1201#[derive(Default, Clone)]
1202pub struct Extensions {
1203    map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
1204}
1205
1206impl Extensions {
1207    /// Create an empty extensions map.
1208    pub fn new() -> Self {
1209        Self::default()
1210    }
1211
1212    /// Insert a value into the extensions map.
1213    ///
1214    /// If a value of the same type already exists, it is replaced.
1215    pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
1216        self.map.insert(TypeId::of::<T>(), Arc::new(val));
1217    }
1218
1219    /// Get a reference to a value in the extensions map.
1220    ///
1221    /// Returns `None` if no value of the given type has been inserted.
1222    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
1223        self.map
1224            .get(&TypeId::of::<T>())
1225            .and_then(|val| val.downcast_ref::<T>())
1226    }
1227}
1228
1229impl std::fmt::Debug for Extensions {
1230    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1231        f.debug_struct("Extensions")
1232            .field("len", &self.map.len())
1233            .finish()
1234    }
1235}
1236
1237/// Request type for the tower Service implementation
1238#[derive(Debug)]
1239pub struct RouterRequest {
1240    pub id: RequestId,
1241    pub inner: McpRequest,
1242    /// Type-map for passing data (e.g., `TokenClaims`) through middleware.
1243    pub extensions: Extensions,
1244}
1245
1246/// Response type for the tower Service implementation
1247#[derive(Debug)]
1248pub struct RouterResponse {
1249    pub id: RequestId,
1250    pub inner: std::result::Result<McpResponse, JsonRpcError>,
1251}
1252
1253impl RouterResponse {
1254    /// Convert to JSON-RPC response
1255    pub fn into_jsonrpc(self) -> JsonRpcResponse {
1256        match self.inner {
1257            Ok(response) => match serde_json::to_value(response) {
1258                Ok(result) => JsonRpcResponse::result(self.id, result),
1259                Err(e) => {
1260                    tracing::error!(error = %e, "Failed to serialize response");
1261                    JsonRpcResponse::error(
1262                        Some(self.id),
1263                        JsonRpcError::internal_error(format!("Serialization error: {}", e)),
1264                    )
1265                }
1266            },
1267            Err(error) => JsonRpcResponse::error(Some(self.id), error),
1268        }
1269    }
1270}
1271
1272impl Service<RouterRequest> for McpRouter {
1273    type Response = RouterResponse;
1274    type Error = std::convert::Infallible; // Errors are in the response
1275    type Future =
1276        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
1277
1278    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
1279        Poll::Ready(Ok(()))
1280    }
1281
1282    fn call(&mut self, req: RouterRequest) -> Self::Future {
1283        let router = self.clone();
1284        let request_id = req.id.clone();
1285        Box::pin(async move {
1286            let result = router.handle(req.id, req.inner).await;
1287            // Clean up tracking after request completes
1288            router.complete_request(&request_id);
1289            Ok(RouterResponse {
1290                id: request_id,
1291                // Map tower-mcp errors to JSON-RPC errors:
1292                // - Error::JsonRpc: forwarded as-is (preserves original code)
1293                // - Error::Tool: mapped to -32603 (Internal Error)
1294                // - All others: mapped to -32603 (Internal Error)
1295                inner: result.map_err(|e| match e {
1296                    Error::JsonRpc(err) => err,
1297                    Error::Tool(err) => JsonRpcError::internal_error(err.to_string()),
1298                    e => JsonRpcError::internal_error(e.to_string()),
1299                }),
1300            })
1301        })
1302    }
1303}
1304
1305#[cfg(test)]
1306mod tests {
1307    use super::*;
1308    use crate::jsonrpc::JsonRpcService;
1309    use crate::tool::ToolBuilder;
1310    use schemars::JsonSchema;
1311    use serde::Deserialize;
1312    use tower::ServiceExt;
1313
1314    #[derive(Debug, Deserialize, JsonSchema)]
1315    struct AddInput {
1316        a: i64,
1317        b: i64,
1318    }
1319
1320    /// Helper to initialize a router for testing
1321    async fn init_router(router: &mut McpRouter) {
1322        // Send initialize request
1323        let init_req = RouterRequest {
1324            id: RequestId::Number(0),
1325            inner: McpRequest::Initialize(InitializeParams {
1326                protocol_version: "2025-11-25".to_string(),
1327                capabilities: ClientCapabilities {
1328                    roots: None,
1329                    sampling: None,
1330                    elicitation: None,
1331                },
1332                client_info: Implementation {
1333                    name: "test".to_string(),
1334                    version: "1.0".to_string(),
1335                    ..Default::default()
1336                },
1337            }),
1338            extensions: Extensions::new(),
1339        };
1340        let _ = router.ready().await.unwrap().call(init_req).await.unwrap();
1341        // Send initialized notification
1342        router.handle_notification(McpNotification::Initialized);
1343    }
1344
1345    #[tokio::test]
1346    async fn test_router_list_tools() {
1347        let add_tool = ToolBuilder::new("add")
1348            .description("Add two numbers")
1349            .handler(|input: AddInput| async move {
1350                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1351            })
1352            .build()
1353            .expect("valid tool name");
1354
1355        let mut router = McpRouter::new().tool(add_tool);
1356
1357        // Initialize session first
1358        init_router(&mut router).await;
1359
1360        let req = RouterRequest {
1361            id: RequestId::Number(1),
1362            inner: McpRequest::ListTools(ListToolsParams::default()),
1363            extensions: Extensions::new(),
1364        };
1365
1366        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1367
1368        match resp.inner {
1369            Ok(McpResponse::ListTools(result)) => {
1370                assert_eq!(result.tools.len(), 1);
1371                assert_eq!(result.tools[0].name, "add");
1372            }
1373            _ => panic!("Expected ListTools response"),
1374        }
1375    }
1376
1377    #[tokio::test]
1378    async fn test_router_call_tool() {
1379        let add_tool = ToolBuilder::new("add")
1380            .description("Add two numbers")
1381            .handler(|input: AddInput| async move {
1382                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1383            })
1384            .build()
1385            .expect("valid tool name");
1386
1387        let mut router = McpRouter::new().tool(add_tool);
1388
1389        // Initialize session first
1390        init_router(&mut router).await;
1391
1392        let req = RouterRequest {
1393            id: RequestId::Number(1),
1394            inner: McpRequest::CallTool(CallToolParams {
1395                name: "add".to_string(),
1396                arguments: serde_json::json!({"a": 2, "b": 3}),
1397                meta: None,
1398            }),
1399            extensions: Extensions::new(),
1400        };
1401
1402        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1403
1404        match resp.inner {
1405            Ok(McpResponse::CallTool(result)) => {
1406                assert!(!result.is_error);
1407                // Check the text content
1408                match &result.content[0] {
1409                    Content::Text { text, .. } => assert_eq!(text, "5"),
1410                    _ => panic!("Expected text content"),
1411                }
1412            }
1413            _ => panic!("Expected CallTool response"),
1414        }
1415    }
1416
1417    /// Helper to initialize a JsonRpcService for testing
1418    async fn init_jsonrpc_service(service: &mut JsonRpcService<McpRouter>, router: &McpRouter) {
1419        let init_req = JsonRpcRequest::new(0, "initialize").with_params(serde_json::json!({
1420            "protocolVersion": "2025-11-25",
1421            "capabilities": {},
1422            "clientInfo": { "name": "test", "version": "1.0" }
1423        }));
1424        let _ = service.call_single(init_req).await.unwrap();
1425        router.handle_notification(McpNotification::Initialized);
1426    }
1427
1428    #[tokio::test]
1429    async fn test_jsonrpc_service() {
1430        let add_tool = ToolBuilder::new("add")
1431            .description("Add two numbers")
1432            .handler(|input: AddInput| async move {
1433                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1434            })
1435            .build()
1436            .expect("valid tool name");
1437
1438        let router = McpRouter::new().tool(add_tool);
1439        let mut service = JsonRpcService::new(router.clone());
1440
1441        // Initialize session first
1442        init_jsonrpc_service(&mut service, &router).await;
1443
1444        let req = JsonRpcRequest::new(1, "tools/list");
1445
1446        let resp = service.call_single(req).await.unwrap();
1447
1448        match resp {
1449            JsonRpcResponse::Result(r) => {
1450                assert_eq!(r.id, RequestId::Number(1));
1451                let tools = r.result.get("tools").unwrap().as_array().unwrap();
1452                assert_eq!(tools.len(), 1);
1453            }
1454            JsonRpcResponse::Error(_) => panic!("Expected success response"),
1455        }
1456    }
1457
1458    #[tokio::test]
1459    async fn test_batch_request() {
1460        let add_tool = ToolBuilder::new("add")
1461            .description("Add two numbers")
1462            .handler(|input: AddInput| async move {
1463                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1464            })
1465            .build()
1466            .expect("valid tool name");
1467
1468        let router = McpRouter::new().tool(add_tool);
1469        let mut service = JsonRpcService::new(router.clone());
1470
1471        // Initialize session first
1472        init_jsonrpc_service(&mut service, &router).await;
1473
1474        // Create a batch of requests
1475        let requests = vec![
1476            JsonRpcRequest::new(1, "tools/list"),
1477            JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1478                "name": "add",
1479                "arguments": {"a": 10, "b": 20}
1480            })),
1481            JsonRpcRequest::new(3, "ping"),
1482        ];
1483
1484        let responses = service.call_batch(requests).await.unwrap();
1485
1486        assert_eq!(responses.len(), 3);
1487
1488        // Check first response (tools/list)
1489        match &responses[0] {
1490            JsonRpcResponse::Result(r) => {
1491                assert_eq!(r.id, RequestId::Number(1));
1492                let tools = r.result.get("tools").unwrap().as_array().unwrap();
1493                assert_eq!(tools.len(), 1);
1494            }
1495            JsonRpcResponse::Error(_) => panic!("Expected success for tools/list"),
1496        }
1497
1498        // Check second response (tools/call)
1499        match &responses[1] {
1500            JsonRpcResponse::Result(r) => {
1501                assert_eq!(r.id, RequestId::Number(2));
1502                let content = r.result.get("content").unwrap().as_array().unwrap();
1503                let text = content[0].get("text").unwrap().as_str().unwrap();
1504                assert_eq!(text, "30");
1505            }
1506            JsonRpcResponse::Error(_) => panic!("Expected success for tools/call"),
1507        }
1508
1509        // Check third response (ping)
1510        match &responses[2] {
1511            JsonRpcResponse::Result(r) => {
1512                assert_eq!(r.id, RequestId::Number(3));
1513            }
1514            JsonRpcResponse::Error(_) => panic!("Expected success for ping"),
1515        }
1516    }
1517
1518    #[tokio::test]
1519    async fn test_empty_batch_error() {
1520        let router = McpRouter::new();
1521        let mut service = JsonRpcService::new(router);
1522
1523        let result = service.call_batch(vec![]).await;
1524        assert!(result.is_err());
1525    }
1526
1527    // =========================================================================
1528    // Progress Token Tests
1529    // =========================================================================
1530
1531    #[tokio::test]
1532    async fn test_progress_token_extraction() {
1533        use crate::context::{RequestContext, ServerNotification, notification_channel};
1534        use crate::protocol::ProgressToken;
1535        use std::sync::Arc;
1536        use std::sync::atomic::{AtomicBool, Ordering};
1537
1538        // Track whether progress was reported
1539        let progress_reported = Arc::new(AtomicBool::new(false));
1540        let progress_ref = progress_reported.clone();
1541
1542        // Create a tool that reports progress
1543        let tool = ToolBuilder::new("progress_tool")
1544            .description("Tool that reports progress")
1545            .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1546                let reported = progress_ref.clone();
1547                async move {
1548                    // Report progress - this should work if token was extracted
1549                    ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
1550                        .await;
1551                    reported.store(true, Ordering::SeqCst);
1552                    Ok(CallToolResult::text("done"))
1553                }
1554            })
1555            .build()
1556            .expect("valid tool name");
1557
1558        // Set up notification channel
1559        let (tx, mut rx) = notification_channel(10);
1560        let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1561        let mut service = JsonRpcService::new(router.clone());
1562
1563        // Initialize
1564        init_jsonrpc_service(&mut service, &router).await;
1565
1566        // Call tool WITH progress token in _meta
1567        let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1568            "name": "progress_tool",
1569            "arguments": {"a": 1, "b": 2},
1570            "_meta": {
1571                "progressToken": "test-token-123"
1572            }
1573        }));
1574
1575        let resp = service.call_single(req).await.unwrap();
1576
1577        // Verify the tool was called successfully
1578        match resp {
1579            JsonRpcResponse::Result(_) => {}
1580            JsonRpcResponse::Error(e) => panic!("Expected success, got error: {:?}", e),
1581        }
1582
1583        // Verify progress was reported by handler
1584        assert!(progress_reported.load(Ordering::SeqCst));
1585
1586        // Verify progress notification was sent through channel
1587        let notification = rx.try_recv().expect("Expected progress notification");
1588        match notification {
1589            ServerNotification::Progress(params) => {
1590                assert_eq!(
1591                    params.progress_token,
1592                    ProgressToken::String("test-token-123".to_string())
1593                );
1594                assert_eq!(params.progress, 50.0);
1595                assert_eq!(params.total, Some(100.0));
1596                assert_eq!(params.message.as_deref(), Some("Halfway"));
1597            }
1598            _ => panic!("Expected Progress notification"),
1599        }
1600    }
1601
1602    #[tokio::test]
1603    async fn test_tool_call_without_progress_token() {
1604        use crate::context::{RequestContext, notification_channel};
1605        use std::sync::Arc;
1606        use std::sync::atomic::{AtomicBool, Ordering};
1607
1608        let progress_attempted = Arc::new(AtomicBool::new(false));
1609        let progress_ref = progress_attempted.clone();
1610
1611        let tool = ToolBuilder::new("no_token_tool")
1612            .description("Tool that tries to report progress without token")
1613            .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1614                let attempted = progress_ref.clone();
1615                async move {
1616                    // Try to report progress - should be a no-op without token
1617                    ctx.report_progress(50.0, Some(100.0), None).await;
1618                    attempted.store(true, Ordering::SeqCst);
1619                    Ok(CallToolResult::text("done"))
1620                }
1621            })
1622            .build()
1623            .expect("valid tool name");
1624
1625        let (tx, mut rx) = notification_channel(10);
1626        let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1627        let mut service = JsonRpcService::new(router.clone());
1628
1629        init_jsonrpc_service(&mut service, &router).await;
1630
1631        // Call tool WITHOUT progress token
1632        let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1633            "name": "no_token_tool",
1634            "arguments": {"a": 1, "b": 2}
1635        }));
1636
1637        let resp = service.call_single(req).await.unwrap();
1638        assert!(matches!(resp, JsonRpcResponse::Result(_)));
1639
1640        // Handler was called
1641        assert!(progress_attempted.load(Ordering::SeqCst));
1642
1643        // But no notification was sent (no progress token)
1644        assert!(rx.try_recv().is_err());
1645    }
1646
1647    #[tokio::test]
1648    async fn test_batch_errors_returned_not_dropped() {
1649        let add_tool = ToolBuilder::new("add")
1650            .description("Add two numbers")
1651            .handler(|input: AddInput| async move {
1652                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1653            })
1654            .build()
1655            .expect("valid tool name");
1656
1657        let router = McpRouter::new().tool(add_tool);
1658        let mut service = JsonRpcService::new(router.clone());
1659
1660        init_jsonrpc_service(&mut service, &router).await;
1661
1662        // Create a batch with one valid and one invalid request
1663        let requests = vec![
1664            // Valid request
1665            JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1666                "name": "add",
1667                "arguments": {"a": 10, "b": 20}
1668            })),
1669            // Invalid request - tool doesn't exist
1670            JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1671                "name": "nonexistent_tool",
1672                "arguments": {}
1673            })),
1674            // Another valid request
1675            JsonRpcRequest::new(3, "ping"),
1676        ];
1677
1678        let responses = service.call_batch(requests).await.unwrap();
1679
1680        // All three requests should have responses (errors are not dropped)
1681        assert_eq!(responses.len(), 3);
1682
1683        // First should be success
1684        match &responses[0] {
1685            JsonRpcResponse::Result(r) => {
1686                assert_eq!(r.id, RequestId::Number(1));
1687            }
1688            JsonRpcResponse::Error(_) => panic!("Expected success for first request"),
1689        }
1690
1691        // Second should be an error (tool not found)
1692        match &responses[1] {
1693            JsonRpcResponse::Error(e) => {
1694                assert_eq!(e.id, Some(RequestId::Number(2)));
1695                // Error should indicate method not found
1696                assert!(e.error.message.contains("not found") || e.error.code == -32601);
1697            }
1698            JsonRpcResponse::Result(_) => panic!("Expected error for second request"),
1699        }
1700
1701        // Third should be success
1702        match &responses[2] {
1703            JsonRpcResponse::Result(r) => {
1704                assert_eq!(r.id, RequestId::Number(3));
1705            }
1706            JsonRpcResponse::Error(_) => panic!("Expected success for third request"),
1707        }
1708    }
1709
1710    // =========================================================================
1711    // Resource Template Tests
1712    // =========================================================================
1713
1714    #[tokio::test]
1715    async fn test_list_resource_templates() {
1716        use crate::resource::ResourceTemplateBuilder;
1717        use std::collections::HashMap;
1718
1719        let template = ResourceTemplateBuilder::new("file:///{path}")
1720            .name("Project Files")
1721            .description("Access project files")
1722            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1723                Ok(ReadResourceResult {
1724                    contents: vec![ResourceContent {
1725                        uri,
1726                        mime_type: None,
1727                        text: None,
1728                        blob: None,
1729                    }],
1730                })
1731            });
1732
1733        let mut router = McpRouter::new().resource_template(template);
1734
1735        // Initialize session
1736        init_router(&mut router).await;
1737
1738        let req = RouterRequest {
1739            id: RequestId::Number(1),
1740            inner: McpRequest::ListResourceTemplates(ListResourceTemplatesParams::default()),
1741            extensions: Extensions::new(),
1742        };
1743
1744        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1745
1746        match resp.inner {
1747            Ok(McpResponse::ListResourceTemplates(result)) => {
1748                assert_eq!(result.resource_templates.len(), 1);
1749                assert_eq!(result.resource_templates[0].uri_template, "file:///{path}");
1750                assert_eq!(result.resource_templates[0].name, "Project Files");
1751            }
1752            _ => panic!("Expected ListResourceTemplates response"),
1753        }
1754    }
1755
1756    #[tokio::test]
1757    async fn test_read_resource_via_template() {
1758        use crate::resource::ResourceTemplateBuilder;
1759        use std::collections::HashMap;
1760
1761        let template = ResourceTemplateBuilder::new("db://users/{id}")
1762            .name("User Records")
1763            .handler(|uri: String, vars: HashMap<String, String>| async move {
1764                let id = vars.get("id").unwrap().clone();
1765                Ok(ReadResourceResult {
1766                    contents: vec![ResourceContent {
1767                        uri,
1768                        mime_type: Some("application/json".to_string()),
1769                        text: Some(format!(r#"{{"id": "{}"}}"#, id)),
1770                        blob: None,
1771                    }],
1772                })
1773            });
1774
1775        let mut router = McpRouter::new().resource_template(template);
1776
1777        // Initialize session
1778        init_router(&mut router).await;
1779
1780        // Read a resource that matches the template
1781        let req = RouterRequest {
1782            id: RequestId::Number(1),
1783            inner: McpRequest::ReadResource(ReadResourceParams {
1784                uri: "db://users/123".to_string(),
1785            }),
1786            extensions: Extensions::new(),
1787        };
1788
1789        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1790
1791        match resp.inner {
1792            Ok(McpResponse::ReadResource(result)) => {
1793                assert_eq!(result.contents.len(), 1);
1794                assert_eq!(result.contents[0].uri, "db://users/123");
1795                assert!(result.contents[0].text.as_ref().unwrap().contains("123"));
1796            }
1797            _ => panic!("Expected ReadResource response"),
1798        }
1799    }
1800
1801    #[tokio::test]
1802    async fn test_static_resource_takes_precedence_over_template() {
1803        use crate::resource::{ResourceBuilder, ResourceTemplateBuilder};
1804        use std::collections::HashMap;
1805
1806        // Template that would match the same URI
1807        let template = ResourceTemplateBuilder::new("file:///{path}")
1808            .name("Files Template")
1809            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1810                Ok(ReadResourceResult {
1811                    contents: vec![ResourceContent {
1812                        uri,
1813                        mime_type: None,
1814                        text: Some("from template".to_string()),
1815                        blob: None,
1816                    }],
1817                })
1818            });
1819
1820        // Static resource with exact URI
1821        let static_resource = ResourceBuilder::new("file:///README.md")
1822            .name("README")
1823            .text("from static resource");
1824
1825        let mut router = McpRouter::new()
1826            .resource_template(template)
1827            .resource(static_resource);
1828
1829        // Initialize session
1830        init_router(&mut router).await;
1831
1832        // Read the static resource - should NOT go through template
1833        let req = RouterRequest {
1834            id: RequestId::Number(1),
1835            inner: McpRequest::ReadResource(ReadResourceParams {
1836                uri: "file:///README.md".to_string(),
1837            }),
1838            extensions: Extensions::new(),
1839        };
1840
1841        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1842
1843        match resp.inner {
1844            Ok(McpResponse::ReadResource(result)) => {
1845                // Should get static resource, not template
1846                assert_eq!(
1847                    result.contents[0].text.as_deref(),
1848                    Some("from static resource")
1849                );
1850            }
1851            _ => panic!("Expected ReadResource response"),
1852        }
1853    }
1854
1855    #[tokio::test]
1856    async fn test_resource_not_found_when_no_match() {
1857        use crate::resource::ResourceTemplateBuilder;
1858        use std::collections::HashMap;
1859
1860        let template = ResourceTemplateBuilder::new("db://users/{id}")
1861            .name("Users")
1862            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1863                Ok(ReadResourceResult {
1864                    contents: vec![ResourceContent {
1865                        uri,
1866                        mime_type: None,
1867                        text: None,
1868                        blob: None,
1869                    }],
1870                })
1871            });
1872
1873        let mut router = McpRouter::new().resource_template(template);
1874
1875        // Initialize session
1876        init_router(&mut router).await;
1877
1878        // Try to read a URI that doesn't match any resource or template
1879        let req = RouterRequest {
1880            id: RequestId::Number(1),
1881            inner: McpRequest::ReadResource(ReadResourceParams {
1882                uri: "db://posts/123".to_string(),
1883            }),
1884            extensions: Extensions::new(),
1885        };
1886
1887        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1888
1889        match resp.inner {
1890            Err(err) => {
1891                assert!(err.message.contains("not found"));
1892            }
1893            Ok(_) => panic!("Expected error for non-matching URI"),
1894        }
1895    }
1896
1897    #[tokio::test]
1898    async fn test_capabilities_include_resources_with_only_templates() {
1899        use crate::resource::ResourceTemplateBuilder;
1900        use std::collections::HashMap;
1901
1902        let template = ResourceTemplateBuilder::new("file:///{path}")
1903            .name("Files")
1904            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1905                Ok(ReadResourceResult {
1906                    contents: vec![ResourceContent {
1907                        uri,
1908                        mime_type: None,
1909                        text: None,
1910                        blob: None,
1911                    }],
1912                })
1913            });
1914
1915        let mut router = McpRouter::new().resource_template(template);
1916
1917        // Send initialize request and check capabilities
1918        let init_req = RouterRequest {
1919            id: RequestId::Number(0),
1920            inner: McpRequest::Initialize(InitializeParams {
1921                protocol_version: "2025-11-25".to_string(),
1922                capabilities: ClientCapabilities {
1923                    roots: None,
1924                    sampling: None,
1925                    elicitation: None,
1926                },
1927                client_info: Implementation {
1928                    name: "test".to_string(),
1929                    version: "1.0".to_string(),
1930                    ..Default::default()
1931                },
1932            }),
1933            extensions: Extensions::new(),
1934        };
1935        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1936
1937        match resp.inner {
1938            Ok(McpResponse::Initialize(result)) => {
1939                // Should have resources capability even though only templates registered
1940                assert!(result.capabilities.resources.is_some());
1941            }
1942            _ => panic!("Expected Initialize response"),
1943        }
1944    }
1945
1946    // =========================================================================
1947    // Logging Notification Tests
1948    // =========================================================================
1949
1950    #[tokio::test]
1951    async fn test_log_sends_notification() {
1952        use crate::context::notification_channel;
1953
1954        let (tx, mut rx) = notification_channel(10);
1955        let router = McpRouter::new().with_notification_sender(tx);
1956
1957        // Send an info log
1958        let sent = router.log_info("Test message");
1959        assert!(sent);
1960
1961        // Should receive the notification
1962        let notification = rx.try_recv().unwrap();
1963        match notification {
1964            ServerNotification::LogMessage(params) => {
1965                assert_eq!(params.level, LogLevel::Info);
1966                let data = params.data.unwrap();
1967                assert_eq!(
1968                    data.get("message").unwrap().as_str().unwrap(),
1969                    "Test message"
1970                );
1971            }
1972            _ => panic!("Expected LogMessage notification"),
1973        }
1974    }
1975
1976    #[tokio::test]
1977    async fn test_log_with_custom_params() {
1978        use crate::context::notification_channel;
1979
1980        let (tx, mut rx) = notification_channel(10);
1981        let router = McpRouter::new().with_notification_sender(tx);
1982
1983        // Send a custom log message
1984        let params = LoggingMessageParams::new(LogLevel::Error)
1985            .with_logger("database")
1986            .with_data(serde_json::json!({
1987                "error": "Connection failed",
1988                "host": "localhost"
1989            }));
1990
1991        let sent = router.log(params);
1992        assert!(sent);
1993
1994        let notification = rx.try_recv().unwrap();
1995        match notification {
1996            ServerNotification::LogMessage(params) => {
1997                assert_eq!(params.level, LogLevel::Error);
1998                assert_eq!(params.logger.as_deref(), Some("database"));
1999                let data = params.data.unwrap();
2000                assert_eq!(
2001                    data.get("error").unwrap().as_str().unwrap(),
2002                    "Connection failed"
2003                );
2004            }
2005            _ => panic!("Expected LogMessage notification"),
2006        }
2007    }
2008
2009    #[tokio::test]
2010    async fn test_log_without_channel_returns_false() {
2011        // Router without notification channel
2012        let router = McpRouter::new();
2013
2014        // Should return false when no channel configured
2015        assert!(!router.log_info("Test"));
2016        assert!(!router.log_warning("Test"));
2017        assert!(!router.log_error("Test"));
2018        assert!(!router.log_debug("Test"));
2019    }
2020
2021    #[tokio::test]
2022    async fn test_logging_capability_with_channel() {
2023        use crate::context::notification_channel;
2024
2025        let (tx, _rx) = notification_channel(10);
2026        let mut router = McpRouter::new().with_notification_sender(tx);
2027
2028        // Initialize and check capabilities
2029        let init_req = RouterRequest {
2030            id: RequestId::Number(0),
2031            inner: McpRequest::Initialize(InitializeParams {
2032                protocol_version: "2025-11-25".to_string(),
2033                capabilities: ClientCapabilities {
2034                    roots: None,
2035                    sampling: None,
2036                    elicitation: None,
2037                },
2038                client_info: Implementation {
2039                    name: "test".to_string(),
2040                    version: "1.0".to_string(),
2041                    ..Default::default()
2042                },
2043            }),
2044            extensions: Extensions::new(),
2045        };
2046        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2047
2048        match resp.inner {
2049            Ok(McpResponse::Initialize(result)) => {
2050                // Should have logging capability when notification channel is set
2051                assert!(result.capabilities.logging.is_some());
2052            }
2053            _ => panic!("Expected Initialize response"),
2054        }
2055    }
2056
2057    #[tokio::test]
2058    async fn test_no_logging_capability_without_channel() {
2059        let mut router = McpRouter::new();
2060
2061        // Initialize and check capabilities
2062        let init_req = RouterRequest {
2063            id: RequestId::Number(0),
2064            inner: McpRequest::Initialize(InitializeParams {
2065                protocol_version: "2025-11-25".to_string(),
2066                capabilities: ClientCapabilities {
2067                    roots: None,
2068                    sampling: None,
2069                    elicitation: None,
2070                },
2071                client_info: Implementation {
2072                    name: "test".to_string(),
2073                    version: "1.0".to_string(),
2074                    ..Default::default()
2075                },
2076            }),
2077            extensions: Extensions::new(),
2078        };
2079        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2080
2081        match resp.inner {
2082            Ok(McpResponse::Initialize(result)) => {
2083                // Should NOT have logging capability without notification channel
2084                assert!(result.capabilities.logging.is_none());
2085            }
2086            _ => panic!("Expected Initialize response"),
2087        }
2088    }
2089
2090    // =========================================================================
2091    // Task Lifecycle Tests
2092    // =========================================================================
2093
2094    #[tokio::test]
2095    async fn test_enqueue_task() {
2096        let add_tool = ToolBuilder::new("add")
2097            .description("Add two numbers")
2098            .handler(|input: AddInput| async move {
2099                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2100            })
2101            .build()
2102            .expect("valid tool name");
2103
2104        let mut router = McpRouter::new().tool(add_tool);
2105        init_router(&mut router).await;
2106
2107        let req = RouterRequest {
2108            id: RequestId::Number(1),
2109            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2110                tool_name: "add".to_string(),
2111                arguments: serde_json::json!({"a": 5, "b": 10}),
2112                ttl: None,
2113            }),
2114            extensions: Extensions::new(),
2115        };
2116
2117        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2118
2119        match resp.inner {
2120            Ok(McpResponse::EnqueueTask(result)) => {
2121                assert!(result.task_id.starts_with("task-"));
2122                assert_eq!(result.status, TaskStatus::Working);
2123            }
2124            _ => panic!("Expected EnqueueTask response"),
2125        }
2126    }
2127
2128    #[tokio::test]
2129    async fn test_list_tasks_empty() {
2130        let mut router = McpRouter::new();
2131        init_router(&mut router).await;
2132
2133        let req = RouterRequest {
2134            id: RequestId::Number(1),
2135            inner: McpRequest::ListTasks(ListTasksParams::default()),
2136            extensions: Extensions::new(),
2137        };
2138
2139        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2140
2141        match resp.inner {
2142            Ok(McpResponse::ListTasks(result)) => {
2143                assert!(result.tasks.is_empty());
2144            }
2145            _ => panic!("Expected ListTasks response"),
2146        }
2147    }
2148
2149    #[tokio::test]
2150    async fn test_task_lifecycle_complete() {
2151        let add_tool = ToolBuilder::new("add")
2152            .description("Add two numbers")
2153            .handler(|input: AddInput| async move {
2154                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2155            })
2156            .build()
2157            .expect("valid tool name");
2158
2159        let mut router = McpRouter::new().tool(add_tool);
2160        init_router(&mut router).await;
2161
2162        // Enqueue task
2163        let req = RouterRequest {
2164            id: RequestId::Number(1),
2165            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2166                tool_name: "add".to_string(),
2167                arguments: serde_json::json!({"a": 7, "b": 8}),
2168                ttl: None,
2169            }),
2170            extensions: Extensions::new(),
2171        };
2172
2173        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2174        let task_id = match resp.inner {
2175            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2176            _ => panic!("Expected EnqueueTask response"),
2177        };
2178
2179        // Wait for task to complete
2180        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
2181
2182        // Get task result
2183        let req = RouterRequest {
2184            id: RequestId::Number(2),
2185            inner: McpRequest::GetTaskResult(GetTaskResultParams {
2186                task_id: task_id.clone(),
2187            }),
2188            extensions: Extensions::new(),
2189        };
2190
2191        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2192
2193        match resp.inner {
2194            Ok(McpResponse::GetTaskResult(result)) => {
2195                assert_eq!(result.task_id, task_id);
2196                assert_eq!(result.status, TaskStatus::Completed);
2197                assert!(result.result.is_some());
2198                assert!(result.error.is_none());
2199
2200                // Check the result content
2201                let tool_result = result.result.unwrap();
2202                match &tool_result.content[0] {
2203                    Content::Text { text, .. } => assert_eq!(text, "15"),
2204                    _ => panic!("Expected text content"),
2205                }
2206            }
2207            _ => panic!("Expected GetTaskResult response"),
2208        }
2209    }
2210
2211    #[tokio::test]
2212    async fn test_task_cancellation() {
2213        // Use a slow tool to test cancellation
2214        let slow_tool = ToolBuilder::new("slow")
2215            .description("Slow tool")
2216            .handler(|_input: serde_json::Value| async move {
2217                tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
2218                Ok(CallToolResult::text("done"))
2219            })
2220            .build()
2221            .expect("valid tool name");
2222
2223        let mut router = McpRouter::new().tool(slow_tool);
2224        init_router(&mut router).await;
2225
2226        // Enqueue task
2227        let req = RouterRequest {
2228            id: RequestId::Number(1),
2229            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2230                tool_name: "slow".to_string(),
2231                arguments: serde_json::json!({}),
2232                ttl: None,
2233            }),
2234            extensions: Extensions::new(),
2235        };
2236
2237        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2238        let task_id = match resp.inner {
2239            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2240            _ => panic!("Expected EnqueueTask response"),
2241        };
2242
2243        // Cancel the task
2244        let req = RouterRequest {
2245            id: RequestId::Number(2),
2246            inner: McpRequest::CancelTask(CancelTaskParams {
2247                task_id: task_id.clone(),
2248                reason: Some("Test cancellation".to_string()),
2249            }),
2250            extensions: Extensions::new(),
2251        };
2252
2253        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2254
2255        match resp.inner {
2256            Ok(McpResponse::CancelTask(result)) => {
2257                assert!(result.cancelled);
2258                assert_eq!(result.status, TaskStatus::Cancelled);
2259            }
2260            _ => panic!("Expected CancelTask response"),
2261        }
2262    }
2263
2264    #[tokio::test]
2265    async fn test_get_task_info() {
2266        let add_tool = ToolBuilder::new("add")
2267            .description("Add two numbers")
2268            .handler(|input: AddInput| async move {
2269                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2270            })
2271            .build()
2272            .expect("valid tool name");
2273
2274        let mut router = McpRouter::new().tool(add_tool);
2275        init_router(&mut router).await;
2276
2277        // Enqueue task
2278        let req = RouterRequest {
2279            id: RequestId::Number(1),
2280            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2281                tool_name: "add".to_string(),
2282                arguments: serde_json::json!({"a": 1, "b": 2}),
2283                ttl: Some(600),
2284            }),
2285            extensions: Extensions::new(),
2286        };
2287
2288        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2289        let task_id = match resp.inner {
2290            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2291            _ => panic!("Expected EnqueueTask response"),
2292        };
2293
2294        // Get task info
2295        let req = RouterRequest {
2296            id: RequestId::Number(2),
2297            inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2298                task_id: task_id.clone(),
2299            }),
2300            extensions: Extensions::new(),
2301        };
2302
2303        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2304
2305        match resp.inner {
2306            Ok(McpResponse::GetTaskInfo(info)) => {
2307                assert_eq!(info.task_id, task_id);
2308                assert!(info.created_at.contains('T')); // ISO 8601
2309                assert_eq!(info.ttl, Some(600));
2310            }
2311            _ => panic!("Expected GetTaskInfo response"),
2312        }
2313    }
2314
2315    #[tokio::test]
2316    async fn test_enqueue_nonexistent_tool() {
2317        let mut router = McpRouter::new();
2318        init_router(&mut router).await;
2319
2320        let req = RouterRequest {
2321            id: RequestId::Number(1),
2322            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2323                tool_name: "nonexistent".to_string(),
2324                arguments: serde_json::json!({}),
2325                ttl: None,
2326            }),
2327            extensions: Extensions::new(),
2328        };
2329
2330        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2331
2332        match resp.inner {
2333            Err(e) => {
2334                assert!(e.message.contains("not found"));
2335            }
2336            _ => panic!("Expected error response"),
2337        }
2338    }
2339
2340    #[tokio::test]
2341    async fn test_get_nonexistent_task() {
2342        let mut router = McpRouter::new();
2343        init_router(&mut router).await;
2344
2345        let req = RouterRequest {
2346            id: RequestId::Number(1),
2347            inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2348                task_id: "task-999".to_string(),
2349            }),
2350            extensions: Extensions::new(),
2351        };
2352
2353        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2354
2355        match resp.inner {
2356            Err(e) => {
2357                assert!(e.message.contains("not found"));
2358            }
2359            _ => panic!("Expected error response"),
2360        }
2361    }
2362
2363    // =========================================================================
2364    // Resource Subscription Tests
2365    // =========================================================================
2366
2367    #[tokio::test]
2368    async fn test_subscribe_to_resource() {
2369        use crate::resource::ResourceBuilder;
2370
2371        let resource = ResourceBuilder::new("file:///test.txt")
2372            .name("Test File")
2373            .text("Hello");
2374
2375        let mut router = McpRouter::new().resource(resource);
2376        init_router(&mut router).await;
2377
2378        // Subscribe to the resource
2379        let req = RouterRequest {
2380            id: RequestId::Number(1),
2381            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2382                uri: "file:///test.txt".to_string(),
2383            }),
2384            extensions: Extensions::new(),
2385        };
2386
2387        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2388
2389        match resp.inner {
2390            Ok(McpResponse::SubscribeResource(_)) => {
2391                // Should be subscribed now
2392                assert!(router.is_subscribed("file:///test.txt"));
2393            }
2394            _ => panic!("Expected SubscribeResource response"),
2395        }
2396    }
2397
2398    #[tokio::test]
2399    async fn test_unsubscribe_from_resource() {
2400        use crate::resource::ResourceBuilder;
2401
2402        let resource = ResourceBuilder::new("file:///test.txt")
2403            .name("Test File")
2404            .text("Hello");
2405
2406        let mut router = McpRouter::new().resource(resource);
2407        init_router(&mut router).await;
2408
2409        // Subscribe first
2410        let req = RouterRequest {
2411            id: RequestId::Number(1),
2412            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2413                uri: "file:///test.txt".to_string(),
2414            }),
2415            extensions: Extensions::new(),
2416        };
2417        let _ = router.ready().await.unwrap().call(req).await.unwrap();
2418        assert!(router.is_subscribed("file:///test.txt"));
2419
2420        // Now unsubscribe
2421        let req = RouterRequest {
2422            id: RequestId::Number(2),
2423            inner: McpRequest::UnsubscribeResource(UnsubscribeResourceParams {
2424                uri: "file:///test.txt".to_string(),
2425            }),
2426            extensions: Extensions::new(),
2427        };
2428
2429        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2430
2431        match resp.inner {
2432            Ok(McpResponse::UnsubscribeResource(_)) => {
2433                // Should no longer be subscribed
2434                assert!(!router.is_subscribed("file:///test.txt"));
2435            }
2436            _ => panic!("Expected UnsubscribeResource response"),
2437        }
2438    }
2439
2440    #[tokio::test]
2441    async fn test_subscribe_nonexistent_resource() {
2442        let mut router = McpRouter::new();
2443        init_router(&mut router).await;
2444
2445        let req = RouterRequest {
2446            id: RequestId::Number(1),
2447            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2448                uri: "file:///nonexistent.txt".to_string(),
2449            }),
2450            extensions: Extensions::new(),
2451        };
2452
2453        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2454
2455        match resp.inner {
2456            Err(e) => {
2457                assert!(e.message.contains("not found"));
2458            }
2459            _ => panic!("Expected error response"),
2460        }
2461    }
2462
2463    #[tokio::test]
2464    async fn test_notify_resource_updated() {
2465        use crate::context::notification_channel;
2466        use crate::resource::ResourceBuilder;
2467
2468        let (tx, mut rx) = notification_channel(10);
2469
2470        let resource = ResourceBuilder::new("file:///test.txt")
2471            .name("Test File")
2472            .text("Hello");
2473
2474        let router = McpRouter::new()
2475            .resource(resource)
2476            .with_notification_sender(tx);
2477
2478        // First, manually subscribe (simulate subscription)
2479        router.subscribe("file:///test.txt");
2480
2481        // Now notify
2482        let sent = router.notify_resource_updated("file:///test.txt");
2483        assert!(sent);
2484
2485        // Check the notification was sent
2486        let notification = rx.try_recv().unwrap();
2487        match notification {
2488            ServerNotification::ResourceUpdated { uri } => {
2489                assert_eq!(uri, "file:///test.txt");
2490            }
2491            _ => panic!("Expected ResourceUpdated notification"),
2492        }
2493    }
2494
2495    #[tokio::test]
2496    async fn test_notify_resource_updated_not_subscribed() {
2497        use crate::context::notification_channel;
2498        use crate::resource::ResourceBuilder;
2499
2500        let (tx, mut rx) = notification_channel(10);
2501
2502        let resource = ResourceBuilder::new("file:///test.txt")
2503            .name("Test File")
2504            .text("Hello");
2505
2506        let router = McpRouter::new()
2507            .resource(resource)
2508            .with_notification_sender(tx);
2509
2510        // Try to notify without subscribing
2511        let sent = router.notify_resource_updated("file:///test.txt");
2512        assert!(!sent); // Should not send because not subscribed
2513
2514        // Channel should be empty
2515        assert!(rx.try_recv().is_err());
2516    }
2517
2518    #[tokio::test]
2519    async fn test_notify_resources_list_changed() {
2520        use crate::context::notification_channel;
2521
2522        let (tx, mut rx) = notification_channel(10);
2523        let router = McpRouter::new().with_notification_sender(tx);
2524
2525        let sent = router.notify_resources_list_changed();
2526        assert!(sent);
2527
2528        let notification = rx.try_recv().unwrap();
2529        match notification {
2530            ServerNotification::ResourcesListChanged => {}
2531            _ => panic!("Expected ResourcesListChanged notification"),
2532        }
2533    }
2534
2535    #[tokio::test]
2536    async fn test_subscribed_uris() {
2537        use crate::resource::ResourceBuilder;
2538
2539        let resource1 = ResourceBuilder::new("file:///a.txt").name("A").text("A");
2540
2541        let resource2 = ResourceBuilder::new("file:///b.txt").name("B").text("B");
2542
2543        let router = McpRouter::new().resource(resource1).resource(resource2);
2544
2545        // Subscribe to both
2546        router.subscribe("file:///a.txt");
2547        router.subscribe("file:///b.txt");
2548
2549        let uris = router.subscribed_uris();
2550        assert_eq!(uris.len(), 2);
2551        assert!(uris.contains(&"file:///a.txt".to_string()));
2552        assert!(uris.contains(&"file:///b.txt".to_string()));
2553    }
2554
2555    #[tokio::test]
2556    async fn test_subscription_capability_advertised() {
2557        use crate::resource::ResourceBuilder;
2558
2559        let resource = ResourceBuilder::new("file:///test.txt")
2560            .name("Test")
2561            .text("Hello");
2562
2563        let mut router = McpRouter::new().resource(resource);
2564
2565        // Initialize and check capabilities
2566        let init_req = RouterRequest {
2567            id: RequestId::Number(0),
2568            inner: McpRequest::Initialize(InitializeParams {
2569                protocol_version: "2025-11-25".to_string(),
2570                capabilities: ClientCapabilities {
2571                    roots: None,
2572                    sampling: None,
2573                    elicitation: None,
2574                },
2575                client_info: Implementation {
2576                    name: "test".to_string(),
2577                    version: "1.0".to_string(),
2578                    ..Default::default()
2579                },
2580            }),
2581            extensions: Extensions::new(),
2582        };
2583        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2584
2585        match resp.inner {
2586            Ok(McpResponse::Initialize(result)) => {
2587                // Should have resources capability with subscribe enabled
2588                let resources_cap = result.capabilities.resources.unwrap();
2589                assert!(resources_cap.subscribe);
2590            }
2591            _ => panic!("Expected Initialize response"),
2592        }
2593    }
2594
2595    #[tokio::test]
2596    async fn test_completion_handler() {
2597        let router = McpRouter::new()
2598            .server_info("test", "1.0")
2599            .completion_handler(|params: CompleteParams| async move {
2600                // Return suggestions based on the argument value
2601                let prefix = &params.argument.value;
2602                let suggestions: Vec<String> = vec!["alpha", "beta", "gamma"]
2603                    .into_iter()
2604                    .filter(|s| s.starts_with(prefix))
2605                    .map(String::from)
2606                    .collect();
2607                Ok(CompleteResult::new(suggestions))
2608            });
2609
2610        // Initialize
2611        let init_req = RouterRequest {
2612            id: RequestId::Number(0),
2613            inner: McpRequest::Initialize(InitializeParams {
2614                protocol_version: "2025-11-25".to_string(),
2615                capabilities: ClientCapabilities::default(),
2616                client_info: Implementation {
2617                    name: "test".to_string(),
2618                    version: "1.0".to_string(),
2619                    ..Default::default()
2620                },
2621            }),
2622            extensions: Extensions::new(),
2623        };
2624        let resp = router
2625            .clone()
2626            .ready()
2627            .await
2628            .unwrap()
2629            .call(init_req)
2630            .await
2631            .unwrap();
2632
2633        // Check that completions capability is advertised
2634        match resp.inner {
2635            Ok(McpResponse::Initialize(result)) => {
2636                assert!(result.capabilities.completions.is_some());
2637            }
2638            _ => panic!("Expected Initialize response"),
2639        }
2640
2641        // Send initialized notification
2642        router.handle_notification(McpNotification::Initialized);
2643
2644        // Test completion request
2645        let complete_req = RouterRequest {
2646            id: RequestId::Number(1),
2647            inner: McpRequest::Complete(CompleteParams {
2648                reference: CompletionReference::prompt("test-prompt"),
2649                argument: CompletionArgument::new("query", "al"),
2650            }),
2651            extensions: Extensions::new(),
2652        };
2653        let resp = router
2654            .clone()
2655            .ready()
2656            .await
2657            .unwrap()
2658            .call(complete_req)
2659            .await
2660            .unwrap();
2661
2662        match resp.inner {
2663            Ok(McpResponse::Complete(result)) => {
2664                assert_eq!(result.completion.values, vec!["alpha"]);
2665            }
2666            _ => panic!("Expected Complete response"),
2667        }
2668    }
2669
2670    #[tokio::test]
2671    async fn test_completion_without_handler_returns_empty() {
2672        let router = McpRouter::new().server_info("test", "1.0");
2673
2674        // Initialize
2675        let init_req = RouterRequest {
2676            id: RequestId::Number(0),
2677            inner: McpRequest::Initialize(InitializeParams {
2678                protocol_version: "2025-11-25".to_string(),
2679                capabilities: ClientCapabilities::default(),
2680                client_info: Implementation {
2681                    name: "test".to_string(),
2682                    version: "1.0".to_string(),
2683                    ..Default::default()
2684                },
2685            }),
2686            extensions: Extensions::new(),
2687        };
2688        let resp = router
2689            .clone()
2690            .ready()
2691            .await
2692            .unwrap()
2693            .call(init_req)
2694            .await
2695            .unwrap();
2696
2697        // Check that completions capability is NOT advertised
2698        match resp.inner {
2699            Ok(McpResponse::Initialize(result)) => {
2700                assert!(result.capabilities.completions.is_none());
2701            }
2702            _ => panic!("Expected Initialize response"),
2703        }
2704
2705        // Send initialized notification
2706        router.handle_notification(McpNotification::Initialized);
2707
2708        // Test completion request still works but returns empty
2709        let complete_req = RouterRequest {
2710            id: RequestId::Number(1),
2711            inner: McpRequest::Complete(CompleteParams {
2712                reference: CompletionReference::prompt("test-prompt"),
2713                argument: CompletionArgument::new("query", "al"),
2714            }),
2715            extensions: Extensions::new(),
2716        };
2717        let resp = router
2718            .clone()
2719            .ready()
2720            .await
2721            .unwrap()
2722            .call(complete_req)
2723            .await
2724            .unwrap();
2725
2726        match resp.inner {
2727            Ok(McpResponse::Complete(result)) => {
2728                assert!(result.completion.values.is_empty());
2729            }
2730            _ => panic!("Expected Complete response"),
2731        }
2732    }
2733
2734    #[tokio::test]
2735    async fn test_tool_filter_list() {
2736        use crate::filter::CapabilityFilter;
2737        use crate::tool::Tool;
2738
2739        let public_tool = ToolBuilder::new("public")
2740            .description("Public tool")
2741            .handler(|_: AddInput| async move { Ok(CallToolResult::text("public")) })
2742            .build()
2743            .expect("valid tool name");
2744
2745        let admin_tool = ToolBuilder::new("admin")
2746            .description("Admin tool")
2747            .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
2748            .build()
2749            .expect("valid tool name");
2750
2751        let mut router = McpRouter::new()
2752            .tool(public_tool)
2753            .tool(admin_tool)
2754            .tool_filter(CapabilityFilter::new(|_, tool: &Tool| tool.name != "admin"));
2755
2756        // Initialize session
2757        init_router(&mut router).await;
2758
2759        let req = RouterRequest {
2760            id: RequestId::Number(1),
2761            inner: McpRequest::ListTools(ListToolsParams::default()),
2762            extensions: Extensions::new(),
2763        };
2764
2765        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2766
2767        match resp.inner {
2768            Ok(McpResponse::ListTools(result)) => {
2769                // Only public tool should be visible
2770                assert_eq!(result.tools.len(), 1);
2771                assert_eq!(result.tools[0].name, "public");
2772            }
2773            _ => panic!("Expected ListTools response"),
2774        }
2775    }
2776
2777    #[tokio::test]
2778    async fn test_tool_filter_call_denied() {
2779        use crate::filter::CapabilityFilter;
2780        use crate::tool::Tool;
2781
2782        let admin_tool = ToolBuilder::new("admin")
2783            .description("Admin tool")
2784            .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
2785            .build()
2786            .expect("valid tool name");
2787
2788        let mut router = McpRouter::new()
2789            .tool(admin_tool)
2790            .tool_filter(CapabilityFilter::new(|_, _: &Tool| false)); // Deny all
2791
2792        // Initialize session
2793        init_router(&mut router).await;
2794
2795        let req = RouterRequest {
2796            id: RequestId::Number(1),
2797            inner: McpRequest::CallTool(CallToolParams {
2798                name: "admin".to_string(),
2799                arguments: serde_json::json!({"a": 1, "b": 2}),
2800                meta: None,
2801            }),
2802            extensions: Extensions::new(),
2803        };
2804
2805        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2806
2807        // Should get method not found error (default denial behavior)
2808        match resp.inner {
2809            Err(e) => {
2810                assert_eq!(e.code, -32601); // Method not found
2811            }
2812            _ => panic!("Expected JsonRpc error"),
2813        }
2814    }
2815
2816    #[tokio::test]
2817    async fn test_tool_filter_call_allowed() {
2818        use crate::filter::CapabilityFilter;
2819        use crate::tool::Tool;
2820
2821        let public_tool = ToolBuilder::new("public")
2822            .description("Public tool")
2823            .handler(|input: AddInput| async move {
2824                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2825            })
2826            .build()
2827            .expect("valid tool name");
2828
2829        let mut router = McpRouter::new()
2830            .tool(public_tool)
2831            .tool_filter(CapabilityFilter::new(|_, _: &Tool| true)); // Allow all
2832
2833        // Initialize session
2834        init_router(&mut router).await;
2835
2836        let req = RouterRequest {
2837            id: RequestId::Number(1),
2838            inner: McpRequest::CallTool(CallToolParams {
2839                name: "public".to_string(),
2840                arguments: serde_json::json!({"a": 1, "b": 2}),
2841                meta: None,
2842            }),
2843            extensions: Extensions::new(),
2844        };
2845
2846        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2847
2848        match resp.inner {
2849            Ok(McpResponse::CallTool(result)) => {
2850                assert!(!result.is_error);
2851            }
2852            _ => panic!("Expected CallTool response"),
2853        }
2854    }
2855
2856    #[tokio::test]
2857    async fn test_tool_filter_custom_denial() {
2858        use crate::filter::{CapabilityFilter, DenialBehavior};
2859        use crate::tool::Tool;
2860
2861        let admin_tool = ToolBuilder::new("admin")
2862            .description("Admin tool")
2863            .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
2864            .build()
2865            .expect("valid tool name");
2866
2867        let mut router = McpRouter::new().tool(admin_tool).tool_filter(
2868            CapabilityFilter::new(|_, _: &Tool| false)
2869                .denial_behavior(DenialBehavior::Unauthorized),
2870        );
2871
2872        // Initialize session
2873        init_router(&mut router).await;
2874
2875        let req = RouterRequest {
2876            id: RequestId::Number(1),
2877            inner: McpRequest::CallTool(CallToolParams {
2878                name: "admin".to_string(),
2879                arguments: serde_json::json!({"a": 1, "b": 2}),
2880                meta: None,
2881            }),
2882            extensions: Extensions::new(),
2883        };
2884
2885        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2886
2887        // Should get forbidden error
2888        match resp.inner {
2889            Err(e) => {
2890                assert_eq!(e.code, -32007); // Forbidden
2891                assert!(e.message.contains("Unauthorized"));
2892            }
2893            _ => panic!("Expected JsonRpc error"),
2894        }
2895    }
2896
2897    #[tokio::test]
2898    async fn test_resource_filter_list() {
2899        use crate::filter::CapabilityFilter;
2900        use crate::resource::{Resource, ResourceBuilder};
2901
2902        let public_resource = ResourceBuilder::new("file:///public.txt")
2903            .name("Public File")
2904            .text("public content");
2905
2906        let secret_resource = ResourceBuilder::new("file:///secret.txt")
2907            .name("Secret File")
2908            .text("secret content");
2909
2910        let mut router = McpRouter::new()
2911            .resource(public_resource)
2912            .resource(secret_resource)
2913            .resource_filter(CapabilityFilter::new(|_, r: &Resource| {
2914                !r.name.contains("Secret")
2915            }));
2916
2917        // Initialize session
2918        init_router(&mut router).await;
2919
2920        let req = RouterRequest {
2921            id: RequestId::Number(1),
2922            inner: McpRequest::ListResources(ListResourcesParams::default()),
2923            extensions: Extensions::new(),
2924        };
2925
2926        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2927
2928        match resp.inner {
2929            Ok(McpResponse::ListResources(result)) => {
2930                // Should only see public resource
2931                assert_eq!(result.resources.len(), 1);
2932                assert_eq!(result.resources[0].name, "Public File");
2933            }
2934            _ => panic!("Expected ListResources response"),
2935        }
2936    }
2937
2938    #[tokio::test]
2939    async fn test_resource_filter_read_denied() {
2940        use crate::filter::CapabilityFilter;
2941        use crate::resource::{Resource, ResourceBuilder};
2942
2943        let secret_resource = ResourceBuilder::new("file:///secret.txt")
2944            .name("Secret File")
2945            .text("secret content");
2946
2947        let mut router = McpRouter::new()
2948            .resource(secret_resource)
2949            .resource_filter(CapabilityFilter::new(|_, _: &Resource| false)); // Deny all
2950
2951        // Initialize session
2952        init_router(&mut router).await;
2953
2954        let req = RouterRequest {
2955            id: RequestId::Number(1),
2956            inner: McpRequest::ReadResource(ReadResourceParams {
2957                uri: "file:///secret.txt".to_string(),
2958            }),
2959            extensions: Extensions::new(),
2960        };
2961
2962        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2963
2964        // Should get method not found error (default denial behavior)
2965        match resp.inner {
2966            Err(e) => {
2967                assert_eq!(e.code, -32601); // Method not found
2968            }
2969            _ => panic!("Expected JsonRpc error"),
2970        }
2971    }
2972
2973    #[tokio::test]
2974    async fn test_resource_filter_read_allowed() {
2975        use crate::filter::CapabilityFilter;
2976        use crate::resource::{Resource, ResourceBuilder};
2977
2978        let public_resource = ResourceBuilder::new("file:///public.txt")
2979            .name("Public File")
2980            .text("public content");
2981
2982        let mut router = McpRouter::new()
2983            .resource(public_resource)
2984            .resource_filter(CapabilityFilter::new(|_, _: &Resource| true)); // Allow all
2985
2986        // Initialize session
2987        init_router(&mut router).await;
2988
2989        let req = RouterRequest {
2990            id: RequestId::Number(1),
2991            inner: McpRequest::ReadResource(ReadResourceParams {
2992                uri: "file:///public.txt".to_string(),
2993            }),
2994            extensions: Extensions::new(),
2995        };
2996
2997        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2998
2999        match resp.inner {
3000            Ok(McpResponse::ReadResource(result)) => {
3001                assert_eq!(result.contents.len(), 1);
3002                assert_eq!(result.contents[0].text.as_deref(), Some("public content"));
3003            }
3004            _ => panic!("Expected ReadResource response"),
3005        }
3006    }
3007
3008    #[tokio::test]
3009    async fn test_resource_filter_custom_denial() {
3010        use crate::filter::{CapabilityFilter, DenialBehavior};
3011        use crate::resource::{Resource, ResourceBuilder};
3012
3013        let secret_resource = ResourceBuilder::new("file:///secret.txt")
3014            .name("Secret File")
3015            .text("secret content");
3016
3017        let mut router = McpRouter::new().resource(secret_resource).resource_filter(
3018            CapabilityFilter::new(|_, _: &Resource| false)
3019                .denial_behavior(DenialBehavior::Unauthorized),
3020        );
3021
3022        // Initialize session
3023        init_router(&mut router).await;
3024
3025        let req = RouterRequest {
3026            id: RequestId::Number(1),
3027            inner: McpRequest::ReadResource(ReadResourceParams {
3028                uri: "file:///secret.txt".to_string(),
3029            }),
3030            extensions: Extensions::new(),
3031        };
3032
3033        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3034
3035        // Should get forbidden error
3036        match resp.inner {
3037            Err(e) => {
3038                assert_eq!(e.code, -32007); // Forbidden
3039                assert!(e.message.contains("Unauthorized"));
3040            }
3041            _ => panic!("Expected JsonRpc error"),
3042        }
3043    }
3044
3045    #[tokio::test]
3046    async fn test_prompt_filter_list() {
3047        use crate::filter::CapabilityFilter;
3048        use crate::prompt::{Prompt, PromptBuilder};
3049
3050        let public_prompt = PromptBuilder::new("greeting")
3051            .description("A greeting")
3052            .user_message("Hello!");
3053
3054        let admin_prompt = PromptBuilder::new("system_debug")
3055            .description("Admin prompt")
3056            .user_message("Debug");
3057
3058        let mut router = McpRouter::new()
3059            .prompt(public_prompt)
3060            .prompt(admin_prompt)
3061            .prompt_filter(CapabilityFilter::new(|_, p: &Prompt| {
3062                !p.name.contains("system")
3063            }));
3064
3065        // Initialize session
3066        init_router(&mut router).await;
3067
3068        let req = RouterRequest {
3069            id: RequestId::Number(1),
3070            inner: McpRequest::ListPrompts(ListPromptsParams::default()),
3071            extensions: Extensions::new(),
3072        };
3073
3074        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3075
3076        match resp.inner {
3077            Ok(McpResponse::ListPrompts(result)) => {
3078                // Should only see public prompt
3079                assert_eq!(result.prompts.len(), 1);
3080                assert_eq!(result.prompts[0].name, "greeting");
3081            }
3082            _ => panic!("Expected ListPrompts response"),
3083        }
3084    }
3085
3086    #[tokio::test]
3087    async fn test_prompt_filter_get_denied() {
3088        use crate::filter::CapabilityFilter;
3089        use crate::prompt::{Prompt, PromptBuilder};
3090        use std::collections::HashMap;
3091
3092        let admin_prompt = PromptBuilder::new("system_debug")
3093            .description("Admin prompt")
3094            .user_message("Debug");
3095
3096        let mut router = McpRouter::new()
3097            .prompt(admin_prompt)
3098            .prompt_filter(CapabilityFilter::new(|_, _: &Prompt| false)); // Deny all
3099
3100        // Initialize session
3101        init_router(&mut router).await;
3102
3103        let req = RouterRequest {
3104            id: RequestId::Number(1),
3105            inner: McpRequest::GetPrompt(GetPromptParams {
3106                name: "system_debug".to_string(),
3107                arguments: HashMap::new(),
3108            }),
3109            extensions: Extensions::new(),
3110        };
3111
3112        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3113
3114        // Should get method not found error (default denial behavior)
3115        match resp.inner {
3116            Err(e) => {
3117                assert_eq!(e.code, -32601); // Method not found
3118            }
3119            _ => panic!("Expected JsonRpc error"),
3120        }
3121    }
3122
3123    #[tokio::test]
3124    async fn test_prompt_filter_get_allowed() {
3125        use crate::filter::CapabilityFilter;
3126        use crate::prompt::{Prompt, PromptBuilder};
3127        use std::collections::HashMap;
3128
3129        let public_prompt = PromptBuilder::new("greeting")
3130            .description("A greeting")
3131            .user_message("Hello!");
3132
3133        let mut router = McpRouter::new()
3134            .prompt(public_prompt)
3135            .prompt_filter(CapabilityFilter::new(|_, _: &Prompt| true)); // Allow all
3136
3137        // Initialize session
3138        init_router(&mut router).await;
3139
3140        let req = RouterRequest {
3141            id: RequestId::Number(1),
3142            inner: McpRequest::GetPrompt(GetPromptParams {
3143                name: "greeting".to_string(),
3144                arguments: HashMap::new(),
3145            }),
3146            extensions: Extensions::new(),
3147        };
3148
3149        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3150
3151        match resp.inner {
3152            Ok(McpResponse::GetPrompt(result)) => {
3153                assert_eq!(result.messages.len(), 1);
3154            }
3155            _ => panic!("Expected GetPrompt response"),
3156        }
3157    }
3158
3159    #[tokio::test]
3160    async fn test_prompt_filter_custom_denial() {
3161        use crate::filter::{CapabilityFilter, DenialBehavior};
3162        use crate::prompt::{Prompt, PromptBuilder};
3163        use std::collections::HashMap;
3164
3165        let admin_prompt = PromptBuilder::new("system_debug")
3166            .description("Admin prompt")
3167            .user_message("Debug");
3168
3169        let mut router = McpRouter::new().prompt(admin_prompt).prompt_filter(
3170            CapabilityFilter::new(|_, _: &Prompt| false)
3171                .denial_behavior(DenialBehavior::Unauthorized),
3172        );
3173
3174        // Initialize session
3175        init_router(&mut router).await;
3176
3177        let req = RouterRequest {
3178            id: RequestId::Number(1),
3179            inner: McpRequest::GetPrompt(GetPromptParams {
3180                name: "system_debug".to_string(),
3181                arguments: HashMap::new(),
3182            }),
3183            extensions: Extensions::new(),
3184        };
3185
3186        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3187
3188        // Should get forbidden error
3189        match resp.inner {
3190            Err(e) => {
3191                assert_eq!(e.code, -32007); // Forbidden
3192                assert!(e.message.contains("Unauthorized"));
3193            }
3194            _ => panic!("Expected JsonRpc error"),
3195        }
3196    }
3197}