Skip to main content

tower_mcp/
router.rs

1//! MCP Router - routes requests to tools, resources, and prompts
2//!
3//! The router implements Tower's `Service` trait, making it composable with
4//! standard tower middleware.
5
6use std::any::{Any, TypeId};
7use std::collections::{HashMap, HashSet};
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::{Arc, RwLock};
11use std::task::{Context, Poll};
12
13use tower_service::Service;
14
15use crate::async_task::TaskStore;
16use crate::context::{
17    CancellationToken, ClientRequesterHandle, NotificationSender, RequestContext,
18    ServerNotification,
19};
20use crate::error::{Error, JsonRpcError, Result};
21use crate::prompt::Prompt;
22use crate::protocol::*;
23use crate::resource::{Resource, ResourceTemplate};
24use crate::session::SessionState;
25use crate::tool::Tool;
26
27/// Type alias for completion handler function
28pub type CompletionHandler = Arc<
29    dyn Fn(CompleteParams) -> Pin<Box<dyn Future<Output = Result<CompleteResult>> + Send>>
30        + Send
31        + Sync,
32>;
33
34/// MCP Router that dispatches requests to registered handlers
35///
36/// Implements `tower::Service<McpRequest>` for middleware composition.
37///
38/// # Example
39///
40/// ```rust
41/// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
42/// use schemars::JsonSchema;
43/// use serde::Deserialize;
44///
45/// #[derive(Debug, Deserialize, JsonSchema)]
46/// struct Input { value: String }
47///
48/// let tool = ToolBuilder::new("echo")
49///     .description("Echo input")
50///     .handler(|i: Input| async move { Ok(CallToolResult::text(i.value)) })
51///     .build()
52///     .unwrap();
53///
54/// let router = McpRouter::new()
55///     .server_info("my-server", "1.0.0")
56///     .tool(tool);
57/// ```
58#[derive(Clone)]
59pub struct McpRouter {
60    inner: Arc<McpRouterInner>,
61    session: SessionState,
62}
63
64impl std::fmt::Debug for McpRouter {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("McpRouter")
67            .field("server_name", &self.inner.server_name)
68            .field("server_version", &self.inner.server_version)
69            .field("tools_count", &self.inner.tools.len())
70            .field("resources_count", &self.inner.resources.len())
71            .field("prompts_count", &self.inner.prompts.len())
72            .field("session_phase", &self.session.phase())
73            .finish()
74    }
75}
76
77/// Inner configuration that is shared across clones
78#[derive(Clone)]
79struct McpRouterInner {
80    server_name: String,
81    server_version: String,
82    /// Human-readable title for the server
83    server_title: Option<String>,
84    /// Description of the server
85    server_description: Option<String>,
86    /// Icons for the server
87    server_icons: Option<Vec<ToolIcon>>,
88    /// URL of the server's website
89    server_website_url: Option<String>,
90    instructions: Option<String>,
91    tools: HashMap<String, Arc<Tool>>,
92    resources: HashMap<String, Arc<Resource>>,
93    /// Resource templates for dynamic resource matching (keyed by uri_template)
94    resource_templates: Vec<Arc<ResourceTemplate>>,
95    prompts: HashMap<String, Arc<Prompt>>,
96    /// In-flight requests for cancellation tracking (shared across clones)
97    in_flight: Arc<RwLock<HashMap<RequestId, CancellationToken>>>,
98    /// Channel for sending notifications to connected clients
99    notification_tx: Option<NotificationSender>,
100    /// Handle for sending requests to the client (for sampling, etc.)
101    client_requester: Option<ClientRequesterHandle>,
102    /// Task store for async operations
103    task_store: TaskStore,
104    /// Subscribed resource URIs
105    subscriptions: Arc<RwLock<HashSet<String>>>,
106    /// Handler for completion requests
107    completion_handler: Option<CompletionHandler>,
108}
109
110impl McpRouter {
111    /// Create a new MCP router
112    pub fn new() -> Self {
113        Self {
114            inner: Arc::new(McpRouterInner {
115                server_name: "tower-mcp".to_string(),
116                server_version: env!("CARGO_PKG_VERSION").to_string(),
117                server_title: None,
118                server_description: None,
119                server_icons: None,
120                server_website_url: None,
121                instructions: None,
122                tools: HashMap::new(),
123                resources: HashMap::new(),
124                resource_templates: Vec::new(),
125                prompts: HashMap::new(),
126                in_flight: Arc::new(RwLock::new(HashMap::new())),
127                notification_tx: None,
128                client_requester: None,
129                task_store: TaskStore::new(),
130                subscriptions: Arc::new(RwLock::new(HashSet::new())),
131                completion_handler: None,
132            }),
133            session: SessionState::new(),
134        }
135    }
136
137    /// Get access to the task store for async operations
138    pub fn task_store(&self) -> &TaskStore {
139        &self.inner.task_store
140    }
141
142    /// Set the notification sender for progress reporting
143    ///
144    /// This is typically called by the transport layer to receive notifications.
145    pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
146        Arc::make_mut(&mut self.inner).notification_tx = Some(tx);
147        self
148    }
149
150    /// Get the notification sender (if configured)
151    pub fn notification_sender(&self) -> Option<&NotificationSender> {
152        self.inner.notification_tx.as_ref()
153    }
154
155    /// Set the client requester for server-to-client requests (sampling, etc.)
156    ///
157    /// This is typically called by bidirectional transports (WebSocket, stdio)
158    /// to enable tool handlers to send requests to the client.
159    pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
160        Arc::make_mut(&mut self.inner).client_requester = Some(requester);
161        self
162    }
163
164    /// Get the client requester (if configured)
165    pub fn client_requester(&self) -> Option<&ClientRequesterHandle> {
166        self.inner.client_requester.as_ref()
167    }
168
169    /// Create a request context for tracking a request
170    ///
171    /// This registers the request for cancellation tracking and sets up
172    /// progress reporting and client requests if configured.
173    pub fn create_context(
174        &self,
175        request_id: RequestId,
176        progress_token: Option<ProgressToken>,
177    ) -> RequestContext {
178        let ctx = RequestContext::new(request_id.clone());
179
180        // Set up progress token if provided
181        let ctx = if let Some(token) = progress_token {
182            ctx.with_progress_token(token)
183        } else {
184            ctx
185        };
186
187        // Set up notification sender if configured
188        let ctx = if let Some(tx) = &self.inner.notification_tx {
189            ctx.with_notification_sender(tx.clone())
190        } else {
191            ctx
192        };
193
194        // Set up client requester if configured (for sampling support)
195        let ctx = if let Some(requester) = &self.inner.client_requester {
196            ctx.with_client_requester(requester.clone())
197        } else {
198            ctx
199        };
200
201        // Register for cancellation tracking
202        let token = ctx.cancellation_token();
203        if let Ok(mut in_flight) = self.inner.in_flight.write() {
204            in_flight.insert(request_id, token);
205        }
206
207        ctx
208    }
209
210    /// Remove a request from tracking (called when request completes)
211    pub fn complete_request(&self, request_id: &RequestId) {
212        if let Ok(mut in_flight) = self.inner.in_flight.write() {
213            in_flight.remove(request_id);
214        }
215    }
216
217    /// Cancel a tracked request
218    fn cancel_request(&self, request_id: &RequestId) -> bool {
219        let Ok(in_flight) = self.inner.in_flight.read() else {
220            return false;
221        };
222        let Some(token) = in_flight.get(request_id) else {
223            return false;
224        };
225        token.cancel();
226        true
227    }
228
229    /// Set server info
230    pub fn server_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
231        let inner = Arc::make_mut(&mut self.inner);
232        inner.server_name = name.into();
233        inner.server_version = version.into();
234        self
235    }
236
237    /// Set instructions for LLMs describing how to use this server
238    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
239        Arc::make_mut(&mut self.inner).instructions = Some(instructions.into());
240        self
241    }
242
243    /// Set a human-readable title for the server
244    pub fn server_title(mut self, title: impl Into<String>) -> Self {
245        Arc::make_mut(&mut self.inner).server_title = Some(title.into());
246        self
247    }
248
249    /// Set the server description
250    pub fn server_description(mut self, description: impl Into<String>) -> Self {
251        Arc::make_mut(&mut self.inner).server_description = Some(description.into());
252        self
253    }
254
255    /// Set icons for the server
256    pub fn server_icons(mut self, icons: Vec<ToolIcon>) -> Self {
257        Arc::make_mut(&mut self.inner).server_icons = Some(icons);
258        self
259    }
260
261    /// Set the server's website URL
262    pub fn server_website_url(mut self, url: impl Into<String>) -> Self {
263        Arc::make_mut(&mut self.inner).server_website_url = Some(url.into());
264        self
265    }
266
267    /// Register a tool
268    pub fn tool(mut self, tool: Tool) -> Self {
269        Arc::make_mut(&mut self.inner)
270            .tools
271            .insert(tool.name.clone(), Arc::new(tool));
272        self
273    }
274
275    /// Register a resource
276    pub fn resource(mut self, resource: Resource) -> Self {
277        Arc::make_mut(&mut self.inner)
278            .resources
279            .insert(resource.uri.clone(), Arc::new(resource));
280        self
281    }
282
283    /// Register a resource template
284    ///
285    /// Resource templates allow dynamic resources to be matched by URI pattern.
286    /// When a client requests a resource URI that doesn't match any static
287    /// resource, the router tries to match it against registered templates.
288    ///
289    /// # Example
290    ///
291    /// ```rust
292    /// use tower_mcp::{McpRouter, ResourceTemplateBuilder};
293    /// use tower_mcp::protocol::{ReadResourceResult, ResourceContent};
294    /// use std::collections::HashMap;
295    ///
296    /// let template = ResourceTemplateBuilder::new("file:///{path}")
297    ///     .name("Project Files")
298    ///     .handler(|uri: String, vars: HashMap<String, String>| async move {
299    ///         let path = vars.get("path").unwrap_or(&String::new()).clone();
300    ///         Ok(ReadResourceResult {
301    ///             contents: vec![ResourceContent {
302    ///                 uri,
303    ///                 mime_type: Some("text/plain".to_string()),
304    ///                 text: Some(format!("Contents of {}", path)),
305    ///                 blob: None,
306    ///             }],
307    ///         })
308    ///     });
309    ///
310    /// let router = McpRouter::new()
311    ///     .resource_template(template);
312    /// ```
313    pub fn resource_template(mut self, template: ResourceTemplate) -> Self {
314        Arc::make_mut(&mut self.inner)
315            .resource_templates
316            .push(Arc::new(template));
317        self
318    }
319
320    /// Register a prompt
321    pub fn prompt(mut self, prompt: Prompt) -> Self {
322        Arc::make_mut(&mut self.inner)
323            .prompts
324            .insert(prompt.name.clone(), Arc::new(prompt));
325        self
326    }
327
328    /// Register multiple tools at once.
329    ///
330    /// # Example
331    ///
332    /// ```rust
333    /// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
334    /// use schemars::JsonSchema;
335    /// use serde::Deserialize;
336    ///
337    /// #[derive(Debug, Deserialize, JsonSchema)]
338    /// struct Input { value: String }
339    ///
340    /// let tools = vec![
341    ///     ToolBuilder::new("a")
342    ///         .description("Tool A")
343    ///         .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
344    ///         .build().unwrap(),
345    ///     ToolBuilder::new("b")
346    ///         .description("Tool B")
347    ///         .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
348    ///         .build().unwrap(),
349    /// ];
350    ///
351    /// let router = McpRouter::new().tools(tools);
352    /// ```
353    pub fn tools(self, tools: impl IntoIterator<Item = Tool>) -> Self {
354        tools
355            .into_iter()
356            .fold(self, |router, tool| router.tool(tool))
357    }
358
359    /// Register multiple resources at once.
360    ///
361    /// # Example
362    ///
363    /// ```rust
364    /// use tower_mcp::{McpRouter, ResourceBuilder};
365    ///
366    /// let resources = vec![
367    ///     ResourceBuilder::new("file:///a.txt")
368    ///         .name("File A")
369    ///         .text("contents a"),
370    ///     ResourceBuilder::new("file:///b.txt")
371    ///         .name("File B")
372    ///         .text("contents b"),
373    /// ];
374    ///
375    /// let router = McpRouter::new().resources(resources);
376    /// ```
377    pub fn resources(self, resources: impl IntoIterator<Item = Resource>) -> Self {
378        resources
379            .into_iter()
380            .fold(self, |router, resource| router.resource(resource))
381    }
382
383    /// Register multiple prompts at once.
384    ///
385    /// # Example
386    ///
387    /// ```rust
388    /// use tower_mcp::{McpRouter, PromptBuilder};
389    ///
390    /// let prompts = vec![
391    ///     PromptBuilder::new("greet")
392    ///         .description("Greet someone")
393    ///         .user_message("Hello!"),
394    ///     PromptBuilder::new("farewell")
395    ///         .description("Say goodbye")
396    ///         .user_message("Goodbye!"),
397    /// ];
398    ///
399    /// let router = McpRouter::new().prompts(prompts);
400    /// ```
401    pub fn prompts(self, prompts: impl IntoIterator<Item = Prompt>) -> Self {
402        prompts
403            .into_iter()
404            .fold(self, |router, prompt| router.prompt(prompt))
405    }
406
407    /// Register a completion handler for `completion/complete` requests.
408    ///
409    /// The handler receives `CompleteParams` containing the reference (prompt or resource)
410    /// and the argument being completed, and should return completion suggestions.
411    ///
412    /// # Example
413    ///
414    /// ```rust
415    /// use tower_mcp::{McpRouter, CompleteResult};
416    /// use tower_mcp::protocol::{CompleteParams, CompletionReference};
417    ///
418    /// let router = McpRouter::new()
419    ///     .completion_handler(|params: CompleteParams| async move {
420    ///         // Provide completions based on the reference and argument
421    ///         match params.reference {
422    ///             CompletionReference::Prompt { name } => {
423    ///                 // Return prompt argument completions
424    ///                 Ok(CompleteResult::new(vec!["option1".to_string(), "option2".to_string()]))
425    ///             }
426    ///             CompletionReference::Resource { uri } => {
427    ///                 // Return resource URI completions
428    ///                 Ok(CompleteResult::new(vec![]))
429    ///             }
430    ///         }
431    ///     });
432    /// ```
433    pub fn completion_handler<F, Fut>(mut self, handler: F) -> Self
434    where
435        F: Fn(CompleteParams) -> Fut + Send + Sync + 'static,
436        Fut: Future<Output = Result<CompleteResult>> + Send + 'static,
437    {
438        Arc::make_mut(&mut self.inner).completion_handler =
439            Some(Arc::new(move |params| Box::pin(handler(params))));
440        self
441    }
442
443    /// Get access to the session state
444    pub fn session(&self) -> &SessionState {
445        &self.session
446    }
447
448    /// Send a log message notification to the client
449    ///
450    /// This sends a `notifications/message` notification with the given parameters.
451    /// Returns `true` if the notification was sent, `false` if no notification channel
452    /// is configured.
453    ///
454    /// # Example
455    ///
456    /// ```rust,ignore
457    /// use tower_mcp::protocol::{LogLevel, LoggingMessageParams};
458    ///
459    /// // Simple info message
460    /// router.log(LoggingMessageParams::new(LogLevel::Info).with_data(
461    ///     serde_json::json!({"message": "Operation completed"})
462    /// ));
463    ///
464    /// // Error with logger name
465    /// router.log(LoggingMessageParams::new(LogLevel::Error)
466    ///     .with_logger("database")
467    ///     .with_data(serde_json::json!({"error": "Connection failed"})));
468    /// ```
469    pub fn log(&self, params: LoggingMessageParams) -> bool {
470        let Some(tx) = &self.inner.notification_tx else {
471            return false;
472        };
473        tx.try_send(ServerNotification::LogMessage(params)).is_ok()
474    }
475
476    /// Send an info-level log message
477    ///
478    /// Convenience method for sending an info log with optional data.
479    pub fn log_info(&self, message: &str) -> bool {
480        self.log(
481            LoggingMessageParams::new(LogLevel::Info)
482                .with_data(serde_json::json!({ "message": message })),
483        )
484    }
485
486    /// Send a warning-level log message
487    pub fn log_warning(&self, message: &str) -> bool {
488        self.log(
489            LoggingMessageParams::new(LogLevel::Warning)
490                .with_data(serde_json::json!({ "message": message })),
491        )
492    }
493
494    /// Send an error-level log message
495    pub fn log_error(&self, message: &str) -> bool {
496        self.log(
497            LoggingMessageParams::new(LogLevel::Error)
498                .with_data(serde_json::json!({ "message": message })),
499        )
500    }
501
502    /// Send a debug-level log message
503    pub fn log_debug(&self, message: &str) -> bool {
504        self.log(
505            LoggingMessageParams::new(LogLevel::Debug)
506                .with_data(serde_json::json!({ "message": message })),
507        )
508    }
509
510    /// Check if a resource URI is currently subscribed
511    pub fn is_subscribed(&self, uri: &str) -> bool {
512        if let Ok(subs) = self.inner.subscriptions.read() {
513            return subs.contains(uri);
514        }
515        false
516    }
517
518    /// Get a list of all subscribed resource URIs
519    pub fn subscribed_uris(&self) -> Vec<String> {
520        if let Ok(subs) = self.inner.subscriptions.read() {
521            return subs.iter().cloned().collect();
522        }
523        Vec::new()
524    }
525
526    /// Subscribe to a resource URI
527    fn subscribe(&self, uri: &str) -> bool {
528        if let Ok(mut subs) = self.inner.subscriptions.write() {
529            return subs.insert(uri.to_string());
530        }
531        false
532    }
533
534    /// Unsubscribe from a resource URI
535    fn unsubscribe(&self, uri: &str) -> bool {
536        if let Ok(mut subs) = self.inner.subscriptions.write() {
537            return subs.remove(uri);
538        }
539        false
540    }
541
542    /// Notify clients that a subscribed resource has been updated
543    ///
544    /// Only sends the notification if the resource is currently subscribed.
545    /// Returns `true` if the notification was sent.
546    pub fn notify_resource_updated(&self, uri: &str) -> bool {
547        // Only notify if the resource is subscribed
548        if !self.is_subscribed(uri) {
549            return false;
550        }
551
552        let Some(tx) = &self.inner.notification_tx else {
553            return false;
554        };
555        tx.try_send(ServerNotification::ResourceUpdated {
556            uri: uri.to_string(),
557        })
558        .is_ok()
559    }
560
561    /// Notify clients that the list of available resources has changed
562    ///
563    /// Returns `true` if the notification was sent.
564    pub fn notify_resources_list_changed(&self) -> bool {
565        let Some(tx) = &self.inner.notification_tx else {
566            return false;
567        };
568        tx.try_send(ServerNotification::ResourcesListChanged)
569            .is_ok()
570    }
571
572    /// Get server capabilities based on registered handlers
573    fn capabilities(&self) -> ServerCapabilities {
574        let has_resources =
575            !self.inner.resources.is_empty() || !self.inner.resource_templates.is_empty();
576
577        ServerCapabilities {
578            tools: if self.inner.tools.is_empty() {
579                None
580            } else {
581                Some(ToolsCapability::default())
582            },
583            resources: if has_resources {
584                Some(ResourcesCapability {
585                    subscribe: true,
586                    ..Default::default()
587                })
588            } else {
589                None
590            },
591            prompts: if self.inner.prompts.is_empty() {
592                None
593            } else {
594                Some(PromptsCapability::default())
595            },
596            // Always advertise logging capability when notification channel is configured
597            logging: if self.inner.notification_tx.is_some() {
598                Some(LoggingCapability::default())
599            } else {
600                None
601            },
602            // Tasks capability is always available
603            tasks: Some(TasksCapability::default()),
604            // Completions capability when a handler is registered
605            completions: if self.inner.completion_handler.is_some() {
606                Some(CompletionsCapability::default())
607            } else {
608                None
609            },
610        }
611    }
612
613    /// Handle an MCP request
614    async fn handle(&self, request_id: RequestId, request: McpRequest) -> Result<McpResponse> {
615        // Enforce session state - reject requests before initialization
616        let method = request.method_name();
617        if !self.session.is_request_allowed(method) {
618            tracing::warn!(
619                method = %method,
620                phase = ?self.session.phase(),
621                "Request rejected: session not initialized"
622            );
623            return Err(Error::JsonRpc(JsonRpcError::invalid_request(format!(
624                "Session not initialized. Only 'initialize' and 'ping' are allowed before initialization. Got: {}",
625                method
626            ))));
627        }
628
629        match request {
630            McpRequest::Initialize(params) => {
631                tracing::info!(
632                    client = %params.client_info.name,
633                    version = %params.client_info.version,
634                    "Client initializing"
635                );
636
637                // Protocol version negotiation: respond with same version if supported,
638                // otherwise respond with our latest supported version
639                let protocol_version = if crate::protocol::SUPPORTED_PROTOCOL_VERSIONS
640                    .contains(&params.protocol_version.as_str())
641                {
642                    params.protocol_version
643                } else {
644                    crate::protocol::LATEST_PROTOCOL_VERSION.to_string()
645                };
646
647                // Transition session state to Initializing
648                self.session.mark_initializing();
649
650                Ok(McpResponse::Initialize(InitializeResult {
651                    protocol_version,
652                    capabilities: self.capabilities(),
653                    server_info: Implementation {
654                        name: self.inner.server_name.clone(),
655                        version: self.inner.server_version.clone(),
656                        title: self.inner.server_title.clone(),
657                        description: self.inner.server_description.clone(),
658                        icons: self.inner.server_icons.clone(),
659                        website_url: self.inner.server_website_url.clone(),
660                    },
661                    instructions: self.inner.instructions.clone(),
662                }))
663            }
664
665            McpRequest::ListTools(_params) => {
666                let tools: Vec<ToolDefinition> =
667                    self.inner.tools.values().map(|t| t.definition()).collect();
668
669                Ok(McpResponse::ListTools(ListToolsResult {
670                    tools,
671                    next_cursor: None,
672                }))
673            }
674
675            McpRequest::CallTool(params) => {
676                let tool =
677                    self.inner.tools.get(&params.name).ok_or_else(|| {
678                        Error::JsonRpc(JsonRpcError::method_not_found(&params.name))
679                    })?;
680
681                // Extract progress token from request metadata
682                let progress_token = params.meta.and_then(|m| m.progress_token);
683                let ctx = self.create_context(request_id, progress_token);
684
685                tracing::debug!(tool = %params.name, "Calling tool");
686                let result = tool.call_with_context(ctx, params.arguments).await?;
687
688                Ok(McpResponse::CallTool(result))
689            }
690
691            McpRequest::ListResources(_params) => {
692                let resources: Vec<ResourceDefinition> = self
693                    .inner
694                    .resources
695                    .values()
696                    .map(|r| r.definition())
697                    .collect();
698
699                Ok(McpResponse::ListResources(ListResourcesResult {
700                    resources,
701                    next_cursor: None,
702                }))
703            }
704
705            McpRequest::ListResourceTemplates(_params) => {
706                let resource_templates: Vec<ResourceTemplateDefinition> = self
707                    .inner
708                    .resource_templates
709                    .iter()
710                    .map(|t| t.definition())
711                    .collect();
712
713                Ok(McpResponse::ListResourceTemplates(
714                    ListResourceTemplatesResult {
715                        resource_templates,
716                        next_cursor: None,
717                    },
718                ))
719            }
720
721            McpRequest::ReadResource(params) => {
722                // First, try to find a static resource
723                if let Some(resource) = self.inner.resources.get(&params.uri) {
724                    tracing::debug!(uri = %params.uri, "Reading static resource");
725                    let result = resource.read().await?;
726                    return Ok(McpResponse::ReadResource(result));
727                }
728
729                // If no static resource found, try to match against templates
730                for template in &self.inner.resource_templates {
731                    if let Some(variables) = template.match_uri(&params.uri) {
732                        tracing::debug!(
733                            uri = %params.uri,
734                            template = %template.uri_template,
735                            "Reading resource via template"
736                        );
737                        let result = template.read(&params.uri, variables).await?;
738                        return Ok(McpResponse::ReadResource(result));
739                    }
740                }
741
742                // No match found
743                Err(Error::JsonRpc(JsonRpcError::resource_not_found(
744                    &params.uri,
745                )))
746            }
747
748            McpRequest::SubscribeResource(params) => {
749                // Verify the resource exists
750                if !self.inner.resources.contains_key(&params.uri) {
751                    return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
752                        &params.uri,
753                    )));
754                }
755
756                tracing::debug!(uri = %params.uri, "Subscribing to resource");
757                self.subscribe(&params.uri);
758
759                Ok(McpResponse::SubscribeResource(EmptyResult {}))
760            }
761
762            McpRequest::UnsubscribeResource(params) => {
763                // Verify the resource exists
764                if !self.inner.resources.contains_key(&params.uri) {
765                    return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
766                        &params.uri,
767                    )));
768                }
769
770                tracing::debug!(uri = %params.uri, "Unsubscribing from resource");
771                self.unsubscribe(&params.uri);
772
773                Ok(McpResponse::UnsubscribeResource(EmptyResult {}))
774            }
775
776            McpRequest::ListPrompts(_params) => {
777                let prompts: Vec<PromptDefinition> = self
778                    .inner
779                    .prompts
780                    .values()
781                    .map(|p| p.definition())
782                    .collect();
783
784                Ok(McpResponse::ListPrompts(ListPromptsResult {
785                    prompts,
786                    next_cursor: None,
787                }))
788            }
789
790            McpRequest::GetPrompt(params) => {
791                let prompt = self.inner.prompts.get(&params.name).ok_or_else(|| {
792                    Error::JsonRpc(JsonRpcError::method_not_found(&format!(
793                        "Prompt not found: {}",
794                        params.name
795                    )))
796                })?;
797
798                tracing::debug!(name = %params.name, "Getting prompt");
799                let result = prompt.get(params.arguments).await?;
800
801                Ok(McpResponse::GetPrompt(result))
802            }
803
804            McpRequest::Ping => Ok(McpResponse::Pong(EmptyResult {})),
805
806            McpRequest::EnqueueTask(params) => {
807                // Verify the tool exists
808                let tool = self.inner.tools.get(&params.tool_name).ok_or_else(|| {
809                    Error::JsonRpc(JsonRpcError::method_not_found(&format!(
810                        "Tool not found: {}",
811                        params.tool_name
812                    )))
813                })?;
814
815                // Create the task
816                let (task_id, cancellation_token) = self.inner.task_store.create_task(
817                    &params.tool_name,
818                    params.arguments.clone(),
819                    params.ttl,
820                );
821
822                tracing::info!(task_id = %task_id, tool = %params.tool_name, "Enqueued async task");
823
824                // Create a context for the async task execution
825                let ctx = self.create_context(request_id, None);
826
827                // Spawn the task execution in the background
828                let task_store = self.inner.task_store.clone();
829                let tool = tool.clone();
830                let arguments = params.arguments;
831                let task_id_clone = task_id.clone();
832
833                tokio::spawn(async move {
834                    // Check for cancellation before starting
835                    if cancellation_token.is_cancelled() {
836                        tracing::debug!(task_id = %task_id_clone, "Task cancelled before execution");
837                        return;
838                    }
839
840                    // Execute the tool
841                    match tool.call_with_context(ctx, arguments).await {
842                        Ok(result) => {
843                            if cancellation_token.is_cancelled() {
844                                tracing::debug!(task_id = %task_id_clone, "Task cancelled during execution");
845                            } else {
846                                task_store.complete_task(&task_id_clone, result);
847                                tracing::debug!(task_id = %task_id_clone, "Task completed successfully");
848                            }
849                        }
850                        Err(e) => {
851                            task_store.fail_task(&task_id_clone, &e.to_string());
852                            tracing::warn!(task_id = %task_id_clone, error = %e, "Task failed");
853                        }
854                    }
855                });
856
857                Ok(McpResponse::EnqueueTask(EnqueueTaskResult {
858                    task_id,
859                    status: TaskStatus::Working,
860                    poll_interval: Some(2),
861                }))
862            }
863
864            McpRequest::ListTasks(params) => {
865                let tasks = self.inner.task_store.list_tasks(params.status);
866
867                Ok(McpResponse::ListTasks(ListTasksResult {
868                    tasks,
869                    next_cursor: None,
870                }))
871            }
872
873            McpRequest::GetTaskInfo(params) => {
874                let task = self
875                    .inner
876                    .task_store
877                    .get_task(&params.task_id)
878                    .ok_or_else(|| {
879                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
880                            "Task not found: {}",
881                            params.task_id
882                        )))
883                    })?;
884
885                Ok(McpResponse::GetTaskInfo(task))
886            }
887
888            McpRequest::GetTaskResult(params) => {
889                let (status, result, error) = self
890                    .inner
891                    .task_store
892                    .get_task_full(&params.task_id)
893                    .ok_or_else(|| {
894                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
895                            "Task not found: {}",
896                            params.task_id
897                        )))
898                    })?;
899
900                Ok(McpResponse::GetTaskResult(GetTaskResultResult {
901                    task_id: params.task_id,
902                    status,
903                    result,
904                    error,
905                }))
906            }
907
908            McpRequest::CancelTask(params) => {
909                let status = self
910                    .inner
911                    .task_store
912                    .cancel_task(&params.task_id, params.reason.as_deref())
913                    .ok_or_else(|| {
914                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
915                            "Task not found: {}",
916                            params.task_id
917                        )))
918                    })?;
919
920                let cancelled = status == TaskStatus::Cancelled;
921
922                Ok(McpResponse::CancelTask(CancelTaskResult {
923                    cancelled,
924                    status,
925                }))
926            }
927
928            McpRequest::SetLoggingLevel(params) => {
929                // Store the log level for filtering outgoing log notifications
930                // For now, we just accept the request - actual filtering would be
931                // implemented in the notification sending logic
932                tracing::debug!(level = ?params.level, "Client set logging level");
933                Ok(McpResponse::SetLoggingLevel(EmptyResult {}))
934            }
935
936            McpRequest::Complete(params) => {
937                tracing::debug!(
938                    reference = ?params.reference,
939                    argument = %params.argument.name,
940                    "Completion request"
941                );
942
943                // Delegate to registered completion handler if available
944                if let Some(ref handler) = self.inner.completion_handler {
945                    let result = handler(params).await?;
946                    Ok(McpResponse::Complete(result))
947                } else {
948                    // No completion handler registered, return empty completions
949                    Ok(McpResponse::Complete(CompleteResult::new(vec![])))
950                }
951            }
952
953            McpRequest::Unknown { method, .. } => {
954                Err(Error::JsonRpc(JsonRpcError::method_not_found(&method)))
955            }
956        }
957    }
958
959    /// Handle an MCP notification (no response expected)
960    pub fn handle_notification(&self, notification: McpNotification) {
961        match notification {
962            McpNotification::Initialized => {
963                if self.session.mark_initialized() {
964                    tracing::info!("Session initialized, entering operation phase");
965                } else {
966                    tracing::warn!(
967                        "Received initialized notification in unexpected state: {:?}",
968                        self.session.phase()
969                    );
970                }
971            }
972            McpNotification::Cancelled(params) => {
973                if self.cancel_request(&params.request_id) {
974                    tracing::info!(
975                        request_id = ?params.request_id,
976                        reason = ?params.reason,
977                        "Request cancelled"
978                    );
979                } else {
980                    tracing::debug!(
981                        request_id = ?params.request_id,
982                        reason = ?params.reason,
983                        "Cancellation requested for unknown request"
984                    );
985                }
986            }
987            McpNotification::Progress(params) => {
988                tracing::trace!(
989                    token = ?params.progress_token,
990                    progress = params.progress,
991                    total = ?params.total,
992                    "Progress notification"
993                );
994                // Progress notifications from client are unusual but valid
995            }
996            McpNotification::RootsListChanged => {
997                tracing::info!("Client roots list changed");
998                // Server should re-request roots if needed
999                // This is handled by the application layer
1000            }
1001            McpNotification::Unknown { method, .. } => {
1002                tracing::debug!(method = %method, "Unknown notification received");
1003            }
1004        }
1005    }
1006}
1007
1008impl Default for McpRouter {
1009    fn default() -> Self {
1010        Self::new()
1011    }
1012}
1013
1014// =============================================================================
1015// Tower Service implementation
1016// =============================================================================
1017
1018/// A minimal type-map for passing data through middleware.
1019///
1020/// Uses `Arc<dyn Any>` internally so `Clone` is cheap, which is needed for
1021/// batch requests that create multiple `RouterRequest`s from the same HTTP
1022/// request.
1023///
1024/// # Example
1025///
1026/// ```rust
1027/// use tower_mcp::Extensions;
1028///
1029/// let mut ext = Extensions::new();
1030/// ext.insert(42u32);
1031/// assert_eq!(ext.get::<u32>(), Some(&42));
1032/// ```
1033#[derive(Default, Clone)]
1034pub struct Extensions {
1035    map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
1036}
1037
1038impl Extensions {
1039    /// Create an empty extensions map.
1040    pub fn new() -> Self {
1041        Self::default()
1042    }
1043
1044    /// Insert a value into the extensions map.
1045    ///
1046    /// If a value of the same type already exists, it is replaced.
1047    pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
1048        self.map.insert(TypeId::of::<T>(), Arc::new(val));
1049    }
1050
1051    /// Get a reference to a value in the extensions map.
1052    ///
1053    /// Returns `None` if no value of the given type has been inserted.
1054    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
1055        self.map
1056            .get(&TypeId::of::<T>())
1057            .and_then(|val| val.downcast_ref::<T>())
1058    }
1059}
1060
1061impl std::fmt::Debug for Extensions {
1062    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1063        f.debug_struct("Extensions")
1064            .field("len", &self.map.len())
1065            .finish()
1066    }
1067}
1068
1069/// Request type for the tower Service implementation
1070#[derive(Debug)]
1071pub struct RouterRequest {
1072    pub id: RequestId,
1073    pub inner: McpRequest,
1074    /// Type-map for passing data (e.g., `TokenClaims`) through middleware.
1075    pub extensions: Extensions,
1076}
1077
1078/// Response type for the tower Service implementation
1079#[derive(Debug)]
1080pub struct RouterResponse {
1081    pub id: RequestId,
1082    pub inner: std::result::Result<McpResponse, JsonRpcError>,
1083}
1084
1085impl RouterResponse {
1086    /// Convert to JSON-RPC response
1087    pub fn into_jsonrpc(self) -> JsonRpcResponse {
1088        match self.inner {
1089            Ok(response) => match serde_json::to_value(response) {
1090                Ok(result) => JsonRpcResponse::result(self.id, result),
1091                Err(e) => {
1092                    tracing::error!(error = %e, "Failed to serialize response");
1093                    JsonRpcResponse::error(
1094                        Some(self.id),
1095                        JsonRpcError::internal_error(format!("Serialization error: {}", e)),
1096                    )
1097                }
1098            },
1099            Err(error) => JsonRpcResponse::error(Some(self.id), error),
1100        }
1101    }
1102}
1103
1104impl Service<RouterRequest> for McpRouter {
1105    type Response = RouterResponse;
1106    type Error = std::convert::Infallible; // Errors are in the response
1107    type Future =
1108        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
1109
1110    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
1111        Poll::Ready(Ok(()))
1112    }
1113
1114    fn call(&mut self, req: RouterRequest) -> Self::Future {
1115        let router = self.clone();
1116        let request_id = req.id.clone();
1117        Box::pin(async move {
1118            let result = router.handle(req.id, req.inner).await;
1119            // Clean up tracking after request completes
1120            router.complete_request(&request_id);
1121            Ok(RouterResponse {
1122                id: request_id,
1123                // Map tower-mcp errors to JSON-RPC errors:
1124                // - Error::JsonRpc: forwarded as-is (preserves original code)
1125                // - Error::Tool: mapped to -32603 (Internal Error)
1126                // - All others: mapped to -32603 (Internal Error)
1127                inner: result.map_err(|e| match e {
1128                    Error::JsonRpc(err) => err,
1129                    Error::Tool(err) => JsonRpcError::internal_error(err.to_string()),
1130                    e => JsonRpcError::internal_error(e.to_string()),
1131                }),
1132            })
1133        })
1134    }
1135}
1136
1137#[cfg(test)]
1138mod tests {
1139    use super::*;
1140    use crate::jsonrpc::JsonRpcService;
1141    use crate::tool::ToolBuilder;
1142    use schemars::JsonSchema;
1143    use serde::Deserialize;
1144    use tower::ServiceExt;
1145
1146    #[derive(Debug, Deserialize, JsonSchema)]
1147    struct AddInput {
1148        a: i64,
1149        b: i64,
1150    }
1151
1152    /// Helper to initialize a router for testing
1153    async fn init_router(router: &mut McpRouter) {
1154        // Send initialize request
1155        let init_req = RouterRequest {
1156            id: RequestId::Number(0),
1157            inner: McpRequest::Initialize(InitializeParams {
1158                protocol_version: "2025-11-25".to_string(),
1159                capabilities: ClientCapabilities {
1160                    roots: None,
1161                    sampling: None,
1162                    elicitation: None,
1163                },
1164                client_info: Implementation {
1165                    name: "test".to_string(),
1166                    version: "1.0".to_string(),
1167                    ..Default::default()
1168                },
1169            }),
1170            extensions: Extensions::new(),
1171        };
1172        let _ = router.ready().await.unwrap().call(init_req).await.unwrap();
1173        // Send initialized notification
1174        router.handle_notification(McpNotification::Initialized);
1175    }
1176
1177    #[tokio::test]
1178    async fn test_router_list_tools() {
1179        let add_tool = ToolBuilder::new("add")
1180            .description("Add two numbers")
1181            .handler(|input: AddInput| async move {
1182                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1183            })
1184            .build()
1185            .expect("valid tool name");
1186
1187        let mut router = McpRouter::new().tool(add_tool);
1188
1189        // Initialize session first
1190        init_router(&mut router).await;
1191
1192        let req = RouterRequest {
1193            id: RequestId::Number(1),
1194            inner: McpRequest::ListTools(ListToolsParams::default()),
1195            extensions: Extensions::new(),
1196        };
1197
1198        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1199
1200        match resp.inner {
1201            Ok(McpResponse::ListTools(result)) => {
1202                assert_eq!(result.tools.len(), 1);
1203                assert_eq!(result.tools[0].name, "add");
1204            }
1205            _ => panic!("Expected ListTools response"),
1206        }
1207    }
1208
1209    #[tokio::test]
1210    async fn test_router_call_tool() {
1211        let add_tool = ToolBuilder::new("add")
1212            .description("Add two numbers")
1213            .handler(|input: AddInput| async move {
1214                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1215            })
1216            .build()
1217            .expect("valid tool name");
1218
1219        let mut router = McpRouter::new().tool(add_tool);
1220
1221        // Initialize session first
1222        init_router(&mut router).await;
1223
1224        let req = RouterRequest {
1225            id: RequestId::Number(1),
1226            inner: McpRequest::CallTool(CallToolParams {
1227                name: "add".to_string(),
1228                arguments: serde_json::json!({"a": 2, "b": 3}),
1229                meta: None,
1230            }),
1231            extensions: Extensions::new(),
1232        };
1233
1234        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1235
1236        match resp.inner {
1237            Ok(McpResponse::CallTool(result)) => {
1238                assert!(!result.is_error);
1239                // Check the text content
1240                match &result.content[0] {
1241                    Content::Text { text, .. } => assert_eq!(text, "5"),
1242                    _ => panic!("Expected text content"),
1243                }
1244            }
1245            _ => panic!("Expected CallTool response"),
1246        }
1247    }
1248
1249    /// Helper to initialize a JsonRpcService for testing
1250    async fn init_jsonrpc_service(service: &mut JsonRpcService<McpRouter>, router: &McpRouter) {
1251        let init_req = JsonRpcRequest::new(0, "initialize").with_params(serde_json::json!({
1252            "protocolVersion": "2025-11-25",
1253            "capabilities": {},
1254            "clientInfo": { "name": "test", "version": "1.0" }
1255        }));
1256        let _ = service.call_single(init_req).await.unwrap();
1257        router.handle_notification(McpNotification::Initialized);
1258    }
1259
1260    #[tokio::test]
1261    async fn test_jsonrpc_service() {
1262        let add_tool = ToolBuilder::new("add")
1263            .description("Add two numbers")
1264            .handler(|input: AddInput| async move {
1265                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1266            })
1267            .build()
1268            .expect("valid tool name");
1269
1270        let router = McpRouter::new().tool(add_tool);
1271        let mut service = JsonRpcService::new(router.clone());
1272
1273        // Initialize session first
1274        init_jsonrpc_service(&mut service, &router).await;
1275
1276        let req = JsonRpcRequest::new(1, "tools/list");
1277
1278        let resp = service.call_single(req).await.unwrap();
1279
1280        match resp {
1281            JsonRpcResponse::Result(r) => {
1282                assert_eq!(r.id, RequestId::Number(1));
1283                let tools = r.result.get("tools").unwrap().as_array().unwrap();
1284                assert_eq!(tools.len(), 1);
1285            }
1286            JsonRpcResponse::Error(_) => panic!("Expected success response"),
1287        }
1288    }
1289
1290    #[tokio::test]
1291    async fn test_batch_request() {
1292        let add_tool = ToolBuilder::new("add")
1293            .description("Add two numbers")
1294            .handler(|input: AddInput| async move {
1295                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1296            })
1297            .build()
1298            .expect("valid tool name");
1299
1300        let router = McpRouter::new().tool(add_tool);
1301        let mut service = JsonRpcService::new(router.clone());
1302
1303        // Initialize session first
1304        init_jsonrpc_service(&mut service, &router).await;
1305
1306        // Create a batch of requests
1307        let requests = vec![
1308            JsonRpcRequest::new(1, "tools/list"),
1309            JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1310                "name": "add",
1311                "arguments": {"a": 10, "b": 20}
1312            })),
1313            JsonRpcRequest::new(3, "ping"),
1314        ];
1315
1316        let responses = service.call_batch(requests).await.unwrap();
1317
1318        assert_eq!(responses.len(), 3);
1319
1320        // Check first response (tools/list)
1321        match &responses[0] {
1322            JsonRpcResponse::Result(r) => {
1323                assert_eq!(r.id, RequestId::Number(1));
1324                let tools = r.result.get("tools").unwrap().as_array().unwrap();
1325                assert_eq!(tools.len(), 1);
1326            }
1327            JsonRpcResponse::Error(_) => panic!("Expected success for tools/list"),
1328        }
1329
1330        // Check second response (tools/call)
1331        match &responses[1] {
1332            JsonRpcResponse::Result(r) => {
1333                assert_eq!(r.id, RequestId::Number(2));
1334                let content = r.result.get("content").unwrap().as_array().unwrap();
1335                let text = content[0].get("text").unwrap().as_str().unwrap();
1336                assert_eq!(text, "30");
1337            }
1338            JsonRpcResponse::Error(_) => panic!("Expected success for tools/call"),
1339        }
1340
1341        // Check third response (ping)
1342        match &responses[2] {
1343            JsonRpcResponse::Result(r) => {
1344                assert_eq!(r.id, RequestId::Number(3));
1345            }
1346            JsonRpcResponse::Error(_) => panic!("Expected success for ping"),
1347        }
1348    }
1349
1350    #[tokio::test]
1351    async fn test_empty_batch_error() {
1352        let router = McpRouter::new();
1353        let mut service = JsonRpcService::new(router);
1354
1355        let result = service.call_batch(vec![]).await;
1356        assert!(result.is_err());
1357    }
1358
1359    // =========================================================================
1360    // Progress Token Tests
1361    // =========================================================================
1362
1363    #[tokio::test]
1364    async fn test_progress_token_extraction() {
1365        use crate::context::{RequestContext, ServerNotification, notification_channel};
1366        use crate::protocol::ProgressToken;
1367        use std::sync::Arc;
1368        use std::sync::atomic::{AtomicBool, Ordering};
1369
1370        // Track whether progress was reported
1371        let progress_reported = Arc::new(AtomicBool::new(false));
1372        let progress_ref = progress_reported.clone();
1373
1374        // Create a tool that reports progress
1375        let tool = ToolBuilder::new("progress_tool")
1376            .description("Tool that reports progress")
1377            .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1378                let reported = progress_ref.clone();
1379                async move {
1380                    // Report progress - this should work if token was extracted
1381                    ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
1382                        .await;
1383                    reported.store(true, Ordering::SeqCst);
1384                    Ok(CallToolResult::text("done"))
1385                }
1386            })
1387            .build()
1388            .expect("valid tool name");
1389
1390        // Set up notification channel
1391        let (tx, mut rx) = notification_channel(10);
1392        let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1393        let mut service = JsonRpcService::new(router.clone());
1394
1395        // Initialize
1396        init_jsonrpc_service(&mut service, &router).await;
1397
1398        // Call tool WITH progress token in _meta
1399        let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1400            "name": "progress_tool",
1401            "arguments": {"a": 1, "b": 2},
1402            "_meta": {
1403                "progressToken": "test-token-123"
1404            }
1405        }));
1406
1407        let resp = service.call_single(req).await.unwrap();
1408
1409        // Verify the tool was called successfully
1410        match resp {
1411            JsonRpcResponse::Result(_) => {}
1412            JsonRpcResponse::Error(e) => panic!("Expected success, got error: {:?}", e),
1413        }
1414
1415        // Verify progress was reported by handler
1416        assert!(progress_reported.load(Ordering::SeqCst));
1417
1418        // Verify progress notification was sent through channel
1419        let notification = rx.try_recv().expect("Expected progress notification");
1420        match notification {
1421            ServerNotification::Progress(params) => {
1422                assert_eq!(
1423                    params.progress_token,
1424                    ProgressToken::String("test-token-123".to_string())
1425                );
1426                assert_eq!(params.progress, 50.0);
1427                assert_eq!(params.total, Some(100.0));
1428                assert_eq!(params.message.as_deref(), Some("Halfway"));
1429            }
1430            _ => panic!("Expected Progress notification"),
1431        }
1432    }
1433
1434    #[tokio::test]
1435    async fn test_tool_call_without_progress_token() {
1436        use crate::context::{RequestContext, notification_channel};
1437        use std::sync::Arc;
1438        use std::sync::atomic::{AtomicBool, Ordering};
1439
1440        let progress_attempted = Arc::new(AtomicBool::new(false));
1441        let progress_ref = progress_attempted.clone();
1442
1443        let tool = ToolBuilder::new("no_token_tool")
1444            .description("Tool that tries to report progress without token")
1445            .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1446                let attempted = progress_ref.clone();
1447                async move {
1448                    // Try to report progress - should be a no-op without token
1449                    ctx.report_progress(50.0, Some(100.0), None).await;
1450                    attempted.store(true, Ordering::SeqCst);
1451                    Ok(CallToolResult::text("done"))
1452                }
1453            })
1454            .build()
1455            .expect("valid tool name");
1456
1457        let (tx, mut rx) = notification_channel(10);
1458        let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1459        let mut service = JsonRpcService::new(router.clone());
1460
1461        init_jsonrpc_service(&mut service, &router).await;
1462
1463        // Call tool WITHOUT progress token
1464        let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1465            "name": "no_token_tool",
1466            "arguments": {"a": 1, "b": 2}
1467        }));
1468
1469        let resp = service.call_single(req).await.unwrap();
1470        assert!(matches!(resp, JsonRpcResponse::Result(_)));
1471
1472        // Handler was called
1473        assert!(progress_attempted.load(Ordering::SeqCst));
1474
1475        // But no notification was sent (no progress token)
1476        assert!(rx.try_recv().is_err());
1477    }
1478
1479    #[tokio::test]
1480    async fn test_batch_errors_returned_not_dropped() {
1481        let add_tool = ToolBuilder::new("add")
1482            .description("Add two numbers")
1483            .handler(|input: AddInput| async move {
1484                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1485            })
1486            .build()
1487            .expect("valid tool name");
1488
1489        let router = McpRouter::new().tool(add_tool);
1490        let mut service = JsonRpcService::new(router.clone());
1491
1492        init_jsonrpc_service(&mut service, &router).await;
1493
1494        // Create a batch with one valid and one invalid request
1495        let requests = vec![
1496            // Valid request
1497            JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1498                "name": "add",
1499                "arguments": {"a": 10, "b": 20}
1500            })),
1501            // Invalid request - tool doesn't exist
1502            JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1503                "name": "nonexistent_tool",
1504                "arguments": {}
1505            })),
1506            // Another valid request
1507            JsonRpcRequest::new(3, "ping"),
1508        ];
1509
1510        let responses = service.call_batch(requests).await.unwrap();
1511
1512        // All three requests should have responses (errors are not dropped)
1513        assert_eq!(responses.len(), 3);
1514
1515        // First should be success
1516        match &responses[0] {
1517            JsonRpcResponse::Result(r) => {
1518                assert_eq!(r.id, RequestId::Number(1));
1519            }
1520            JsonRpcResponse::Error(_) => panic!("Expected success for first request"),
1521        }
1522
1523        // Second should be an error (tool not found)
1524        match &responses[1] {
1525            JsonRpcResponse::Error(e) => {
1526                assert_eq!(e.id, Some(RequestId::Number(2)));
1527                // Error should indicate method not found
1528                assert!(e.error.message.contains("not found") || e.error.code == -32601);
1529            }
1530            JsonRpcResponse::Result(_) => panic!("Expected error for second request"),
1531        }
1532
1533        // Third should be success
1534        match &responses[2] {
1535            JsonRpcResponse::Result(r) => {
1536                assert_eq!(r.id, RequestId::Number(3));
1537            }
1538            JsonRpcResponse::Error(_) => panic!("Expected success for third request"),
1539        }
1540    }
1541
1542    // =========================================================================
1543    // Resource Template Tests
1544    // =========================================================================
1545
1546    #[tokio::test]
1547    async fn test_list_resource_templates() {
1548        use crate::resource::ResourceTemplateBuilder;
1549        use std::collections::HashMap;
1550
1551        let template = ResourceTemplateBuilder::new("file:///{path}")
1552            .name("Project Files")
1553            .description("Access project files")
1554            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1555                Ok(ReadResourceResult {
1556                    contents: vec![ResourceContent {
1557                        uri,
1558                        mime_type: None,
1559                        text: None,
1560                        blob: None,
1561                    }],
1562                })
1563            });
1564
1565        let mut router = McpRouter::new().resource_template(template);
1566
1567        // Initialize session
1568        init_router(&mut router).await;
1569
1570        let req = RouterRequest {
1571            id: RequestId::Number(1),
1572            inner: McpRequest::ListResourceTemplates(ListResourceTemplatesParams::default()),
1573            extensions: Extensions::new(),
1574        };
1575
1576        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1577
1578        match resp.inner {
1579            Ok(McpResponse::ListResourceTemplates(result)) => {
1580                assert_eq!(result.resource_templates.len(), 1);
1581                assert_eq!(result.resource_templates[0].uri_template, "file:///{path}");
1582                assert_eq!(result.resource_templates[0].name, "Project Files");
1583            }
1584            _ => panic!("Expected ListResourceTemplates response"),
1585        }
1586    }
1587
1588    #[tokio::test]
1589    async fn test_read_resource_via_template() {
1590        use crate::resource::ResourceTemplateBuilder;
1591        use std::collections::HashMap;
1592
1593        let template = ResourceTemplateBuilder::new("db://users/{id}")
1594            .name("User Records")
1595            .handler(|uri: String, vars: HashMap<String, String>| async move {
1596                let id = vars.get("id").unwrap().clone();
1597                Ok(ReadResourceResult {
1598                    contents: vec![ResourceContent {
1599                        uri,
1600                        mime_type: Some("application/json".to_string()),
1601                        text: Some(format!(r#"{{"id": "{}"}}"#, id)),
1602                        blob: None,
1603                    }],
1604                })
1605            });
1606
1607        let mut router = McpRouter::new().resource_template(template);
1608
1609        // Initialize session
1610        init_router(&mut router).await;
1611
1612        // Read a resource that matches the template
1613        let req = RouterRequest {
1614            id: RequestId::Number(1),
1615            inner: McpRequest::ReadResource(ReadResourceParams {
1616                uri: "db://users/123".to_string(),
1617            }),
1618            extensions: Extensions::new(),
1619        };
1620
1621        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1622
1623        match resp.inner {
1624            Ok(McpResponse::ReadResource(result)) => {
1625                assert_eq!(result.contents.len(), 1);
1626                assert_eq!(result.contents[0].uri, "db://users/123");
1627                assert!(result.contents[0].text.as_ref().unwrap().contains("123"));
1628            }
1629            _ => panic!("Expected ReadResource response"),
1630        }
1631    }
1632
1633    #[tokio::test]
1634    async fn test_static_resource_takes_precedence_over_template() {
1635        use crate::resource::{ResourceBuilder, ResourceTemplateBuilder};
1636        use std::collections::HashMap;
1637
1638        // Template that would match the same URI
1639        let template = ResourceTemplateBuilder::new("file:///{path}")
1640            .name("Files Template")
1641            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1642                Ok(ReadResourceResult {
1643                    contents: vec![ResourceContent {
1644                        uri,
1645                        mime_type: None,
1646                        text: Some("from template".to_string()),
1647                        blob: None,
1648                    }],
1649                })
1650            });
1651
1652        // Static resource with exact URI
1653        let static_resource = ResourceBuilder::new("file:///README.md")
1654            .name("README")
1655            .text("from static resource");
1656
1657        let mut router = McpRouter::new()
1658            .resource_template(template)
1659            .resource(static_resource);
1660
1661        // Initialize session
1662        init_router(&mut router).await;
1663
1664        // Read the static resource - should NOT go through template
1665        let req = RouterRequest {
1666            id: RequestId::Number(1),
1667            inner: McpRequest::ReadResource(ReadResourceParams {
1668                uri: "file:///README.md".to_string(),
1669            }),
1670            extensions: Extensions::new(),
1671        };
1672
1673        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1674
1675        match resp.inner {
1676            Ok(McpResponse::ReadResource(result)) => {
1677                // Should get static resource, not template
1678                assert_eq!(
1679                    result.contents[0].text.as_deref(),
1680                    Some("from static resource")
1681                );
1682            }
1683            _ => panic!("Expected ReadResource response"),
1684        }
1685    }
1686
1687    #[tokio::test]
1688    async fn test_resource_not_found_when_no_match() {
1689        use crate::resource::ResourceTemplateBuilder;
1690        use std::collections::HashMap;
1691
1692        let template = ResourceTemplateBuilder::new("db://users/{id}")
1693            .name("Users")
1694            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1695                Ok(ReadResourceResult {
1696                    contents: vec![ResourceContent {
1697                        uri,
1698                        mime_type: None,
1699                        text: None,
1700                        blob: None,
1701                    }],
1702                })
1703            });
1704
1705        let mut router = McpRouter::new().resource_template(template);
1706
1707        // Initialize session
1708        init_router(&mut router).await;
1709
1710        // Try to read a URI that doesn't match any resource or template
1711        let req = RouterRequest {
1712            id: RequestId::Number(1),
1713            inner: McpRequest::ReadResource(ReadResourceParams {
1714                uri: "db://posts/123".to_string(),
1715            }),
1716            extensions: Extensions::new(),
1717        };
1718
1719        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1720
1721        match resp.inner {
1722            Err(err) => {
1723                assert!(err.message.contains("not found"));
1724            }
1725            Ok(_) => panic!("Expected error for non-matching URI"),
1726        }
1727    }
1728
1729    #[tokio::test]
1730    async fn test_capabilities_include_resources_with_only_templates() {
1731        use crate::resource::ResourceTemplateBuilder;
1732        use std::collections::HashMap;
1733
1734        let template = ResourceTemplateBuilder::new("file:///{path}")
1735            .name("Files")
1736            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1737                Ok(ReadResourceResult {
1738                    contents: vec![ResourceContent {
1739                        uri,
1740                        mime_type: None,
1741                        text: None,
1742                        blob: None,
1743                    }],
1744                })
1745            });
1746
1747        let mut router = McpRouter::new().resource_template(template);
1748
1749        // Send initialize request and check capabilities
1750        let init_req = RouterRequest {
1751            id: RequestId::Number(0),
1752            inner: McpRequest::Initialize(InitializeParams {
1753                protocol_version: "2025-11-25".to_string(),
1754                capabilities: ClientCapabilities {
1755                    roots: None,
1756                    sampling: None,
1757                    elicitation: None,
1758                },
1759                client_info: Implementation {
1760                    name: "test".to_string(),
1761                    version: "1.0".to_string(),
1762                    ..Default::default()
1763                },
1764            }),
1765            extensions: Extensions::new(),
1766        };
1767        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1768
1769        match resp.inner {
1770            Ok(McpResponse::Initialize(result)) => {
1771                // Should have resources capability even though only templates registered
1772                assert!(result.capabilities.resources.is_some());
1773            }
1774            _ => panic!("Expected Initialize response"),
1775        }
1776    }
1777
1778    // =========================================================================
1779    // Logging Notification Tests
1780    // =========================================================================
1781
1782    #[tokio::test]
1783    async fn test_log_sends_notification() {
1784        use crate::context::notification_channel;
1785
1786        let (tx, mut rx) = notification_channel(10);
1787        let router = McpRouter::new().with_notification_sender(tx);
1788
1789        // Send an info log
1790        let sent = router.log_info("Test message");
1791        assert!(sent);
1792
1793        // Should receive the notification
1794        let notification = rx.try_recv().unwrap();
1795        match notification {
1796            ServerNotification::LogMessage(params) => {
1797                assert_eq!(params.level, LogLevel::Info);
1798                let data = params.data.unwrap();
1799                assert_eq!(
1800                    data.get("message").unwrap().as_str().unwrap(),
1801                    "Test message"
1802                );
1803            }
1804            _ => panic!("Expected LogMessage notification"),
1805        }
1806    }
1807
1808    #[tokio::test]
1809    async fn test_log_with_custom_params() {
1810        use crate::context::notification_channel;
1811
1812        let (tx, mut rx) = notification_channel(10);
1813        let router = McpRouter::new().with_notification_sender(tx);
1814
1815        // Send a custom log message
1816        let params = LoggingMessageParams::new(LogLevel::Error)
1817            .with_logger("database")
1818            .with_data(serde_json::json!({
1819                "error": "Connection failed",
1820                "host": "localhost"
1821            }));
1822
1823        let sent = router.log(params);
1824        assert!(sent);
1825
1826        let notification = rx.try_recv().unwrap();
1827        match notification {
1828            ServerNotification::LogMessage(params) => {
1829                assert_eq!(params.level, LogLevel::Error);
1830                assert_eq!(params.logger.as_deref(), Some("database"));
1831                let data = params.data.unwrap();
1832                assert_eq!(
1833                    data.get("error").unwrap().as_str().unwrap(),
1834                    "Connection failed"
1835                );
1836            }
1837            _ => panic!("Expected LogMessage notification"),
1838        }
1839    }
1840
1841    #[tokio::test]
1842    async fn test_log_without_channel_returns_false() {
1843        // Router without notification channel
1844        let router = McpRouter::new();
1845
1846        // Should return false when no channel configured
1847        assert!(!router.log_info("Test"));
1848        assert!(!router.log_warning("Test"));
1849        assert!(!router.log_error("Test"));
1850        assert!(!router.log_debug("Test"));
1851    }
1852
1853    #[tokio::test]
1854    async fn test_logging_capability_with_channel() {
1855        use crate::context::notification_channel;
1856
1857        let (tx, _rx) = notification_channel(10);
1858        let mut router = McpRouter::new().with_notification_sender(tx);
1859
1860        // Initialize and check capabilities
1861        let init_req = RouterRequest {
1862            id: RequestId::Number(0),
1863            inner: McpRequest::Initialize(InitializeParams {
1864                protocol_version: "2025-11-25".to_string(),
1865                capabilities: ClientCapabilities {
1866                    roots: None,
1867                    sampling: None,
1868                    elicitation: None,
1869                },
1870                client_info: Implementation {
1871                    name: "test".to_string(),
1872                    version: "1.0".to_string(),
1873                    ..Default::default()
1874                },
1875            }),
1876            extensions: Extensions::new(),
1877        };
1878        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1879
1880        match resp.inner {
1881            Ok(McpResponse::Initialize(result)) => {
1882                // Should have logging capability when notification channel is set
1883                assert!(result.capabilities.logging.is_some());
1884            }
1885            _ => panic!("Expected Initialize response"),
1886        }
1887    }
1888
1889    #[tokio::test]
1890    async fn test_no_logging_capability_without_channel() {
1891        let mut router = McpRouter::new();
1892
1893        // Initialize and check capabilities
1894        let init_req = RouterRequest {
1895            id: RequestId::Number(0),
1896            inner: McpRequest::Initialize(InitializeParams {
1897                protocol_version: "2025-11-25".to_string(),
1898                capabilities: ClientCapabilities {
1899                    roots: None,
1900                    sampling: None,
1901                    elicitation: None,
1902                },
1903                client_info: Implementation {
1904                    name: "test".to_string(),
1905                    version: "1.0".to_string(),
1906                    ..Default::default()
1907                },
1908            }),
1909            extensions: Extensions::new(),
1910        };
1911        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1912
1913        match resp.inner {
1914            Ok(McpResponse::Initialize(result)) => {
1915                // Should NOT have logging capability without notification channel
1916                assert!(result.capabilities.logging.is_none());
1917            }
1918            _ => panic!("Expected Initialize response"),
1919        }
1920    }
1921
1922    // =========================================================================
1923    // Task Lifecycle Tests
1924    // =========================================================================
1925
1926    #[tokio::test]
1927    async fn test_enqueue_task() {
1928        let add_tool = ToolBuilder::new("add")
1929            .description("Add two numbers")
1930            .handler(|input: AddInput| async move {
1931                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1932            })
1933            .build()
1934            .expect("valid tool name");
1935
1936        let mut router = McpRouter::new().tool(add_tool);
1937        init_router(&mut router).await;
1938
1939        let req = RouterRequest {
1940            id: RequestId::Number(1),
1941            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1942                tool_name: "add".to_string(),
1943                arguments: serde_json::json!({"a": 5, "b": 10}),
1944                ttl: None,
1945            }),
1946            extensions: Extensions::new(),
1947        };
1948
1949        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1950
1951        match resp.inner {
1952            Ok(McpResponse::EnqueueTask(result)) => {
1953                assert!(result.task_id.starts_with("task-"));
1954                assert_eq!(result.status, TaskStatus::Working);
1955            }
1956            _ => panic!("Expected EnqueueTask response"),
1957        }
1958    }
1959
1960    #[tokio::test]
1961    async fn test_list_tasks_empty() {
1962        let mut router = McpRouter::new();
1963        init_router(&mut router).await;
1964
1965        let req = RouterRequest {
1966            id: RequestId::Number(1),
1967            inner: McpRequest::ListTasks(ListTasksParams::default()),
1968            extensions: Extensions::new(),
1969        };
1970
1971        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1972
1973        match resp.inner {
1974            Ok(McpResponse::ListTasks(result)) => {
1975                assert!(result.tasks.is_empty());
1976            }
1977            _ => panic!("Expected ListTasks response"),
1978        }
1979    }
1980
1981    #[tokio::test]
1982    async fn test_task_lifecycle_complete() {
1983        let add_tool = ToolBuilder::new("add")
1984            .description("Add two numbers")
1985            .handler(|input: AddInput| async move {
1986                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1987            })
1988            .build()
1989            .expect("valid tool name");
1990
1991        let mut router = McpRouter::new().tool(add_tool);
1992        init_router(&mut router).await;
1993
1994        // Enqueue task
1995        let req = RouterRequest {
1996            id: RequestId::Number(1),
1997            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1998                tool_name: "add".to_string(),
1999                arguments: serde_json::json!({"a": 7, "b": 8}),
2000                ttl: None,
2001            }),
2002            extensions: Extensions::new(),
2003        };
2004
2005        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2006        let task_id = match resp.inner {
2007            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2008            _ => panic!("Expected EnqueueTask response"),
2009        };
2010
2011        // Wait for task to complete
2012        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
2013
2014        // Get task result
2015        let req = RouterRequest {
2016            id: RequestId::Number(2),
2017            inner: McpRequest::GetTaskResult(GetTaskResultParams {
2018                task_id: task_id.clone(),
2019            }),
2020            extensions: Extensions::new(),
2021        };
2022
2023        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2024
2025        match resp.inner {
2026            Ok(McpResponse::GetTaskResult(result)) => {
2027                assert_eq!(result.task_id, task_id);
2028                assert_eq!(result.status, TaskStatus::Completed);
2029                assert!(result.result.is_some());
2030                assert!(result.error.is_none());
2031
2032                // Check the result content
2033                let tool_result = result.result.unwrap();
2034                match &tool_result.content[0] {
2035                    Content::Text { text, .. } => assert_eq!(text, "15"),
2036                    _ => panic!("Expected text content"),
2037                }
2038            }
2039            _ => panic!("Expected GetTaskResult response"),
2040        }
2041    }
2042
2043    #[tokio::test]
2044    async fn test_task_cancellation() {
2045        // Use a slow tool to test cancellation
2046        let slow_tool = ToolBuilder::new("slow")
2047            .description("Slow tool")
2048            .handler(|_input: serde_json::Value| async move {
2049                tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
2050                Ok(CallToolResult::text("done"))
2051            })
2052            .build()
2053            .expect("valid tool name");
2054
2055        let mut router = McpRouter::new().tool(slow_tool);
2056        init_router(&mut router).await;
2057
2058        // Enqueue task
2059        let req = RouterRequest {
2060            id: RequestId::Number(1),
2061            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2062                tool_name: "slow".to_string(),
2063                arguments: serde_json::json!({}),
2064                ttl: None,
2065            }),
2066            extensions: Extensions::new(),
2067        };
2068
2069        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2070        let task_id = match resp.inner {
2071            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2072            _ => panic!("Expected EnqueueTask response"),
2073        };
2074
2075        // Cancel the task
2076        let req = RouterRequest {
2077            id: RequestId::Number(2),
2078            inner: McpRequest::CancelTask(CancelTaskParams {
2079                task_id: task_id.clone(),
2080                reason: Some("Test cancellation".to_string()),
2081            }),
2082            extensions: Extensions::new(),
2083        };
2084
2085        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2086
2087        match resp.inner {
2088            Ok(McpResponse::CancelTask(result)) => {
2089                assert!(result.cancelled);
2090                assert_eq!(result.status, TaskStatus::Cancelled);
2091            }
2092            _ => panic!("Expected CancelTask response"),
2093        }
2094    }
2095
2096    #[tokio::test]
2097    async fn test_get_task_info() {
2098        let add_tool = ToolBuilder::new("add")
2099            .description("Add two numbers")
2100            .handler(|input: AddInput| async move {
2101                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2102            })
2103            .build()
2104            .expect("valid tool name");
2105
2106        let mut router = McpRouter::new().tool(add_tool);
2107        init_router(&mut router).await;
2108
2109        // Enqueue task
2110        let req = RouterRequest {
2111            id: RequestId::Number(1),
2112            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2113                tool_name: "add".to_string(),
2114                arguments: serde_json::json!({"a": 1, "b": 2}),
2115                ttl: Some(600),
2116            }),
2117            extensions: Extensions::new(),
2118        };
2119
2120        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2121        let task_id = match resp.inner {
2122            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2123            _ => panic!("Expected EnqueueTask response"),
2124        };
2125
2126        // Get task info
2127        let req = RouterRequest {
2128            id: RequestId::Number(2),
2129            inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2130                task_id: task_id.clone(),
2131            }),
2132            extensions: Extensions::new(),
2133        };
2134
2135        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2136
2137        match resp.inner {
2138            Ok(McpResponse::GetTaskInfo(info)) => {
2139                assert_eq!(info.task_id, task_id);
2140                assert!(info.created_at.contains('T')); // ISO 8601
2141                assert_eq!(info.ttl, Some(600));
2142            }
2143            _ => panic!("Expected GetTaskInfo response"),
2144        }
2145    }
2146
2147    #[tokio::test]
2148    async fn test_enqueue_nonexistent_tool() {
2149        let mut router = McpRouter::new();
2150        init_router(&mut router).await;
2151
2152        let req = RouterRequest {
2153            id: RequestId::Number(1),
2154            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2155                tool_name: "nonexistent".to_string(),
2156                arguments: serde_json::json!({}),
2157                ttl: None,
2158            }),
2159            extensions: Extensions::new(),
2160        };
2161
2162        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2163
2164        match resp.inner {
2165            Err(e) => {
2166                assert!(e.message.contains("not found"));
2167            }
2168            _ => panic!("Expected error response"),
2169        }
2170    }
2171
2172    #[tokio::test]
2173    async fn test_get_nonexistent_task() {
2174        let mut router = McpRouter::new();
2175        init_router(&mut router).await;
2176
2177        let req = RouterRequest {
2178            id: RequestId::Number(1),
2179            inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2180                task_id: "task-999".to_string(),
2181            }),
2182            extensions: Extensions::new(),
2183        };
2184
2185        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2186
2187        match resp.inner {
2188            Err(e) => {
2189                assert!(e.message.contains("not found"));
2190            }
2191            _ => panic!("Expected error response"),
2192        }
2193    }
2194
2195    // =========================================================================
2196    // Resource Subscription Tests
2197    // =========================================================================
2198
2199    #[tokio::test]
2200    async fn test_subscribe_to_resource() {
2201        use crate::resource::ResourceBuilder;
2202
2203        let resource = ResourceBuilder::new("file:///test.txt")
2204            .name("Test File")
2205            .text("Hello");
2206
2207        let mut router = McpRouter::new().resource(resource);
2208        init_router(&mut router).await;
2209
2210        // Subscribe to the resource
2211        let req = RouterRequest {
2212            id: RequestId::Number(1),
2213            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2214                uri: "file:///test.txt".to_string(),
2215            }),
2216            extensions: Extensions::new(),
2217        };
2218
2219        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2220
2221        match resp.inner {
2222            Ok(McpResponse::SubscribeResource(_)) => {
2223                // Should be subscribed now
2224                assert!(router.is_subscribed("file:///test.txt"));
2225            }
2226            _ => panic!("Expected SubscribeResource response"),
2227        }
2228    }
2229
2230    #[tokio::test]
2231    async fn test_unsubscribe_from_resource() {
2232        use crate::resource::ResourceBuilder;
2233
2234        let resource = ResourceBuilder::new("file:///test.txt")
2235            .name("Test File")
2236            .text("Hello");
2237
2238        let mut router = McpRouter::new().resource(resource);
2239        init_router(&mut router).await;
2240
2241        // Subscribe first
2242        let req = RouterRequest {
2243            id: RequestId::Number(1),
2244            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2245                uri: "file:///test.txt".to_string(),
2246            }),
2247            extensions: Extensions::new(),
2248        };
2249        let _ = router.ready().await.unwrap().call(req).await.unwrap();
2250        assert!(router.is_subscribed("file:///test.txt"));
2251
2252        // Now unsubscribe
2253        let req = RouterRequest {
2254            id: RequestId::Number(2),
2255            inner: McpRequest::UnsubscribeResource(UnsubscribeResourceParams {
2256                uri: "file:///test.txt".to_string(),
2257            }),
2258            extensions: Extensions::new(),
2259        };
2260
2261        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2262
2263        match resp.inner {
2264            Ok(McpResponse::UnsubscribeResource(_)) => {
2265                // Should no longer be subscribed
2266                assert!(!router.is_subscribed("file:///test.txt"));
2267            }
2268            _ => panic!("Expected UnsubscribeResource response"),
2269        }
2270    }
2271
2272    #[tokio::test]
2273    async fn test_subscribe_nonexistent_resource() {
2274        let mut router = McpRouter::new();
2275        init_router(&mut router).await;
2276
2277        let req = RouterRequest {
2278            id: RequestId::Number(1),
2279            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2280                uri: "file:///nonexistent.txt".to_string(),
2281            }),
2282            extensions: Extensions::new(),
2283        };
2284
2285        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2286
2287        match resp.inner {
2288            Err(e) => {
2289                assert!(e.message.contains("not found"));
2290            }
2291            _ => panic!("Expected error response"),
2292        }
2293    }
2294
2295    #[tokio::test]
2296    async fn test_notify_resource_updated() {
2297        use crate::context::notification_channel;
2298        use crate::resource::ResourceBuilder;
2299
2300        let (tx, mut rx) = notification_channel(10);
2301
2302        let resource = ResourceBuilder::new("file:///test.txt")
2303            .name("Test File")
2304            .text("Hello");
2305
2306        let router = McpRouter::new()
2307            .resource(resource)
2308            .with_notification_sender(tx);
2309
2310        // First, manually subscribe (simulate subscription)
2311        router.subscribe("file:///test.txt");
2312
2313        // Now notify
2314        let sent = router.notify_resource_updated("file:///test.txt");
2315        assert!(sent);
2316
2317        // Check the notification was sent
2318        let notification = rx.try_recv().unwrap();
2319        match notification {
2320            ServerNotification::ResourceUpdated { uri } => {
2321                assert_eq!(uri, "file:///test.txt");
2322            }
2323            _ => panic!("Expected ResourceUpdated notification"),
2324        }
2325    }
2326
2327    #[tokio::test]
2328    async fn test_notify_resource_updated_not_subscribed() {
2329        use crate::context::notification_channel;
2330        use crate::resource::ResourceBuilder;
2331
2332        let (tx, mut rx) = notification_channel(10);
2333
2334        let resource = ResourceBuilder::new("file:///test.txt")
2335            .name("Test File")
2336            .text("Hello");
2337
2338        let router = McpRouter::new()
2339            .resource(resource)
2340            .with_notification_sender(tx);
2341
2342        // Try to notify without subscribing
2343        let sent = router.notify_resource_updated("file:///test.txt");
2344        assert!(!sent); // Should not send because not subscribed
2345
2346        // Channel should be empty
2347        assert!(rx.try_recv().is_err());
2348    }
2349
2350    #[tokio::test]
2351    async fn test_notify_resources_list_changed() {
2352        use crate::context::notification_channel;
2353
2354        let (tx, mut rx) = notification_channel(10);
2355        let router = McpRouter::new().with_notification_sender(tx);
2356
2357        let sent = router.notify_resources_list_changed();
2358        assert!(sent);
2359
2360        let notification = rx.try_recv().unwrap();
2361        match notification {
2362            ServerNotification::ResourcesListChanged => {}
2363            _ => panic!("Expected ResourcesListChanged notification"),
2364        }
2365    }
2366
2367    #[tokio::test]
2368    async fn test_subscribed_uris() {
2369        use crate::resource::ResourceBuilder;
2370
2371        let resource1 = ResourceBuilder::new("file:///a.txt").name("A").text("A");
2372
2373        let resource2 = ResourceBuilder::new("file:///b.txt").name("B").text("B");
2374
2375        let router = McpRouter::new().resource(resource1).resource(resource2);
2376
2377        // Subscribe to both
2378        router.subscribe("file:///a.txt");
2379        router.subscribe("file:///b.txt");
2380
2381        let uris = router.subscribed_uris();
2382        assert_eq!(uris.len(), 2);
2383        assert!(uris.contains(&"file:///a.txt".to_string()));
2384        assert!(uris.contains(&"file:///b.txt".to_string()));
2385    }
2386
2387    #[tokio::test]
2388    async fn test_subscription_capability_advertised() {
2389        use crate::resource::ResourceBuilder;
2390
2391        let resource = ResourceBuilder::new("file:///test.txt")
2392            .name("Test")
2393            .text("Hello");
2394
2395        let mut router = McpRouter::new().resource(resource);
2396
2397        // Initialize and check capabilities
2398        let init_req = RouterRequest {
2399            id: RequestId::Number(0),
2400            inner: McpRequest::Initialize(InitializeParams {
2401                protocol_version: "2025-11-25".to_string(),
2402                capabilities: ClientCapabilities {
2403                    roots: None,
2404                    sampling: None,
2405                    elicitation: None,
2406                },
2407                client_info: Implementation {
2408                    name: "test".to_string(),
2409                    version: "1.0".to_string(),
2410                    ..Default::default()
2411                },
2412            }),
2413            extensions: Extensions::new(),
2414        };
2415        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2416
2417        match resp.inner {
2418            Ok(McpResponse::Initialize(result)) => {
2419                // Should have resources capability with subscribe enabled
2420                let resources_cap = result.capabilities.resources.unwrap();
2421                assert!(resources_cap.subscribe);
2422            }
2423            _ => panic!("Expected Initialize response"),
2424        }
2425    }
2426
2427    #[tokio::test]
2428    async fn test_completion_handler() {
2429        let router = McpRouter::new()
2430            .server_info("test", "1.0")
2431            .completion_handler(|params: CompleteParams| async move {
2432                // Return suggestions based on the argument value
2433                let prefix = &params.argument.value;
2434                let suggestions: Vec<String> = vec!["alpha", "beta", "gamma"]
2435                    .into_iter()
2436                    .filter(|s| s.starts_with(prefix))
2437                    .map(String::from)
2438                    .collect();
2439                Ok(CompleteResult::new(suggestions))
2440            });
2441
2442        // Initialize
2443        let init_req = RouterRequest {
2444            id: RequestId::Number(0),
2445            inner: McpRequest::Initialize(InitializeParams {
2446                protocol_version: "2025-11-25".to_string(),
2447                capabilities: ClientCapabilities::default(),
2448                client_info: Implementation {
2449                    name: "test".to_string(),
2450                    version: "1.0".to_string(),
2451                    ..Default::default()
2452                },
2453            }),
2454            extensions: Extensions::new(),
2455        };
2456        let resp = router
2457            .clone()
2458            .ready()
2459            .await
2460            .unwrap()
2461            .call(init_req)
2462            .await
2463            .unwrap();
2464
2465        // Check that completions capability is advertised
2466        match resp.inner {
2467            Ok(McpResponse::Initialize(result)) => {
2468                assert!(result.capabilities.completions.is_some());
2469            }
2470            _ => panic!("Expected Initialize response"),
2471        }
2472
2473        // Send initialized notification
2474        router.handle_notification(McpNotification::Initialized);
2475
2476        // Test completion request
2477        let complete_req = RouterRequest {
2478            id: RequestId::Number(1),
2479            inner: McpRequest::Complete(CompleteParams {
2480                reference: CompletionReference::prompt("test-prompt"),
2481                argument: CompletionArgument::new("query", "al"),
2482            }),
2483            extensions: Extensions::new(),
2484        };
2485        let resp = router
2486            .clone()
2487            .ready()
2488            .await
2489            .unwrap()
2490            .call(complete_req)
2491            .await
2492            .unwrap();
2493
2494        match resp.inner {
2495            Ok(McpResponse::Complete(result)) => {
2496                assert_eq!(result.completion.values, vec!["alpha"]);
2497            }
2498            _ => panic!("Expected Complete response"),
2499        }
2500    }
2501
2502    #[tokio::test]
2503    async fn test_completion_without_handler_returns_empty() {
2504        let router = McpRouter::new().server_info("test", "1.0");
2505
2506        // Initialize
2507        let init_req = RouterRequest {
2508            id: RequestId::Number(0),
2509            inner: McpRequest::Initialize(InitializeParams {
2510                protocol_version: "2025-11-25".to_string(),
2511                capabilities: ClientCapabilities::default(),
2512                client_info: Implementation {
2513                    name: "test".to_string(),
2514                    version: "1.0".to_string(),
2515                    ..Default::default()
2516                },
2517            }),
2518            extensions: Extensions::new(),
2519        };
2520        let resp = router
2521            .clone()
2522            .ready()
2523            .await
2524            .unwrap()
2525            .call(init_req)
2526            .await
2527            .unwrap();
2528
2529        // Check that completions capability is NOT advertised
2530        match resp.inner {
2531            Ok(McpResponse::Initialize(result)) => {
2532                assert!(result.capabilities.completions.is_none());
2533            }
2534            _ => panic!("Expected Initialize response"),
2535        }
2536
2537        // Send initialized notification
2538        router.handle_notification(McpNotification::Initialized);
2539
2540        // Test completion request still works but returns empty
2541        let complete_req = RouterRequest {
2542            id: RequestId::Number(1),
2543            inner: McpRequest::Complete(CompleteParams {
2544                reference: CompletionReference::prompt("test-prompt"),
2545                argument: CompletionArgument::new("query", "al"),
2546            }),
2547            extensions: Extensions::new(),
2548        };
2549        let resp = router
2550            .clone()
2551            .ready()
2552            .await
2553            .unwrap()
2554            .call(complete_req)
2555            .await
2556            .unwrap();
2557
2558        match resp.inner {
2559            Ok(McpResponse::Complete(result)) => {
2560                assert!(result.completion.values.is_empty());
2561            }
2562            _ => panic!("Expected Complete response"),
2563        }
2564    }
2565}