Skip to main content

tower_mcp/
router.rs

1//! MCP Router - routes requests to tools, resources, and prompts
2//!
3//! The router implements Tower's `Service` trait, making it composable with
4//! standard tower middleware.
5
6use std::any::{Any, TypeId};
7use std::collections::{HashMap, HashSet};
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::{Arc, RwLock};
11use std::task::{Context, Poll};
12
13use tower_service::Service;
14
15use crate::async_task::TaskStore;
16use crate::context::{
17    CancellationToken, ClientRequesterHandle, NotificationSender, RequestContext,
18    ServerNotification,
19};
20use crate::error::{Error, JsonRpcError, Result};
21use crate::prompt::Prompt;
22use crate::protocol::*;
23use crate::resource::{Resource, ResourceTemplate};
24use crate::session::SessionState;
25use crate::tool::Tool;
26
27/// Type alias for completion handler function
28pub type CompletionHandler = Arc<
29    dyn Fn(CompleteParams) -> Pin<Box<dyn Future<Output = Result<CompleteResult>> + Send>>
30        + Send
31        + Sync,
32>;
33
34/// MCP Router that dispatches requests to registered handlers
35///
36/// Implements `tower::Service<McpRequest>` for middleware composition.
37///
38/// # Example
39///
40/// ```rust
41/// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
42/// use schemars::JsonSchema;
43/// use serde::Deserialize;
44///
45/// #[derive(Debug, Deserialize, JsonSchema)]
46/// struct Input { value: String }
47///
48/// let tool = ToolBuilder::new("echo")
49///     .description("Echo input")
50///     .handler(|i: Input| async move { Ok(CallToolResult::text(i.value)) })
51///     .build()
52///     .unwrap();
53///
54/// let router = McpRouter::new()
55///     .server_info("my-server", "1.0.0")
56///     .tool(tool);
57/// ```
58#[derive(Clone)]
59pub struct McpRouter {
60    inner: Arc<McpRouterInner>,
61    session: SessionState,
62}
63
64impl std::fmt::Debug for McpRouter {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("McpRouter")
67            .field("server_name", &self.inner.server_name)
68            .field("server_version", &self.inner.server_version)
69            .field("tools_count", &self.inner.tools.len())
70            .field("resources_count", &self.inner.resources.len())
71            .field("prompts_count", &self.inner.prompts.len())
72            .field("session_phase", &self.session.phase())
73            .finish()
74    }
75}
76
77/// Inner configuration that is shared across clones
78#[derive(Clone)]
79struct McpRouterInner {
80    server_name: String,
81    server_version: String,
82    /// Human-readable title for the server
83    server_title: Option<String>,
84    /// Description of the server
85    server_description: Option<String>,
86    /// Icons for the server
87    server_icons: Option<Vec<ToolIcon>>,
88    /// URL of the server's website
89    server_website_url: Option<String>,
90    instructions: Option<String>,
91    tools: HashMap<String, Arc<Tool>>,
92    resources: HashMap<String, Arc<Resource>>,
93    /// Resource templates for dynamic resource matching (keyed by uri_template)
94    resource_templates: Vec<Arc<ResourceTemplate>>,
95    prompts: HashMap<String, Arc<Prompt>>,
96    /// In-flight requests for cancellation tracking (shared across clones)
97    in_flight: Arc<RwLock<HashMap<RequestId, CancellationToken>>>,
98    /// Channel for sending notifications to connected clients
99    notification_tx: Option<NotificationSender>,
100    /// Handle for sending requests to the client (for sampling, etc.)
101    client_requester: Option<ClientRequesterHandle>,
102    /// Task store for async operations
103    task_store: TaskStore,
104    /// Subscribed resource URIs
105    subscriptions: Arc<RwLock<HashSet<String>>>,
106    /// Handler for completion requests
107    completion_handler: Option<CompletionHandler>,
108}
109
110impl McpRouter {
111    /// Create a new MCP router
112    pub fn new() -> Self {
113        Self {
114            inner: Arc::new(McpRouterInner {
115                server_name: "tower-mcp".to_string(),
116                server_version: env!("CARGO_PKG_VERSION").to_string(),
117                server_title: None,
118                server_description: None,
119                server_icons: None,
120                server_website_url: None,
121                instructions: None,
122                tools: HashMap::new(),
123                resources: HashMap::new(),
124                resource_templates: Vec::new(),
125                prompts: HashMap::new(),
126                in_flight: Arc::new(RwLock::new(HashMap::new())),
127                notification_tx: None,
128                client_requester: None,
129                task_store: TaskStore::new(),
130                subscriptions: Arc::new(RwLock::new(HashSet::new())),
131                completion_handler: None,
132            }),
133            session: SessionState::new(),
134        }
135    }
136
137    /// Get access to the task store for async operations
138    pub fn task_store(&self) -> &TaskStore {
139        &self.inner.task_store
140    }
141
142    /// Set the notification sender for progress reporting
143    ///
144    /// This is typically called by the transport layer to receive notifications.
145    pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
146        Arc::make_mut(&mut self.inner).notification_tx = Some(tx);
147        self
148    }
149
150    /// Get the notification sender (if configured)
151    pub fn notification_sender(&self) -> Option<&NotificationSender> {
152        self.inner.notification_tx.as_ref()
153    }
154
155    /// Set the client requester for server-to-client requests (sampling, etc.)
156    ///
157    /// This is typically called by bidirectional transports (WebSocket, stdio)
158    /// to enable tool handlers to send requests to the client.
159    pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
160        Arc::make_mut(&mut self.inner).client_requester = Some(requester);
161        self
162    }
163
164    /// Get the client requester (if configured)
165    pub fn client_requester(&self) -> Option<&ClientRequesterHandle> {
166        self.inner.client_requester.as_ref()
167    }
168
169    /// Create a request context for tracking a request
170    ///
171    /// This registers the request for cancellation tracking and sets up
172    /// progress reporting and client requests if configured.
173    pub fn create_context(
174        &self,
175        request_id: RequestId,
176        progress_token: Option<ProgressToken>,
177    ) -> RequestContext {
178        let ctx = RequestContext::new(request_id.clone());
179
180        // Set up progress token if provided
181        let ctx = if let Some(token) = progress_token {
182            ctx.with_progress_token(token)
183        } else {
184            ctx
185        };
186
187        // Set up notification sender if configured
188        let ctx = if let Some(tx) = &self.inner.notification_tx {
189            ctx.with_notification_sender(tx.clone())
190        } else {
191            ctx
192        };
193
194        // Set up client requester if configured (for sampling support)
195        let ctx = if let Some(requester) = &self.inner.client_requester {
196            ctx.with_client_requester(requester.clone())
197        } else {
198            ctx
199        };
200
201        // Register for cancellation tracking
202        let token = ctx.cancellation_token();
203        if let Ok(mut in_flight) = self.inner.in_flight.write() {
204            in_flight.insert(request_id, token);
205        }
206
207        ctx
208    }
209
210    /// Remove a request from tracking (called when request completes)
211    pub fn complete_request(&self, request_id: &RequestId) {
212        if let Ok(mut in_flight) = self.inner.in_flight.write() {
213            in_flight.remove(request_id);
214        }
215    }
216
217    /// Cancel a tracked request
218    fn cancel_request(&self, request_id: &RequestId) -> bool {
219        if let Ok(in_flight) = self.inner.in_flight.read()
220            && let Some(token) = in_flight.get(request_id)
221        {
222            token.cancel();
223            return true;
224        }
225        false
226    }
227
228    /// Set server info
229    pub fn server_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
230        let inner = Arc::make_mut(&mut self.inner);
231        inner.server_name = name.into();
232        inner.server_version = version.into();
233        self
234    }
235
236    /// Set instructions for LLMs describing how to use this server
237    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
238        Arc::make_mut(&mut self.inner).instructions = Some(instructions.into());
239        self
240    }
241
242    /// Set a human-readable title for the server
243    pub fn server_title(mut self, title: impl Into<String>) -> Self {
244        Arc::make_mut(&mut self.inner).server_title = Some(title.into());
245        self
246    }
247
248    /// Set the server description
249    pub fn server_description(mut self, description: impl Into<String>) -> Self {
250        Arc::make_mut(&mut self.inner).server_description = Some(description.into());
251        self
252    }
253
254    /// Set icons for the server
255    pub fn server_icons(mut self, icons: Vec<ToolIcon>) -> Self {
256        Arc::make_mut(&mut self.inner).server_icons = Some(icons);
257        self
258    }
259
260    /// Set the server's website URL
261    pub fn server_website_url(mut self, url: impl Into<String>) -> Self {
262        Arc::make_mut(&mut self.inner).server_website_url = Some(url.into());
263        self
264    }
265
266    /// Register a tool
267    pub fn tool(mut self, tool: Tool) -> Self {
268        Arc::make_mut(&mut self.inner)
269            .tools
270            .insert(tool.name.clone(), Arc::new(tool));
271        self
272    }
273
274    /// Register a resource
275    pub fn resource(mut self, resource: Resource) -> Self {
276        Arc::make_mut(&mut self.inner)
277            .resources
278            .insert(resource.uri.clone(), Arc::new(resource));
279        self
280    }
281
282    /// Register a resource template
283    ///
284    /// Resource templates allow dynamic resources to be matched by URI pattern.
285    /// When a client requests a resource URI that doesn't match any static
286    /// resource, the router tries to match it against registered templates.
287    ///
288    /// # Example
289    ///
290    /// ```rust
291    /// use tower_mcp::{McpRouter, ResourceTemplateBuilder};
292    /// use tower_mcp::protocol::{ReadResourceResult, ResourceContent};
293    /// use std::collections::HashMap;
294    ///
295    /// let template = ResourceTemplateBuilder::new("file:///{path}")
296    ///     .name("Project Files")
297    ///     .handler(|uri: String, vars: HashMap<String, String>| async move {
298    ///         let path = vars.get("path").unwrap_or(&String::new()).clone();
299    ///         Ok(ReadResourceResult {
300    ///             contents: vec![ResourceContent {
301    ///                 uri,
302    ///                 mime_type: Some("text/plain".to_string()),
303    ///                 text: Some(format!("Contents of {}", path)),
304    ///                 blob: None,
305    ///             }],
306    ///         })
307    ///     });
308    ///
309    /// let router = McpRouter::new()
310    ///     .resource_template(template);
311    /// ```
312    pub fn resource_template(mut self, template: ResourceTemplate) -> Self {
313        Arc::make_mut(&mut self.inner)
314            .resource_templates
315            .push(Arc::new(template));
316        self
317    }
318
319    /// Register a prompt
320    pub fn prompt(mut self, prompt: Prompt) -> Self {
321        Arc::make_mut(&mut self.inner)
322            .prompts
323            .insert(prompt.name.clone(), Arc::new(prompt));
324        self
325    }
326
327    /// Register multiple tools at once.
328    ///
329    /// # Example
330    ///
331    /// ```rust
332    /// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
333    /// use schemars::JsonSchema;
334    /// use serde::Deserialize;
335    ///
336    /// #[derive(Debug, Deserialize, JsonSchema)]
337    /// struct Input { value: String }
338    ///
339    /// let tools = vec![
340    ///     ToolBuilder::new("a")
341    ///         .description("Tool A")
342    ///         .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
343    ///         .build().unwrap(),
344    ///     ToolBuilder::new("b")
345    ///         .description("Tool B")
346    ///         .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
347    ///         .build().unwrap(),
348    /// ];
349    ///
350    /// let router = McpRouter::new().tools(tools);
351    /// ```
352    pub fn tools(self, tools: impl IntoIterator<Item = Tool>) -> Self {
353        tools
354            .into_iter()
355            .fold(self, |router, tool| router.tool(tool))
356    }
357
358    /// Register multiple resources at once.
359    ///
360    /// # Example
361    ///
362    /// ```rust
363    /// use tower_mcp::{McpRouter, ResourceBuilder};
364    ///
365    /// let resources = vec![
366    ///     ResourceBuilder::new("file:///a.txt")
367    ///         .name("File A")
368    ///         .text("contents a"),
369    ///     ResourceBuilder::new("file:///b.txt")
370    ///         .name("File B")
371    ///         .text("contents b"),
372    /// ];
373    ///
374    /// let router = McpRouter::new().resources(resources);
375    /// ```
376    pub fn resources(self, resources: impl IntoIterator<Item = Resource>) -> Self {
377        resources
378            .into_iter()
379            .fold(self, |router, resource| router.resource(resource))
380    }
381
382    /// Register multiple prompts at once.
383    ///
384    /// # Example
385    ///
386    /// ```rust
387    /// use tower_mcp::{McpRouter, PromptBuilder};
388    ///
389    /// let prompts = vec![
390    ///     PromptBuilder::new("greet")
391    ///         .description("Greet someone")
392    ///         .user_message("Hello!"),
393    ///     PromptBuilder::new("farewell")
394    ///         .description("Say goodbye")
395    ///         .user_message("Goodbye!"),
396    /// ];
397    ///
398    /// let router = McpRouter::new().prompts(prompts);
399    /// ```
400    pub fn prompts(self, prompts: impl IntoIterator<Item = Prompt>) -> Self {
401        prompts
402            .into_iter()
403            .fold(self, |router, prompt| router.prompt(prompt))
404    }
405
406    /// Register a completion handler for `completion/complete` requests.
407    ///
408    /// The handler receives `CompleteParams` containing the reference (prompt or resource)
409    /// and the argument being completed, and should return completion suggestions.
410    ///
411    /// # Example
412    ///
413    /// ```rust
414    /// use tower_mcp::{McpRouter, CompleteResult};
415    /// use tower_mcp::protocol::{CompleteParams, CompletionReference};
416    ///
417    /// let router = McpRouter::new()
418    ///     .completion_handler(|params: CompleteParams| async move {
419    ///         // Provide completions based on the reference and argument
420    ///         match params.reference {
421    ///             CompletionReference::Prompt { name } => {
422    ///                 // Return prompt argument completions
423    ///                 Ok(CompleteResult::new(vec!["option1".to_string(), "option2".to_string()]))
424    ///             }
425    ///             CompletionReference::Resource { uri } => {
426    ///                 // Return resource URI completions
427    ///                 Ok(CompleteResult::new(vec![]))
428    ///             }
429    ///         }
430    ///     });
431    /// ```
432    pub fn completion_handler<F, Fut>(mut self, handler: F) -> Self
433    where
434        F: Fn(CompleteParams) -> Fut + Send + Sync + 'static,
435        Fut: Future<Output = Result<CompleteResult>> + Send + 'static,
436    {
437        Arc::make_mut(&mut self.inner).completion_handler =
438            Some(Arc::new(move |params| Box::pin(handler(params))));
439        self
440    }
441
442    /// Get access to the session state
443    pub fn session(&self) -> &SessionState {
444        &self.session
445    }
446
447    /// Send a log message notification to the client
448    ///
449    /// This sends a `notifications/message` notification with the given parameters.
450    /// Returns `true` if the notification was sent, `false` if no notification channel
451    /// is configured.
452    ///
453    /// # Example
454    ///
455    /// ```rust,ignore
456    /// use tower_mcp::protocol::{LogLevel, LoggingMessageParams};
457    ///
458    /// // Simple info message
459    /// router.log(LoggingMessageParams::new(LogLevel::Info).with_data(
460    ///     serde_json::json!({"message": "Operation completed"})
461    /// ));
462    ///
463    /// // Error with logger name
464    /// router.log(LoggingMessageParams::new(LogLevel::Error)
465    ///     .with_logger("database")
466    ///     .with_data(serde_json::json!({"error": "Connection failed"})));
467    /// ```
468    pub fn log(&self, params: LoggingMessageParams) -> bool {
469        if let Some(tx) = &self.inner.notification_tx
470            && tx.try_send(ServerNotification::LogMessage(params)).is_ok()
471        {
472            return true;
473        }
474        false
475    }
476
477    /// Send an info-level log message
478    ///
479    /// Convenience method for sending an info log with optional data.
480    pub fn log_info(&self, message: &str) -> bool {
481        self.log(
482            LoggingMessageParams::new(LogLevel::Info)
483                .with_data(serde_json::json!({ "message": message })),
484        )
485    }
486
487    /// Send a warning-level log message
488    pub fn log_warning(&self, message: &str) -> bool {
489        self.log(
490            LoggingMessageParams::new(LogLevel::Warning)
491                .with_data(serde_json::json!({ "message": message })),
492        )
493    }
494
495    /// Send an error-level log message
496    pub fn log_error(&self, message: &str) -> bool {
497        self.log(
498            LoggingMessageParams::new(LogLevel::Error)
499                .with_data(serde_json::json!({ "message": message })),
500        )
501    }
502
503    /// Send a debug-level log message
504    pub fn log_debug(&self, message: &str) -> bool {
505        self.log(
506            LoggingMessageParams::new(LogLevel::Debug)
507                .with_data(serde_json::json!({ "message": message })),
508        )
509    }
510
511    /// Check if a resource URI is currently subscribed
512    pub fn is_subscribed(&self, uri: &str) -> bool {
513        if let Ok(subs) = self.inner.subscriptions.read() {
514            return subs.contains(uri);
515        }
516        false
517    }
518
519    /// Get a list of all subscribed resource URIs
520    pub fn subscribed_uris(&self) -> Vec<String> {
521        if let Ok(subs) = self.inner.subscriptions.read() {
522            return subs.iter().cloned().collect();
523        }
524        Vec::new()
525    }
526
527    /// Subscribe to a resource URI
528    fn subscribe(&self, uri: &str) -> bool {
529        if let Ok(mut subs) = self.inner.subscriptions.write() {
530            return subs.insert(uri.to_string());
531        }
532        false
533    }
534
535    /// Unsubscribe from a resource URI
536    fn unsubscribe(&self, uri: &str) -> bool {
537        if let Ok(mut subs) = self.inner.subscriptions.write() {
538            return subs.remove(uri);
539        }
540        false
541    }
542
543    /// Notify clients that a subscribed resource has been updated
544    ///
545    /// Only sends the notification if the resource is currently subscribed.
546    /// Returns `true` if the notification was sent.
547    pub fn notify_resource_updated(&self, uri: &str) -> bool {
548        // Only notify if the resource is subscribed
549        if !self.is_subscribed(uri) {
550            return false;
551        }
552
553        if let Some(tx) = &self.inner.notification_tx
554            && tx
555                .try_send(ServerNotification::ResourceUpdated {
556                    uri: uri.to_string(),
557                })
558                .is_ok()
559        {
560            return true;
561        }
562        false
563    }
564
565    /// Notify clients that the list of available resources has changed
566    ///
567    /// Returns `true` if the notification was sent.
568    pub fn notify_resources_list_changed(&self) -> bool {
569        if let Some(tx) = &self.inner.notification_tx
570            && tx
571                .try_send(ServerNotification::ResourcesListChanged)
572                .is_ok()
573        {
574            return true;
575        }
576        false
577    }
578
579    /// Get server capabilities based on registered handlers
580    fn capabilities(&self) -> ServerCapabilities {
581        let has_resources =
582            !self.inner.resources.is_empty() || !self.inner.resource_templates.is_empty();
583
584        ServerCapabilities {
585            tools: if self.inner.tools.is_empty() {
586                None
587            } else {
588                Some(ToolsCapability::default())
589            },
590            resources: if has_resources {
591                Some(ResourcesCapability {
592                    subscribe: true,
593                    ..Default::default()
594                })
595            } else {
596                None
597            },
598            prompts: if self.inner.prompts.is_empty() {
599                None
600            } else {
601                Some(PromptsCapability::default())
602            },
603            // Always advertise logging capability when notification channel is configured
604            logging: if self.inner.notification_tx.is_some() {
605                Some(LoggingCapability::default())
606            } else {
607                None
608            },
609            // Tasks capability is always available
610            tasks: Some(TasksCapability::default()),
611            // Completions capability when a handler is registered
612            completions: if self.inner.completion_handler.is_some() {
613                Some(CompletionsCapability::default())
614            } else {
615                None
616            },
617        }
618    }
619
620    /// Handle an MCP request
621    async fn handle(&self, request_id: RequestId, request: McpRequest) -> Result<McpResponse> {
622        // Enforce session state - reject requests before initialization
623        let method = request.method_name();
624        if !self.session.is_request_allowed(method) {
625            tracing::warn!(
626                method = %method,
627                phase = ?self.session.phase(),
628                "Request rejected: session not initialized"
629            );
630            return Err(Error::JsonRpc(JsonRpcError::invalid_request(format!(
631                "Session not initialized. Only 'initialize' and 'ping' are allowed before initialization. Got: {}",
632                method
633            ))));
634        }
635
636        match request {
637            McpRequest::Initialize(params) => {
638                tracing::info!(
639                    client = %params.client_info.name,
640                    version = %params.client_info.version,
641                    "Client initializing"
642                );
643
644                // Protocol version negotiation: respond with same version if supported,
645                // otherwise respond with our latest supported version
646                let protocol_version = if crate::protocol::SUPPORTED_PROTOCOL_VERSIONS
647                    .contains(&params.protocol_version.as_str())
648                {
649                    params.protocol_version
650                } else {
651                    crate::protocol::LATEST_PROTOCOL_VERSION.to_string()
652                };
653
654                // Transition session state to Initializing
655                self.session.mark_initializing();
656
657                Ok(McpResponse::Initialize(InitializeResult {
658                    protocol_version,
659                    capabilities: self.capabilities(),
660                    server_info: Implementation {
661                        name: self.inner.server_name.clone(),
662                        version: self.inner.server_version.clone(),
663                        title: self.inner.server_title.clone(),
664                        description: self.inner.server_description.clone(),
665                        icons: self.inner.server_icons.clone(),
666                        website_url: self.inner.server_website_url.clone(),
667                    },
668                    instructions: self.inner.instructions.clone(),
669                }))
670            }
671
672            McpRequest::ListTools(_params) => {
673                let tools: Vec<ToolDefinition> =
674                    self.inner.tools.values().map(|t| t.definition()).collect();
675
676                Ok(McpResponse::ListTools(ListToolsResult {
677                    tools,
678                    next_cursor: None,
679                }))
680            }
681
682            McpRequest::CallTool(params) => {
683                let tool =
684                    self.inner.tools.get(&params.name).ok_or_else(|| {
685                        Error::JsonRpc(JsonRpcError::method_not_found(&params.name))
686                    })?;
687
688                // Extract progress token from request metadata
689                let progress_token = params.meta.and_then(|m| m.progress_token);
690                let ctx = self.create_context(request_id, progress_token);
691
692                tracing::debug!(tool = %params.name, "Calling tool");
693                let result = tool.call_with_context(ctx, params.arguments).await?;
694
695                Ok(McpResponse::CallTool(result))
696            }
697
698            McpRequest::ListResources(_params) => {
699                let resources: Vec<ResourceDefinition> = self
700                    .inner
701                    .resources
702                    .values()
703                    .map(|r| r.definition())
704                    .collect();
705
706                Ok(McpResponse::ListResources(ListResourcesResult {
707                    resources,
708                    next_cursor: None,
709                }))
710            }
711
712            McpRequest::ListResourceTemplates(_params) => {
713                let resource_templates: Vec<ResourceTemplateDefinition> = self
714                    .inner
715                    .resource_templates
716                    .iter()
717                    .map(|t| t.definition())
718                    .collect();
719
720                Ok(McpResponse::ListResourceTemplates(
721                    ListResourceTemplatesResult {
722                        resource_templates,
723                        next_cursor: None,
724                    },
725                ))
726            }
727
728            McpRequest::ReadResource(params) => {
729                // First, try to find a static resource
730                if let Some(resource) = self.inner.resources.get(&params.uri) {
731                    tracing::debug!(uri = %params.uri, "Reading static resource");
732                    let result = resource.read().await?;
733                    return Ok(McpResponse::ReadResource(result));
734                }
735
736                // If no static resource found, try to match against templates
737                for template in &self.inner.resource_templates {
738                    if let Some(variables) = template.match_uri(&params.uri) {
739                        tracing::debug!(
740                            uri = %params.uri,
741                            template = %template.uri_template,
742                            "Reading resource via template"
743                        );
744                        let result = template.read(&params.uri, variables).await?;
745                        return Ok(McpResponse::ReadResource(result));
746                    }
747                }
748
749                // No match found
750                Err(Error::JsonRpc(JsonRpcError::resource_not_found(
751                    &params.uri,
752                )))
753            }
754
755            McpRequest::SubscribeResource(params) => {
756                // Verify the resource exists
757                if !self.inner.resources.contains_key(&params.uri) {
758                    return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
759                        &params.uri,
760                    )));
761                }
762
763                tracing::debug!(uri = %params.uri, "Subscribing to resource");
764                self.subscribe(&params.uri);
765
766                Ok(McpResponse::SubscribeResource(EmptyResult {}))
767            }
768
769            McpRequest::UnsubscribeResource(params) => {
770                // Verify the resource exists
771                if !self.inner.resources.contains_key(&params.uri) {
772                    return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
773                        &params.uri,
774                    )));
775                }
776
777                tracing::debug!(uri = %params.uri, "Unsubscribing from resource");
778                self.unsubscribe(&params.uri);
779
780                Ok(McpResponse::UnsubscribeResource(EmptyResult {}))
781            }
782
783            McpRequest::ListPrompts(_params) => {
784                let prompts: Vec<PromptDefinition> = self
785                    .inner
786                    .prompts
787                    .values()
788                    .map(|p| p.definition())
789                    .collect();
790
791                Ok(McpResponse::ListPrompts(ListPromptsResult {
792                    prompts,
793                    next_cursor: None,
794                }))
795            }
796
797            McpRequest::GetPrompt(params) => {
798                let prompt = self.inner.prompts.get(&params.name).ok_or_else(|| {
799                    Error::JsonRpc(JsonRpcError::method_not_found(&format!(
800                        "Prompt not found: {}",
801                        params.name
802                    )))
803                })?;
804
805                tracing::debug!(name = %params.name, "Getting prompt");
806                let result = prompt.get(params.arguments).await?;
807
808                Ok(McpResponse::GetPrompt(result))
809            }
810
811            McpRequest::Ping => Ok(McpResponse::Pong(EmptyResult {})),
812
813            McpRequest::EnqueueTask(params) => {
814                // Verify the tool exists
815                let tool = self.inner.tools.get(&params.tool_name).ok_or_else(|| {
816                    Error::JsonRpc(JsonRpcError::method_not_found(&format!(
817                        "Tool not found: {}",
818                        params.tool_name
819                    )))
820                })?;
821
822                // Create the task
823                let (task_id, cancellation_token) = self.inner.task_store.create_task(
824                    &params.tool_name,
825                    params.arguments.clone(),
826                    params.ttl,
827                );
828
829                tracing::info!(task_id = %task_id, tool = %params.tool_name, "Enqueued async task");
830
831                // Create a context for the async task execution
832                let ctx = self.create_context(request_id, None);
833
834                // Spawn the task execution in the background
835                let task_store = self.inner.task_store.clone();
836                let tool = tool.clone();
837                let arguments = params.arguments;
838                let task_id_clone = task_id.clone();
839
840                tokio::spawn(async move {
841                    // Check for cancellation before starting
842                    if cancellation_token.is_cancelled() {
843                        tracing::debug!(task_id = %task_id_clone, "Task cancelled before execution");
844                        return;
845                    }
846
847                    // Execute the tool
848                    match tool.call_with_context(ctx, arguments).await {
849                        Ok(result) => {
850                            if cancellation_token.is_cancelled() {
851                                tracing::debug!(task_id = %task_id_clone, "Task cancelled during execution");
852                            } else {
853                                task_store.complete_task(&task_id_clone, result);
854                                tracing::debug!(task_id = %task_id_clone, "Task completed successfully");
855                            }
856                        }
857                        Err(e) => {
858                            task_store.fail_task(&task_id_clone, &e.to_string());
859                            tracing::warn!(task_id = %task_id_clone, error = %e, "Task failed");
860                        }
861                    }
862                });
863
864                Ok(McpResponse::EnqueueTask(EnqueueTaskResult {
865                    task_id,
866                    status: TaskStatus::Working,
867                    poll_interval: Some(2),
868                }))
869            }
870
871            McpRequest::ListTasks(params) => {
872                let tasks = self.inner.task_store.list_tasks(params.status);
873
874                Ok(McpResponse::ListTasks(ListTasksResult {
875                    tasks,
876                    next_cursor: None,
877                }))
878            }
879
880            McpRequest::GetTaskInfo(params) => {
881                let task = self
882                    .inner
883                    .task_store
884                    .get_task(&params.task_id)
885                    .ok_or_else(|| {
886                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
887                            "Task not found: {}",
888                            params.task_id
889                        )))
890                    })?;
891
892                Ok(McpResponse::GetTaskInfo(task))
893            }
894
895            McpRequest::GetTaskResult(params) => {
896                let (status, result, error) = self
897                    .inner
898                    .task_store
899                    .get_task_full(&params.task_id)
900                    .ok_or_else(|| {
901                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
902                            "Task not found: {}",
903                            params.task_id
904                        )))
905                    })?;
906
907                Ok(McpResponse::GetTaskResult(GetTaskResultResult {
908                    task_id: params.task_id,
909                    status,
910                    result,
911                    error,
912                }))
913            }
914
915            McpRequest::CancelTask(params) => {
916                let status = self
917                    .inner
918                    .task_store
919                    .cancel_task(&params.task_id, params.reason.as_deref())
920                    .ok_or_else(|| {
921                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
922                            "Task not found: {}",
923                            params.task_id
924                        )))
925                    })?;
926
927                let cancelled = status == TaskStatus::Cancelled;
928
929                Ok(McpResponse::CancelTask(CancelTaskResult {
930                    cancelled,
931                    status,
932                }))
933            }
934
935            McpRequest::SetLoggingLevel(params) => {
936                // Store the log level for filtering outgoing log notifications
937                // For now, we just accept the request - actual filtering would be
938                // implemented in the notification sending logic
939                tracing::debug!(level = ?params.level, "Client set logging level");
940                Ok(McpResponse::SetLoggingLevel(EmptyResult {}))
941            }
942
943            McpRequest::Complete(params) => {
944                tracing::debug!(
945                    reference = ?params.reference,
946                    argument = %params.argument.name,
947                    "Completion request"
948                );
949
950                // Delegate to registered completion handler if available
951                if let Some(ref handler) = self.inner.completion_handler {
952                    let result = handler(params).await?;
953                    Ok(McpResponse::Complete(result))
954                } else {
955                    // No completion handler registered, return empty completions
956                    Ok(McpResponse::Complete(CompleteResult::new(vec![])))
957                }
958            }
959
960            McpRequest::Unknown { method, .. } => {
961                Err(Error::JsonRpc(JsonRpcError::method_not_found(&method)))
962            }
963        }
964    }
965
966    /// Handle an MCP notification (no response expected)
967    pub fn handle_notification(&self, notification: McpNotification) {
968        match notification {
969            McpNotification::Initialized => {
970                if self.session.mark_initialized() {
971                    tracing::info!("Session initialized, entering operation phase");
972                } else {
973                    tracing::warn!(
974                        "Received initialized notification in unexpected state: {:?}",
975                        self.session.phase()
976                    );
977                }
978            }
979            McpNotification::Cancelled(params) => {
980                if self.cancel_request(&params.request_id) {
981                    tracing::info!(
982                        request_id = ?params.request_id,
983                        reason = ?params.reason,
984                        "Request cancelled"
985                    );
986                } else {
987                    tracing::debug!(
988                        request_id = ?params.request_id,
989                        reason = ?params.reason,
990                        "Cancellation requested for unknown request"
991                    );
992                }
993            }
994            McpNotification::Progress(params) => {
995                tracing::trace!(
996                    token = ?params.progress_token,
997                    progress = params.progress,
998                    total = ?params.total,
999                    "Progress notification"
1000                );
1001                // Progress notifications from client are unusual but valid
1002            }
1003            McpNotification::RootsListChanged => {
1004                tracing::info!("Client roots list changed");
1005                // Server should re-request roots if needed
1006                // This is handled by the application layer
1007            }
1008            McpNotification::Unknown { method, .. } => {
1009                tracing::debug!(method = %method, "Unknown notification received");
1010            }
1011        }
1012    }
1013}
1014
1015impl Default for McpRouter {
1016    fn default() -> Self {
1017        Self::new()
1018    }
1019}
1020
1021// =============================================================================
1022// Tower Service implementation
1023// =============================================================================
1024
1025/// A minimal type-map for passing data through middleware.
1026///
1027/// Uses `Arc<dyn Any>` internally so `Clone` is cheap, which is needed for
1028/// batch requests that create multiple `RouterRequest`s from the same HTTP
1029/// request.
1030///
1031/// # Example
1032///
1033/// ```rust
1034/// use tower_mcp::Extensions;
1035///
1036/// let mut ext = Extensions::new();
1037/// ext.insert(42u32);
1038/// assert_eq!(ext.get::<u32>(), Some(&42));
1039/// ```
1040#[derive(Default, Clone)]
1041pub struct Extensions {
1042    map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
1043}
1044
1045impl Extensions {
1046    /// Create an empty extensions map.
1047    pub fn new() -> Self {
1048        Self::default()
1049    }
1050
1051    /// Insert a value into the extensions map.
1052    ///
1053    /// If a value of the same type already exists, it is replaced.
1054    pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
1055        self.map.insert(TypeId::of::<T>(), Arc::new(val));
1056    }
1057
1058    /// Get a reference to a value in the extensions map.
1059    ///
1060    /// Returns `None` if no value of the given type has been inserted.
1061    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
1062        self.map
1063            .get(&TypeId::of::<T>())
1064            .and_then(|val| val.downcast_ref::<T>())
1065    }
1066}
1067
1068impl std::fmt::Debug for Extensions {
1069    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1070        f.debug_struct("Extensions")
1071            .field("len", &self.map.len())
1072            .finish()
1073    }
1074}
1075
1076/// Request type for the tower Service implementation
1077#[derive(Debug)]
1078pub struct RouterRequest {
1079    pub id: RequestId,
1080    pub inner: McpRequest,
1081    /// Type-map for passing data (e.g., `TokenClaims`) through middleware.
1082    pub extensions: Extensions,
1083}
1084
1085/// Response type for the tower Service implementation
1086#[derive(Debug)]
1087pub struct RouterResponse {
1088    pub id: RequestId,
1089    pub inner: std::result::Result<McpResponse, JsonRpcError>,
1090}
1091
1092impl RouterResponse {
1093    /// Convert to JSON-RPC response
1094    pub fn into_jsonrpc(self) -> JsonRpcResponse {
1095        match self.inner {
1096            Ok(response) => match serde_json::to_value(response) {
1097                Ok(result) => JsonRpcResponse::result(self.id, result),
1098                Err(e) => {
1099                    tracing::error!(error = %e, "Failed to serialize response");
1100                    JsonRpcResponse::error(
1101                        Some(self.id),
1102                        JsonRpcError::internal_error(format!("Serialization error: {}", e)),
1103                    )
1104                }
1105            },
1106            Err(error) => JsonRpcResponse::error(Some(self.id), error),
1107        }
1108    }
1109}
1110
1111impl Service<RouterRequest> for McpRouter {
1112    type Response = RouterResponse;
1113    type Error = std::convert::Infallible; // Errors are in the response
1114    type Future =
1115        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
1116
1117    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
1118        Poll::Ready(Ok(()))
1119    }
1120
1121    fn call(&mut self, req: RouterRequest) -> Self::Future {
1122        let router = self.clone();
1123        let request_id = req.id.clone();
1124        Box::pin(async move {
1125            let result = router.handle(req.id, req.inner).await;
1126            // Clean up tracking after request completes
1127            router.complete_request(&request_id);
1128            Ok(RouterResponse {
1129                id: request_id,
1130                // Map tower-mcp errors to JSON-RPC errors:
1131                // - Error::JsonRpc: forwarded as-is (preserves original code)
1132                // - Error::Tool: mapped to -32603 (Internal Error)
1133                // - All others: mapped to -32603 (Internal Error)
1134                inner: result.map_err(|e| match e {
1135                    Error::JsonRpc(err) => err,
1136                    Error::Tool(err) => JsonRpcError::internal_error(err.to_string()),
1137                    e => JsonRpcError::internal_error(e.to_string()),
1138                }),
1139            })
1140        })
1141    }
1142}
1143
1144#[cfg(test)]
1145mod tests {
1146    use super::*;
1147    use crate::jsonrpc::JsonRpcService;
1148    use crate::tool::ToolBuilder;
1149    use schemars::JsonSchema;
1150    use serde::Deserialize;
1151    use tower::ServiceExt;
1152
1153    #[derive(Debug, Deserialize, JsonSchema)]
1154    struct AddInput {
1155        a: i64,
1156        b: i64,
1157    }
1158
1159    /// Helper to initialize a router for testing
1160    async fn init_router(router: &mut McpRouter) {
1161        // Send initialize request
1162        let init_req = RouterRequest {
1163            id: RequestId::Number(0),
1164            inner: McpRequest::Initialize(InitializeParams {
1165                protocol_version: "2025-11-25".to_string(),
1166                capabilities: ClientCapabilities {
1167                    roots: None,
1168                    sampling: None,
1169                    elicitation: None,
1170                },
1171                client_info: Implementation {
1172                    name: "test".to_string(),
1173                    version: "1.0".to_string(),
1174                    ..Default::default()
1175                },
1176            }),
1177            extensions: Extensions::new(),
1178        };
1179        let _ = router.ready().await.unwrap().call(init_req).await.unwrap();
1180        // Send initialized notification
1181        router.handle_notification(McpNotification::Initialized);
1182    }
1183
1184    #[tokio::test]
1185    async fn test_router_list_tools() {
1186        let add_tool = ToolBuilder::new("add")
1187            .description("Add two numbers")
1188            .handler(|input: AddInput| async move {
1189                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1190            })
1191            .build()
1192            .expect("valid tool name");
1193
1194        let mut router = McpRouter::new().tool(add_tool);
1195
1196        // Initialize session first
1197        init_router(&mut router).await;
1198
1199        let req = RouterRequest {
1200            id: RequestId::Number(1),
1201            inner: McpRequest::ListTools(ListToolsParams::default()),
1202            extensions: Extensions::new(),
1203        };
1204
1205        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1206
1207        match resp.inner {
1208            Ok(McpResponse::ListTools(result)) => {
1209                assert_eq!(result.tools.len(), 1);
1210                assert_eq!(result.tools[0].name, "add");
1211            }
1212            _ => panic!("Expected ListTools response"),
1213        }
1214    }
1215
1216    #[tokio::test]
1217    async fn test_router_call_tool() {
1218        let add_tool = ToolBuilder::new("add")
1219            .description("Add two numbers")
1220            .handler(|input: AddInput| async move {
1221                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1222            })
1223            .build()
1224            .expect("valid tool name");
1225
1226        let mut router = McpRouter::new().tool(add_tool);
1227
1228        // Initialize session first
1229        init_router(&mut router).await;
1230
1231        let req = RouterRequest {
1232            id: RequestId::Number(1),
1233            inner: McpRequest::CallTool(CallToolParams {
1234                name: "add".to_string(),
1235                arguments: serde_json::json!({"a": 2, "b": 3}),
1236                meta: None,
1237            }),
1238            extensions: Extensions::new(),
1239        };
1240
1241        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1242
1243        match resp.inner {
1244            Ok(McpResponse::CallTool(result)) => {
1245                assert!(!result.is_error);
1246                // Check the text content
1247                match &result.content[0] {
1248                    Content::Text { text, .. } => assert_eq!(text, "5"),
1249                    _ => panic!("Expected text content"),
1250                }
1251            }
1252            _ => panic!("Expected CallTool response"),
1253        }
1254    }
1255
1256    /// Helper to initialize a JsonRpcService for testing
1257    async fn init_jsonrpc_service(service: &mut JsonRpcService<McpRouter>, router: &McpRouter) {
1258        let init_req = JsonRpcRequest::new(0, "initialize").with_params(serde_json::json!({
1259            "protocolVersion": "2025-11-25",
1260            "capabilities": {},
1261            "clientInfo": { "name": "test", "version": "1.0" }
1262        }));
1263        let _ = service.call_single(init_req).await.unwrap();
1264        router.handle_notification(McpNotification::Initialized);
1265    }
1266
1267    #[tokio::test]
1268    async fn test_jsonrpc_service() {
1269        let add_tool = ToolBuilder::new("add")
1270            .description("Add two numbers")
1271            .handler(|input: AddInput| async move {
1272                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1273            })
1274            .build()
1275            .expect("valid tool name");
1276
1277        let router = McpRouter::new().tool(add_tool);
1278        let mut service = JsonRpcService::new(router.clone());
1279
1280        // Initialize session first
1281        init_jsonrpc_service(&mut service, &router).await;
1282
1283        let req = JsonRpcRequest::new(1, "tools/list");
1284
1285        let resp = service.call_single(req).await.unwrap();
1286
1287        match resp {
1288            JsonRpcResponse::Result(r) => {
1289                assert_eq!(r.id, RequestId::Number(1));
1290                let tools = r.result.get("tools").unwrap().as_array().unwrap();
1291                assert_eq!(tools.len(), 1);
1292            }
1293            JsonRpcResponse::Error(_) => panic!("Expected success response"),
1294        }
1295    }
1296
1297    #[tokio::test]
1298    async fn test_batch_request() {
1299        let add_tool = ToolBuilder::new("add")
1300            .description("Add two numbers")
1301            .handler(|input: AddInput| async move {
1302                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1303            })
1304            .build()
1305            .expect("valid tool name");
1306
1307        let router = McpRouter::new().tool(add_tool);
1308        let mut service = JsonRpcService::new(router.clone());
1309
1310        // Initialize session first
1311        init_jsonrpc_service(&mut service, &router).await;
1312
1313        // Create a batch of requests
1314        let requests = vec![
1315            JsonRpcRequest::new(1, "tools/list"),
1316            JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1317                "name": "add",
1318                "arguments": {"a": 10, "b": 20}
1319            })),
1320            JsonRpcRequest::new(3, "ping"),
1321        ];
1322
1323        let responses = service.call_batch(requests).await.unwrap();
1324
1325        assert_eq!(responses.len(), 3);
1326
1327        // Check first response (tools/list)
1328        match &responses[0] {
1329            JsonRpcResponse::Result(r) => {
1330                assert_eq!(r.id, RequestId::Number(1));
1331                let tools = r.result.get("tools").unwrap().as_array().unwrap();
1332                assert_eq!(tools.len(), 1);
1333            }
1334            JsonRpcResponse::Error(_) => panic!("Expected success for tools/list"),
1335        }
1336
1337        // Check second response (tools/call)
1338        match &responses[1] {
1339            JsonRpcResponse::Result(r) => {
1340                assert_eq!(r.id, RequestId::Number(2));
1341                let content = r.result.get("content").unwrap().as_array().unwrap();
1342                let text = content[0].get("text").unwrap().as_str().unwrap();
1343                assert_eq!(text, "30");
1344            }
1345            JsonRpcResponse::Error(_) => panic!("Expected success for tools/call"),
1346        }
1347
1348        // Check third response (ping)
1349        match &responses[2] {
1350            JsonRpcResponse::Result(r) => {
1351                assert_eq!(r.id, RequestId::Number(3));
1352            }
1353            JsonRpcResponse::Error(_) => panic!("Expected success for ping"),
1354        }
1355    }
1356
1357    #[tokio::test]
1358    async fn test_empty_batch_error() {
1359        let router = McpRouter::new();
1360        let mut service = JsonRpcService::new(router);
1361
1362        let result = service.call_batch(vec![]).await;
1363        assert!(result.is_err());
1364    }
1365
1366    // =========================================================================
1367    // Progress Token Tests
1368    // =========================================================================
1369
1370    #[tokio::test]
1371    async fn test_progress_token_extraction() {
1372        use crate::context::{RequestContext, ServerNotification, notification_channel};
1373        use crate::protocol::ProgressToken;
1374        use std::sync::Arc;
1375        use std::sync::atomic::{AtomicBool, Ordering};
1376
1377        // Track whether progress was reported
1378        let progress_reported = Arc::new(AtomicBool::new(false));
1379        let progress_ref = progress_reported.clone();
1380
1381        // Create a tool that reports progress
1382        let tool = ToolBuilder::new("progress_tool")
1383            .description("Tool that reports progress")
1384            .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1385                let reported = progress_ref.clone();
1386                async move {
1387                    // Report progress - this should work if token was extracted
1388                    ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
1389                        .await;
1390                    reported.store(true, Ordering::SeqCst);
1391                    Ok(CallToolResult::text("done"))
1392                }
1393            })
1394            .build()
1395            .expect("valid tool name");
1396
1397        // Set up notification channel
1398        let (tx, mut rx) = notification_channel(10);
1399        let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1400        let mut service = JsonRpcService::new(router.clone());
1401
1402        // Initialize
1403        init_jsonrpc_service(&mut service, &router).await;
1404
1405        // Call tool WITH progress token in _meta
1406        let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1407            "name": "progress_tool",
1408            "arguments": {"a": 1, "b": 2},
1409            "_meta": {
1410                "progressToken": "test-token-123"
1411            }
1412        }));
1413
1414        let resp = service.call_single(req).await.unwrap();
1415
1416        // Verify the tool was called successfully
1417        match resp {
1418            JsonRpcResponse::Result(_) => {}
1419            JsonRpcResponse::Error(e) => panic!("Expected success, got error: {:?}", e),
1420        }
1421
1422        // Verify progress was reported by handler
1423        assert!(progress_reported.load(Ordering::SeqCst));
1424
1425        // Verify progress notification was sent through channel
1426        let notification = rx.try_recv().expect("Expected progress notification");
1427        match notification {
1428            ServerNotification::Progress(params) => {
1429                assert_eq!(
1430                    params.progress_token,
1431                    ProgressToken::String("test-token-123".to_string())
1432                );
1433                assert_eq!(params.progress, 50.0);
1434                assert_eq!(params.total, Some(100.0));
1435                assert_eq!(params.message.as_deref(), Some("Halfway"));
1436            }
1437            _ => panic!("Expected Progress notification"),
1438        }
1439    }
1440
1441    #[tokio::test]
1442    async fn test_tool_call_without_progress_token() {
1443        use crate::context::{RequestContext, notification_channel};
1444        use std::sync::Arc;
1445        use std::sync::atomic::{AtomicBool, Ordering};
1446
1447        let progress_attempted = Arc::new(AtomicBool::new(false));
1448        let progress_ref = progress_attempted.clone();
1449
1450        let tool = ToolBuilder::new("no_token_tool")
1451            .description("Tool that tries to report progress without token")
1452            .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1453                let attempted = progress_ref.clone();
1454                async move {
1455                    // Try to report progress - should be a no-op without token
1456                    ctx.report_progress(50.0, Some(100.0), None).await;
1457                    attempted.store(true, Ordering::SeqCst);
1458                    Ok(CallToolResult::text("done"))
1459                }
1460            })
1461            .build()
1462            .expect("valid tool name");
1463
1464        let (tx, mut rx) = notification_channel(10);
1465        let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1466        let mut service = JsonRpcService::new(router.clone());
1467
1468        init_jsonrpc_service(&mut service, &router).await;
1469
1470        // Call tool WITHOUT progress token
1471        let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1472            "name": "no_token_tool",
1473            "arguments": {"a": 1, "b": 2}
1474        }));
1475
1476        let resp = service.call_single(req).await.unwrap();
1477        assert!(matches!(resp, JsonRpcResponse::Result(_)));
1478
1479        // Handler was called
1480        assert!(progress_attempted.load(Ordering::SeqCst));
1481
1482        // But no notification was sent (no progress token)
1483        assert!(rx.try_recv().is_err());
1484    }
1485
1486    #[tokio::test]
1487    async fn test_batch_errors_returned_not_dropped() {
1488        let add_tool = ToolBuilder::new("add")
1489            .description("Add two numbers")
1490            .handler(|input: AddInput| async move {
1491                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1492            })
1493            .build()
1494            .expect("valid tool name");
1495
1496        let router = McpRouter::new().tool(add_tool);
1497        let mut service = JsonRpcService::new(router.clone());
1498
1499        init_jsonrpc_service(&mut service, &router).await;
1500
1501        // Create a batch with one valid and one invalid request
1502        let requests = vec![
1503            // Valid request
1504            JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1505                "name": "add",
1506                "arguments": {"a": 10, "b": 20}
1507            })),
1508            // Invalid request - tool doesn't exist
1509            JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1510                "name": "nonexistent_tool",
1511                "arguments": {}
1512            })),
1513            // Another valid request
1514            JsonRpcRequest::new(3, "ping"),
1515        ];
1516
1517        let responses = service.call_batch(requests).await.unwrap();
1518
1519        // All three requests should have responses (errors are not dropped)
1520        assert_eq!(responses.len(), 3);
1521
1522        // First should be success
1523        match &responses[0] {
1524            JsonRpcResponse::Result(r) => {
1525                assert_eq!(r.id, RequestId::Number(1));
1526            }
1527            JsonRpcResponse::Error(_) => panic!("Expected success for first request"),
1528        }
1529
1530        // Second should be an error (tool not found)
1531        match &responses[1] {
1532            JsonRpcResponse::Error(e) => {
1533                assert_eq!(e.id, Some(RequestId::Number(2)));
1534                // Error should indicate method not found
1535                assert!(e.error.message.contains("not found") || e.error.code == -32601);
1536            }
1537            JsonRpcResponse::Result(_) => panic!("Expected error for second request"),
1538        }
1539
1540        // Third should be success
1541        match &responses[2] {
1542            JsonRpcResponse::Result(r) => {
1543                assert_eq!(r.id, RequestId::Number(3));
1544            }
1545            JsonRpcResponse::Error(_) => panic!("Expected success for third request"),
1546        }
1547    }
1548
1549    // =========================================================================
1550    // Resource Template Tests
1551    // =========================================================================
1552
1553    #[tokio::test]
1554    async fn test_list_resource_templates() {
1555        use crate::resource::ResourceTemplateBuilder;
1556        use std::collections::HashMap;
1557
1558        let template = ResourceTemplateBuilder::new("file:///{path}")
1559            .name("Project Files")
1560            .description("Access project files")
1561            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1562                Ok(ReadResourceResult {
1563                    contents: vec![ResourceContent {
1564                        uri,
1565                        mime_type: None,
1566                        text: None,
1567                        blob: None,
1568                    }],
1569                })
1570            });
1571
1572        let mut router = McpRouter::new().resource_template(template);
1573
1574        // Initialize session
1575        init_router(&mut router).await;
1576
1577        let req = RouterRequest {
1578            id: RequestId::Number(1),
1579            inner: McpRequest::ListResourceTemplates(ListResourceTemplatesParams::default()),
1580            extensions: Extensions::new(),
1581        };
1582
1583        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1584
1585        match resp.inner {
1586            Ok(McpResponse::ListResourceTemplates(result)) => {
1587                assert_eq!(result.resource_templates.len(), 1);
1588                assert_eq!(result.resource_templates[0].uri_template, "file:///{path}");
1589                assert_eq!(result.resource_templates[0].name, "Project Files");
1590            }
1591            _ => panic!("Expected ListResourceTemplates response"),
1592        }
1593    }
1594
1595    #[tokio::test]
1596    async fn test_read_resource_via_template() {
1597        use crate::resource::ResourceTemplateBuilder;
1598        use std::collections::HashMap;
1599
1600        let template = ResourceTemplateBuilder::new("db://users/{id}")
1601            .name("User Records")
1602            .handler(|uri: String, vars: HashMap<String, String>| async move {
1603                let id = vars.get("id").unwrap().clone();
1604                Ok(ReadResourceResult {
1605                    contents: vec![ResourceContent {
1606                        uri,
1607                        mime_type: Some("application/json".to_string()),
1608                        text: Some(format!(r#"{{"id": "{}"}}"#, id)),
1609                        blob: None,
1610                    }],
1611                })
1612            });
1613
1614        let mut router = McpRouter::new().resource_template(template);
1615
1616        // Initialize session
1617        init_router(&mut router).await;
1618
1619        // Read a resource that matches the template
1620        let req = RouterRequest {
1621            id: RequestId::Number(1),
1622            inner: McpRequest::ReadResource(ReadResourceParams {
1623                uri: "db://users/123".to_string(),
1624            }),
1625            extensions: Extensions::new(),
1626        };
1627
1628        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1629
1630        match resp.inner {
1631            Ok(McpResponse::ReadResource(result)) => {
1632                assert_eq!(result.contents.len(), 1);
1633                assert_eq!(result.contents[0].uri, "db://users/123");
1634                assert!(result.contents[0].text.as_ref().unwrap().contains("123"));
1635            }
1636            _ => panic!("Expected ReadResource response"),
1637        }
1638    }
1639
1640    #[tokio::test]
1641    async fn test_static_resource_takes_precedence_over_template() {
1642        use crate::resource::{ResourceBuilder, ResourceTemplateBuilder};
1643        use std::collections::HashMap;
1644
1645        // Template that would match the same URI
1646        let template = ResourceTemplateBuilder::new("file:///{path}")
1647            .name("Files Template")
1648            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1649                Ok(ReadResourceResult {
1650                    contents: vec![ResourceContent {
1651                        uri,
1652                        mime_type: None,
1653                        text: Some("from template".to_string()),
1654                        blob: None,
1655                    }],
1656                })
1657            });
1658
1659        // Static resource with exact URI
1660        let static_resource = ResourceBuilder::new("file:///README.md")
1661            .name("README")
1662            .text("from static resource");
1663
1664        let mut router = McpRouter::new()
1665            .resource_template(template)
1666            .resource(static_resource);
1667
1668        // Initialize session
1669        init_router(&mut router).await;
1670
1671        // Read the static resource - should NOT go through template
1672        let req = RouterRequest {
1673            id: RequestId::Number(1),
1674            inner: McpRequest::ReadResource(ReadResourceParams {
1675                uri: "file:///README.md".to_string(),
1676            }),
1677            extensions: Extensions::new(),
1678        };
1679
1680        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1681
1682        match resp.inner {
1683            Ok(McpResponse::ReadResource(result)) => {
1684                // Should get static resource, not template
1685                assert_eq!(
1686                    result.contents[0].text.as_deref(),
1687                    Some("from static resource")
1688                );
1689            }
1690            _ => panic!("Expected ReadResource response"),
1691        }
1692    }
1693
1694    #[tokio::test]
1695    async fn test_resource_not_found_when_no_match() {
1696        use crate::resource::ResourceTemplateBuilder;
1697        use std::collections::HashMap;
1698
1699        let template = ResourceTemplateBuilder::new("db://users/{id}")
1700            .name("Users")
1701            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1702                Ok(ReadResourceResult {
1703                    contents: vec![ResourceContent {
1704                        uri,
1705                        mime_type: None,
1706                        text: None,
1707                        blob: None,
1708                    }],
1709                })
1710            });
1711
1712        let mut router = McpRouter::new().resource_template(template);
1713
1714        // Initialize session
1715        init_router(&mut router).await;
1716
1717        // Try to read a URI that doesn't match any resource or template
1718        let req = RouterRequest {
1719            id: RequestId::Number(1),
1720            inner: McpRequest::ReadResource(ReadResourceParams {
1721                uri: "db://posts/123".to_string(),
1722            }),
1723            extensions: Extensions::new(),
1724        };
1725
1726        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1727
1728        match resp.inner {
1729            Err(err) => {
1730                assert!(err.message.contains("not found"));
1731            }
1732            Ok(_) => panic!("Expected error for non-matching URI"),
1733        }
1734    }
1735
1736    #[tokio::test]
1737    async fn test_capabilities_include_resources_with_only_templates() {
1738        use crate::resource::ResourceTemplateBuilder;
1739        use std::collections::HashMap;
1740
1741        let template = ResourceTemplateBuilder::new("file:///{path}")
1742            .name("Files")
1743            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1744                Ok(ReadResourceResult {
1745                    contents: vec![ResourceContent {
1746                        uri,
1747                        mime_type: None,
1748                        text: None,
1749                        blob: None,
1750                    }],
1751                })
1752            });
1753
1754        let mut router = McpRouter::new().resource_template(template);
1755
1756        // Send initialize request and check capabilities
1757        let init_req = RouterRequest {
1758            id: RequestId::Number(0),
1759            inner: McpRequest::Initialize(InitializeParams {
1760                protocol_version: "2025-11-25".to_string(),
1761                capabilities: ClientCapabilities {
1762                    roots: None,
1763                    sampling: None,
1764                    elicitation: None,
1765                },
1766                client_info: Implementation {
1767                    name: "test".to_string(),
1768                    version: "1.0".to_string(),
1769                    ..Default::default()
1770                },
1771            }),
1772            extensions: Extensions::new(),
1773        };
1774        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1775
1776        match resp.inner {
1777            Ok(McpResponse::Initialize(result)) => {
1778                // Should have resources capability even though only templates registered
1779                assert!(result.capabilities.resources.is_some());
1780            }
1781            _ => panic!("Expected Initialize response"),
1782        }
1783    }
1784
1785    // =========================================================================
1786    // Logging Notification Tests
1787    // =========================================================================
1788
1789    #[tokio::test]
1790    async fn test_log_sends_notification() {
1791        use crate::context::notification_channel;
1792
1793        let (tx, mut rx) = notification_channel(10);
1794        let router = McpRouter::new().with_notification_sender(tx);
1795
1796        // Send an info log
1797        let sent = router.log_info("Test message");
1798        assert!(sent);
1799
1800        // Should receive the notification
1801        let notification = rx.try_recv().unwrap();
1802        match notification {
1803            ServerNotification::LogMessage(params) => {
1804                assert_eq!(params.level, LogLevel::Info);
1805                let data = params.data.unwrap();
1806                assert_eq!(
1807                    data.get("message").unwrap().as_str().unwrap(),
1808                    "Test message"
1809                );
1810            }
1811            _ => panic!("Expected LogMessage notification"),
1812        }
1813    }
1814
1815    #[tokio::test]
1816    async fn test_log_with_custom_params() {
1817        use crate::context::notification_channel;
1818
1819        let (tx, mut rx) = notification_channel(10);
1820        let router = McpRouter::new().with_notification_sender(tx);
1821
1822        // Send a custom log message
1823        let params = LoggingMessageParams::new(LogLevel::Error)
1824            .with_logger("database")
1825            .with_data(serde_json::json!({
1826                "error": "Connection failed",
1827                "host": "localhost"
1828            }));
1829
1830        let sent = router.log(params);
1831        assert!(sent);
1832
1833        let notification = rx.try_recv().unwrap();
1834        match notification {
1835            ServerNotification::LogMessage(params) => {
1836                assert_eq!(params.level, LogLevel::Error);
1837                assert_eq!(params.logger.as_deref(), Some("database"));
1838                let data = params.data.unwrap();
1839                assert_eq!(
1840                    data.get("error").unwrap().as_str().unwrap(),
1841                    "Connection failed"
1842                );
1843            }
1844            _ => panic!("Expected LogMessage notification"),
1845        }
1846    }
1847
1848    #[tokio::test]
1849    async fn test_log_without_channel_returns_false() {
1850        // Router without notification channel
1851        let router = McpRouter::new();
1852
1853        // Should return false when no channel configured
1854        assert!(!router.log_info("Test"));
1855        assert!(!router.log_warning("Test"));
1856        assert!(!router.log_error("Test"));
1857        assert!(!router.log_debug("Test"));
1858    }
1859
1860    #[tokio::test]
1861    async fn test_logging_capability_with_channel() {
1862        use crate::context::notification_channel;
1863
1864        let (tx, _rx) = notification_channel(10);
1865        let mut router = McpRouter::new().with_notification_sender(tx);
1866
1867        // Initialize and check capabilities
1868        let init_req = RouterRequest {
1869            id: RequestId::Number(0),
1870            inner: McpRequest::Initialize(InitializeParams {
1871                protocol_version: "2025-11-25".to_string(),
1872                capabilities: ClientCapabilities {
1873                    roots: None,
1874                    sampling: None,
1875                    elicitation: None,
1876                },
1877                client_info: Implementation {
1878                    name: "test".to_string(),
1879                    version: "1.0".to_string(),
1880                    ..Default::default()
1881                },
1882            }),
1883            extensions: Extensions::new(),
1884        };
1885        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1886
1887        match resp.inner {
1888            Ok(McpResponse::Initialize(result)) => {
1889                // Should have logging capability when notification channel is set
1890                assert!(result.capabilities.logging.is_some());
1891            }
1892            _ => panic!("Expected Initialize response"),
1893        }
1894    }
1895
1896    #[tokio::test]
1897    async fn test_no_logging_capability_without_channel() {
1898        let mut router = McpRouter::new();
1899
1900        // Initialize and check capabilities
1901        let init_req = RouterRequest {
1902            id: RequestId::Number(0),
1903            inner: McpRequest::Initialize(InitializeParams {
1904                protocol_version: "2025-11-25".to_string(),
1905                capabilities: ClientCapabilities {
1906                    roots: None,
1907                    sampling: None,
1908                    elicitation: None,
1909                },
1910                client_info: Implementation {
1911                    name: "test".to_string(),
1912                    version: "1.0".to_string(),
1913                    ..Default::default()
1914                },
1915            }),
1916            extensions: Extensions::new(),
1917        };
1918        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1919
1920        match resp.inner {
1921            Ok(McpResponse::Initialize(result)) => {
1922                // Should NOT have logging capability without notification channel
1923                assert!(result.capabilities.logging.is_none());
1924            }
1925            _ => panic!("Expected Initialize response"),
1926        }
1927    }
1928
1929    // =========================================================================
1930    // Task Lifecycle Tests
1931    // =========================================================================
1932
1933    #[tokio::test]
1934    async fn test_enqueue_task() {
1935        let add_tool = ToolBuilder::new("add")
1936            .description("Add two numbers")
1937            .handler(|input: AddInput| async move {
1938                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1939            })
1940            .build()
1941            .expect("valid tool name");
1942
1943        let mut router = McpRouter::new().tool(add_tool);
1944        init_router(&mut router).await;
1945
1946        let req = RouterRequest {
1947            id: RequestId::Number(1),
1948            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1949                tool_name: "add".to_string(),
1950                arguments: serde_json::json!({"a": 5, "b": 10}),
1951                ttl: None,
1952            }),
1953            extensions: Extensions::new(),
1954        };
1955
1956        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1957
1958        match resp.inner {
1959            Ok(McpResponse::EnqueueTask(result)) => {
1960                assert!(result.task_id.starts_with("task-"));
1961                assert_eq!(result.status, TaskStatus::Working);
1962            }
1963            _ => panic!("Expected EnqueueTask response"),
1964        }
1965    }
1966
1967    #[tokio::test]
1968    async fn test_list_tasks_empty() {
1969        let mut router = McpRouter::new();
1970        init_router(&mut router).await;
1971
1972        let req = RouterRequest {
1973            id: RequestId::Number(1),
1974            inner: McpRequest::ListTasks(ListTasksParams::default()),
1975            extensions: Extensions::new(),
1976        };
1977
1978        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1979
1980        match resp.inner {
1981            Ok(McpResponse::ListTasks(result)) => {
1982                assert!(result.tasks.is_empty());
1983            }
1984            _ => panic!("Expected ListTasks response"),
1985        }
1986    }
1987
1988    #[tokio::test]
1989    async fn test_task_lifecycle_complete() {
1990        let add_tool = ToolBuilder::new("add")
1991            .description("Add two numbers")
1992            .handler(|input: AddInput| async move {
1993                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1994            })
1995            .build()
1996            .expect("valid tool name");
1997
1998        let mut router = McpRouter::new().tool(add_tool);
1999        init_router(&mut router).await;
2000
2001        // Enqueue task
2002        let req = RouterRequest {
2003            id: RequestId::Number(1),
2004            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2005                tool_name: "add".to_string(),
2006                arguments: serde_json::json!({"a": 7, "b": 8}),
2007                ttl: None,
2008            }),
2009            extensions: Extensions::new(),
2010        };
2011
2012        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2013        let task_id = match resp.inner {
2014            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2015            _ => panic!("Expected EnqueueTask response"),
2016        };
2017
2018        // Wait for task to complete
2019        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
2020
2021        // Get task result
2022        let req = RouterRequest {
2023            id: RequestId::Number(2),
2024            inner: McpRequest::GetTaskResult(GetTaskResultParams {
2025                task_id: task_id.clone(),
2026            }),
2027            extensions: Extensions::new(),
2028        };
2029
2030        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2031
2032        match resp.inner {
2033            Ok(McpResponse::GetTaskResult(result)) => {
2034                assert_eq!(result.task_id, task_id);
2035                assert_eq!(result.status, TaskStatus::Completed);
2036                assert!(result.result.is_some());
2037                assert!(result.error.is_none());
2038
2039                // Check the result content
2040                let tool_result = result.result.unwrap();
2041                match &tool_result.content[0] {
2042                    Content::Text { text, .. } => assert_eq!(text, "15"),
2043                    _ => panic!("Expected text content"),
2044                }
2045            }
2046            _ => panic!("Expected GetTaskResult response"),
2047        }
2048    }
2049
2050    #[tokio::test]
2051    async fn test_task_cancellation() {
2052        // Use a slow tool to test cancellation
2053        let slow_tool = ToolBuilder::new("slow")
2054            .description("Slow tool")
2055            .handler(|_input: serde_json::Value| async move {
2056                tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
2057                Ok(CallToolResult::text("done"))
2058            })
2059            .build()
2060            .expect("valid tool name");
2061
2062        let mut router = McpRouter::new().tool(slow_tool);
2063        init_router(&mut router).await;
2064
2065        // Enqueue task
2066        let req = RouterRequest {
2067            id: RequestId::Number(1),
2068            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2069                tool_name: "slow".to_string(),
2070                arguments: serde_json::json!({}),
2071                ttl: None,
2072            }),
2073            extensions: Extensions::new(),
2074        };
2075
2076        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2077        let task_id = match resp.inner {
2078            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2079            _ => panic!("Expected EnqueueTask response"),
2080        };
2081
2082        // Cancel the task
2083        let req = RouterRequest {
2084            id: RequestId::Number(2),
2085            inner: McpRequest::CancelTask(CancelTaskParams {
2086                task_id: task_id.clone(),
2087                reason: Some("Test cancellation".to_string()),
2088            }),
2089            extensions: Extensions::new(),
2090        };
2091
2092        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2093
2094        match resp.inner {
2095            Ok(McpResponse::CancelTask(result)) => {
2096                assert!(result.cancelled);
2097                assert_eq!(result.status, TaskStatus::Cancelled);
2098            }
2099            _ => panic!("Expected CancelTask response"),
2100        }
2101    }
2102
2103    #[tokio::test]
2104    async fn test_get_task_info() {
2105        let add_tool = ToolBuilder::new("add")
2106            .description("Add two numbers")
2107            .handler(|input: AddInput| async move {
2108                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2109            })
2110            .build()
2111            .expect("valid tool name");
2112
2113        let mut router = McpRouter::new().tool(add_tool);
2114        init_router(&mut router).await;
2115
2116        // Enqueue task
2117        let req = RouterRequest {
2118            id: RequestId::Number(1),
2119            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2120                tool_name: "add".to_string(),
2121                arguments: serde_json::json!({"a": 1, "b": 2}),
2122                ttl: Some(600),
2123            }),
2124            extensions: Extensions::new(),
2125        };
2126
2127        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2128        let task_id = match resp.inner {
2129            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2130            _ => panic!("Expected EnqueueTask response"),
2131        };
2132
2133        // Get task info
2134        let req = RouterRequest {
2135            id: RequestId::Number(2),
2136            inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2137                task_id: task_id.clone(),
2138            }),
2139            extensions: Extensions::new(),
2140        };
2141
2142        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2143
2144        match resp.inner {
2145            Ok(McpResponse::GetTaskInfo(info)) => {
2146                assert_eq!(info.task_id, task_id);
2147                assert!(info.created_at.contains('T')); // ISO 8601
2148                assert_eq!(info.ttl, Some(600));
2149            }
2150            _ => panic!("Expected GetTaskInfo response"),
2151        }
2152    }
2153
2154    #[tokio::test]
2155    async fn test_enqueue_nonexistent_tool() {
2156        let mut router = McpRouter::new();
2157        init_router(&mut router).await;
2158
2159        let req = RouterRequest {
2160            id: RequestId::Number(1),
2161            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2162                tool_name: "nonexistent".to_string(),
2163                arguments: serde_json::json!({}),
2164                ttl: None,
2165            }),
2166            extensions: Extensions::new(),
2167        };
2168
2169        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2170
2171        match resp.inner {
2172            Err(e) => {
2173                assert!(e.message.contains("not found"));
2174            }
2175            _ => panic!("Expected error response"),
2176        }
2177    }
2178
2179    #[tokio::test]
2180    async fn test_get_nonexistent_task() {
2181        let mut router = McpRouter::new();
2182        init_router(&mut router).await;
2183
2184        let req = RouterRequest {
2185            id: RequestId::Number(1),
2186            inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2187                task_id: "task-999".to_string(),
2188            }),
2189            extensions: Extensions::new(),
2190        };
2191
2192        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2193
2194        match resp.inner {
2195            Err(e) => {
2196                assert!(e.message.contains("not found"));
2197            }
2198            _ => panic!("Expected error response"),
2199        }
2200    }
2201
2202    // =========================================================================
2203    // Resource Subscription Tests
2204    // =========================================================================
2205
2206    #[tokio::test]
2207    async fn test_subscribe_to_resource() {
2208        use crate::resource::ResourceBuilder;
2209
2210        let resource = ResourceBuilder::new("file:///test.txt")
2211            .name("Test File")
2212            .text("Hello");
2213
2214        let mut router = McpRouter::new().resource(resource);
2215        init_router(&mut router).await;
2216
2217        // Subscribe to the resource
2218        let req = RouterRequest {
2219            id: RequestId::Number(1),
2220            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2221                uri: "file:///test.txt".to_string(),
2222            }),
2223            extensions: Extensions::new(),
2224        };
2225
2226        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2227
2228        match resp.inner {
2229            Ok(McpResponse::SubscribeResource(_)) => {
2230                // Should be subscribed now
2231                assert!(router.is_subscribed("file:///test.txt"));
2232            }
2233            _ => panic!("Expected SubscribeResource response"),
2234        }
2235    }
2236
2237    #[tokio::test]
2238    async fn test_unsubscribe_from_resource() {
2239        use crate::resource::ResourceBuilder;
2240
2241        let resource = ResourceBuilder::new("file:///test.txt")
2242            .name("Test File")
2243            .text("Hello");
2244
2245        let mut router = McpRouter::new().resource(resource);
2246        init_router(&mut router).await;
2247
2248        // Subscribe first
2249        let req = RouterRequest {
2250            id: RequestId::Number(1),
2251            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2252                uri: "file:///test.txt".to_string(),
2253            }),
2254            extensions: Extensions::new(),
2255        };
2256        let _ = router.ready().await.unwrap().call(req).await.unwrap();
2257        assert!(router.is_subscribed("file:///test.txt"));
2258
2259        // Now unsubscribe
2260        let req = RouterRequest {
2261            id: RequestId::Number(2),
2262            inner: McpRequest::UnsubscribeResource(UnsubscribeResourceParams {
2263                uri: "file:///test.txt".to_string(),
2264            }),
2265            extensions: Extensions::new(),
2266        };
2267
2268        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2269
2270        match resp.inner {
2271            Ok(McpResponse::UnsubscribeResource(_)) => {
2272                // Should no longer be subscribed
2273                assert!(!router.is_subscribed("file:///test.txt"));
2274            }
2275            _ => panic!("Expected UnsubscribeResource response"),
2276        }
2277    }
2278
2279    #[tokio::test]
2280    async fn test_subscribe_nonexistent_resource() {
2281        let mut router = McpRouter::new();
2282        init_router(&mut router).await;
2283
2284        let req = RouterRequest {
2285            id: RequestId::Number(1),
2286            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2287                uri: "file:///nonexistent.txt".to_string(),
2288            }),
2289            extensions: Extensions::new(),
2290        };
2291
2292        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2293
2294        match resp.inner {
2295            Err(e) => {
2296                assert!(e.message.contains("not found"));
2297            }
2298            _ => panic!("Expected error response"),
2299        }
2300    }
2301
2302    #[tokio::test]
2303    async fn test_notify_resource_updated() {
2304        use crate::context::notification_channel;
2305        use crate::resource::ResourceBuilder;
2306
2307        let (tx, mut rx) = notification_channel(10);
2308
2309        let resource = ResourceBuilder::new("file:///test.txt")
2310            .name("Test File")
2311            .text("Hello");
2312
2313        let router = McpRouter::new()
2314            .resource(resource)
2315            .with_notification_sender(tx);
2316
2317        // First, manually subscribe (simulate subscription)
2318        router.subscribe("file:///test.txt");
2319
2320        // Now notify
2321        let sent = router.notify_resource_updated("file:///test.txt");
2322        assert!(sent);
2323
2324        // Check the notification was sent
2325        let notification = rx.try_recv().unwrap();
2326        match notification {
2327            ServerNotification::ResourceUpdated { uri } => {
2328                assert_eq!(uri, "file:///test.txt");
2329            }
2330            _ => panic!("Expected ResourceUpdated notification"),
2331        }
2332    }
2333
2334    #[tokio::test]
2335    async fn test_notify_resource_updated_not_subscribed() {
2336        use crate::context::notification_channel;
2337        use crate::resource::ResourceBuilder;
2338
2339        let (tx, mut rx) = notification_channel(10);
2340
2341        let resource = ResourceBuilder::new("file:///test.txt")
2342            .name("Test File")
2343            .text("Hello");
2344
2345        let router = McpRouter::new()
2346            .resource(resource)
2347            .with_notification_sender(tx);
2348
2349        // Try to notify without subscribing
2350        let sent = router.notify_resource_updated("file:///test.txt");
2351        assert!(!sent); // Should not send because not subscribed
2352
2353        // Channel should be empty
2354        assert!(rx.try_recv().is_err());
2355    }
2356
2357    #[tokio::test]
2358    async fn test_notify_resources_list_changed() {
2359        use crate::context::notification_channel;
2360
2361        let (tx, mut rx) = notification_channel(10);
2362        let router = McpRouter::new().with_notification_sender(tx);
2363
2364        let sent = router.notify_resources_list_changed();
2365        assert!(sent);
2366
2367        let notification = rx.try_recv().unwrap();
2368        match notification {
2369            ServerNotification::ResourcesListChanged => {}
2370            _ => panic!("Expected ResourcesListChanged notification"),
2371        }
2372    }
2373
2374    #[tokio::test]
2375    async fn test_subscribed_uris() {
2376        use crate::resource::ResourceBuilder;
2377
2378        let resource1 = ResourceBuilder::new("file:///a.txt").name("A").text("A");
2379
2380        let resource2 = ResourceBuilder::new("file:///b.txt").name("B").text("B");
2381
2382        let router = McpRouter::new().resource(resource1).resource(resource2);
2383
2384        // Subscribe to both
2385        router.subscribe("file:///a.txt");
2386        router.subscribe("file:///b.txt");
2387
2388        let uris = router.subscribed_uris();
2389        assert_eq!(uris.len(), 2);
2390        assert!(uris.contains(&"file:///a.txt".to_string()));
2391        assert!(uris.contains(&"file:///b.txt".to_string()));
2392    }
2393
2394    #[tokio::test]
2395    async fn test_subscription_capability_advertised() {
2396        use crate::resource::ResourceBuilder;
2397
2398        let resource = ResourceBuilder::new("file:///test.txt")
2399            .name("Test")
2400            .text("Hello");
2401
2402        let mut router = McpRouter::new().resource(resource);
2403
2404        // Initialize and check capabilities
2405        let init_req = RouterRequest {
2406            id: RequestId::Number(0),
2407            inner: McpRequest::Initialize(InitializeParams {
2408                protocol_version: "2025-11-25".to_string(),
2409                capabilities: ClientCapabilities {
2410                    roots: None,
2411                    sampling: None,
2412                    elicitation: None,
2413                },
2414                client_info: Implementation {
2415                    name: "test".to_string(),
2416                    version: "1.0".to_string(),
2417                    ..Default::default()
2418                },
2419            }),
2420            extensions: Extensions::new(),
2421        };
2422        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2423
2424        match resp.inner {
2425            Ok(McpResponse::Initialize(result)) => {
2426                // Should have resources capability with subscribe enabled
2427                let resources_cap = result.capabilities.resources.unwrap();
2428                assert!(resources_cap.subscribe);
2429            }
2430            _ => panic!("Expected Initialize response"),
2431        }
2432    }
2433
2434    #[tokio::test]
2435    async fn test_completion_handler() {
2436        let router = McpRouter::new()
2437            .server_info("test", "1.0")
2438            .completion_handler(|params: CompleteParams| async move {
2439                // Return suggestions based on the argument value
2440                let prefix = &params.argument.value;
2441                let suggestions: Vec<String> = vec!["alpha", "beta", "gamma"]
2442                    .into_iter()
2443                    .filter(|s| s.starts_with(prefix))
2444                    .map(String::from)
2445                    .collect();
2446                Ok(CompleteResult::new(suggestions))
2447            });
2448
2449        // Initialize
2450        let init_req = RouterRequest {
2451            id: RequestId::Number(0),
2452            inner: McpRequest::Initialize(InitializeParams {
2453                protocol_version: "2025-11-25".to_string(),
2454                capabilities: ClientCapabilities::default(),
2455                client_info: Implementation {
2456                    name: "test".to_string(),
2457                    version: "1.0".to_string(),
2458                    ..Default::default()
2459                },
2460            }),
2461            extensions: Extensions::new(),
2462        };
2463        let resp = router
2464            .clone()
2465            .ready()
2466            .await
2467            .unwrap()
2468            .call(init_req)
2469            .await
2470            .unwrap();
2471
2472        // Check that completions capability is advertised
2473        match resp.inner {
2474            Ok(McpResponse::Initialize(result)) => {
2475                assert!(result.capabilities.completions.is_some());
2476            }
2477            _ => panic!("Expected Initialize response"),
2478        }
2479
2480        // Send initialized notification
2481        router.handle_notification(McpNotification::Initialized);
2482
2483        // Test completion request
2484        let complete_req = RouterRequest {
2485            id: RequestId::Number(1),
2486            inner: McpRequest::Complete(CompleteParams {
2487                reference: CompletionReference::prompt("test-prompt"),
2488                argument: CompletionArgument::new("query", "al"),
2489            }),
2490            extensions: Extensions::new(),
2491        };
2492        let resp = router
2493            .clone()
2494            .ready()
2495            .await
2496            .unwrap()
2497            .call(complete_req)
2498            .await
2499            .unwrap();
2500
2501        match resp.inner {
2502            Ok(McpResponse::Complete(result)) => {
2503                assert_eq!(result.completion.values, vec!["alpha"]);
2504            }
2505            _ => panic!("Expected Complete response"),
2506        }
2507    }
2508
2509    #[tokio::test]
2510    async fn test_completion_without_handler_returns_empty() {
2511        let router = McpRouter::new().server_info("test", "1.0");
2512
2513        // Initialize
2514        let init_req = RouterRequest {
2515            id: RequestId::Number(0),
2516            inner: McpRequest::Initialize(InitializeParams {
2517                protocol_version: "2025-11-25".to_string(),
2518                capabilities: ClientCapabilities::default(),
2519                client_info: Implementation {
2520                    name: "test".to_string(),
2521                    version: "1.0".to_string(),
2522                    ..Default::default()
2523                },
2524            }),
2525            extensions: Extensions::new(),
2526        };
2527        let resp = router
2528            .clone()
2529            .ready()
2530            .await
2531            .unwrap()
2532            .call(init_req)
2533            .await
2534            .unwrap();
2535
2536        // Check that completions capability is NOT advertised
2537        match resp.inner {
2538            Ok(McpResponse::Initialize(result)) => {
2539                assert!(result.capabilities.completions.is_none());
2540            }
2541            _ => panic!("Expected Initialize response"),
2542        }
2543
2544        // Send initialized notification
2545        router.handle_notification(McpNotification::Initialized);
2546
2547        // Test completion request still works but returns empty
2548        let complete_req = RouterRequest {
2549            id: RequestId::Number(1),
2550            inner: McpRequest::Complete(CompleteParams {
2551                reference: CompletionReference::prompt("test-prompt"),
2552                argument: CompletionArgument::new("query", "al"),
2553            }),
2554            extensions: Extensions::new(),
2555        };
2556        let resp = router
2557            .clone()
2558            .ready()
2559            .await
2560            .unwrap()
2561            .call(complete_req)
2562            .await
2563            .unwrap();
2564
2565        match resp.inner {
2566            Ok(McpResponse::Complete(result)) => {
2567                assert!(result.completion.values.is_empty());
2568            }
2569            _ => panic!("Expected Complete response"),
2570        }
2571    }
2572}