Skip to main content

tower_mcp/
router.rs

1//! MCP Router - routes requests to tools, resources, and prompts
2//!
3//! The router implements Tower's `Service` trait, making it composable with
4//! standard tower middleware.
5
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, RwLock};
10use std::task::{Context, Poll};
11
12use tower_service::Service;
13
14use crate::async_task::TaskStore;
15use crate::context::{
16    CancellationToken, ClientRequesterHandle, NotificationSender, RequestContext,
17    ServerNotification,
18};
19use crate::error::{Error, JsonRpcError, Result};
20use crate::prompt::Prompt;
21use crate::protocol::*;
22use crate::resource::{Resource, ResourceTemplate};
23use crate::session::SessionState;
24use crate::tool::Tool;
25
26/// Type alias for completion handler function
27pub type CompletionHandler = Arc<
28    dyn Fn(CompleteParams) -> Pin<Box<dyn Future<Output = Result<CompleteResult>> + Send>>
29        + Send
30        + Sync,
31>;
32
33/// MCP Router that dispatches requests to registered handlers
34///
35/// Implements `tower::Service<McpRequest>` for middleware composition.
36///
37/// # Example
38///
39/// ```rust
40/// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
41/// use schemars::JsonSchema;
42/// use serde::Deserialize;
43///
44/// #[derive(Debug, Deserialize, JsonSchema)]
45/// struct Input { value: String }
46///
47/// let tool = ToolBuilder::new("echo")
48///     .description("Echo input")
49///     .handler(|i: Input| async move { Ok(CallToolResult::text(i.value)) })
50///     .build()
51///     .unwrap();
52///
53/// let router = McpRouter::new()
54///     .server_info("my-server", "1.0.0")
55///     .tool(tool);
56/// ```
57#[derive(Clone)]
58pub struct McpRouter {
59    inner: Arc<McpRouterInner>,
60    session: SessionState,
61}
62
63impl std::fmt::Debug for McpRouter {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        f.debug_struct("McpRouter")
66            .field("server_name", &self.inner.server_name)
67            .field("server_version", &self.inner.server_version)
68            .field("tools_count", &self.inner.tools.len())
69            .field("resources_count", &self.inner.resources.len())
70            .field("prompts_count", &self.inner.prompts.len())
71            .field("session_phase", &self.session.phase())
72            .finish()
73    }
74}
75
76/// Inner configuration that is shared across clones
77#[derive(Clone)]
78struct McpRouterInner {
79    server_name: String,
80    server_version: String,
81    instructions: Option<String>,
82    tools: HashMap<String, Arc<Tool>>,
83    resources: HashMap<String, Arc<Resource>>,
84    /// Resource templates for dynamic resource matching (keyed by uri_template)
85    resource_templates: Vec<Arc<ResourceTemplate>>,
86    prompts: HashMap<String, Arc<Prompt>>,
87    /// In-flight requests for cancellation tracking (shared across clones)
88    in_flight: Arc<RwLock<HashMap<RequestId, CancellationToken>>>,
89    /// Channel for sending notifications to connected clients
90    notification_tx: Option<NotificationSender>,
91    /// Handle for sending requests to the client (for sampling, etc.)
92    client_requester: Option<ClientRequesterHandle>,
93    /// Task store for async operations
94    task_store: TaskStore,
95    /// Subscribed resource URIs
96    subscriptions: Arc<RwLock<HashSet<String>>>,
97    /// Handler for completion requests
98    completion_handler: Option<CompletionHandler>,
99}
100
101impl McpRouter {
102    /// Create a new MCP router
103    pub fn new() -> Self {
104        Self {
105            inner: Arc::new(McpRouterInner {
106                server_name: "tower-mcp".to_string(),
107                server_version: env!("CARGO_PKG_VERSION").to_string(),
108                instructions: None,
109                tools: HashMap::new(),
110                resources: HashMap::new(),
111                resource_templates: Vec::new(),
112                prompts: HashMap::new(),
113                in_flight: Arc::new(RwLock::new(HashMap::new())),
114                notification_tx: None,
115                client_requester: None,
116                task_store: TaskStore::new(),
117                subscriptions: Arc::new(RwLock::new(HashSet::new())),
118                completion_handler: None,
119            }),
120            session: SessionState::new(),
121        }
122    }
123
124    /// Get access to the task store for async operations
125    pub fn task_store(&self) -> &TaskStore {
126        &self.inner.task_store
127    }
128
129    /// Set the notification sender for progress reporting
130    ///
131    /// This is typically called by the transport layer to receive notifications.
132    pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
133        Arc::make_mut(&mut self.inner).notification_tx = Some(tx);
134        self
135    }
136
137    /// Get the notification sender (if configured)
138    pub fn notification_sender(&self) -> Option<&NotificationSender> {
139        self.inner.notification_tx.as_ref()
140    }
141
142    /// Set the client requester for server-to-client requests (sampling, etc.)
143    ///
144    /// This is typically called by bidirectional transports (WebSocket, stdio)
145    /// to enable tool handlers to send requests to the client.
146    pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
147        Arc::make_mut(&mut self.inner).client_requester = Some(requester);
148        self
149    }
150
151    /// Get the client requester (if configured)
152    pub fn client_requester(&self) -> Option<&ClientRequesterHandle> {
153        self.inner.client_requester.as_ref()
154    }
155
156    /// Create a request context for tracking a request
157    ///
158    /// This registers the request for cancellation tracking and sets up
159    /// progress reporting and client requests if configured.
160    pub fn create_context(
161        &self,
162        request_id: RequestId,
163        progress_token: Option<ProgressToken>,
164    ) -> RequestContext {
165        let ctx = RequestContext::new(request_id.clone());
166
167        // Set up progress token if provided
168        let ctx = if let Some(token) = progress_token {
169            ctx.with_progress_token(token)
170        } else {
171            ctx
172        };
173
174        // Set up notification sender if configured
175        let ctx = if let Some(tx) = &self.inner.notification_tx {
176            ctx.with_notification_sender(tx.clone())
177        } else {
178            ctx
179        };
180
181        // Set up client requester if configured (for sampling support)
182        let ctx = if let Some(requester) = &self.inner.client_requester {
183            ctx.with_client_requester(requester.clone())
184        } else {
185            ctx
186        };
187
188        // Register for cancellation tracking
189        let token = ctx.cancellation_token();
190        if let Ok(mut in_flight) = self.inner.in_flight.write() {
191            in_flight.insert(request_id, token);
192        }
193
194        ctx
195    }
196
197    /// Remove a request from tracking (called when request completes)
198    pub fn complete_request(&self, request_id: &RequestId) {
199        if let Ok(mut in_flight) = self.inner.in_flight.write() {
200            in_flight.remove(request_id);
201        }
202    }
203
204    /// Cancel a tracked request
205    fn cancel_request(&self, request_id: &RequestId) -> bool {
206        if let Ok(in_flight) = self.inner.in_flight.read()
207            && let Some(token) = in_flight.get(request_id)
208        {
209            token.cancel();
210            return true;
211        }
212        false
213    }
214
215    /// Set server info
216    pub fn server_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
217        let inner = Arc::make_mut(&mut self.inner);
218        inner.server_name = name.into();
219        inner.server_version = version.into();
220        self
221    }
222
223    /// Set instructions for LLMs describing how to use this server
224    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
225        Arc::make_mut(&mut self.inner).instructions = Some(instructions.into());
226        self
227    }
228
229    /// Register a tool
230    pub fn tool(mut self, tool: Tool) -> Self {
231        Arc::make_mut(&mut self.inner)
232            .tools
233            .insert(tool.name.clone(), Arc::new(tool));
234        self
235    }
236
237    /// Register a resource
238    pub fn resource(mut self, resource: Resource) -> Self {
239        Arc::make_mut(&mut self.inner)
240            .resources
241            .insert(resource.uri.clone(), Arc::new(resource));
242        self
243    }
244
245    /// Register a resource template
246    ///
247    /// Resource templates allow dynamic resources to be matched by URI pattern.
248    /// When a client requests a resource URI that doesn't match any static
249    /// resource, the router tries to match it against registered templates.
250    ///
251    /// # Example
252    ///
253    /// ```rust
254    /// use tower_mcp::{McpRouter, ResourceTemplateBuilder};
255    /// use tower_mcp::protocol::{ReadResourceResult, ResourceContent};
256    /// use std::collections::HashMap;
257    ///
258    /// let template = ResourceTemplateBuilder::new("file:///{path}")
259    ///     .name("Project Files")
260    ///     .handler(|uri: String, vars: HashMap<String, String>| async move {
261    ///         let path = vars.get("path").unwrap_or(&String::new()).clone();
262    ///         Ok(ReadResourceResult {
263    ///             contents: vec![ResourceContent {
264    ///                 uri,
265    ///                 mime_type: Some("text/plain".to_string()),
266    ///                 text: Some(format!("Contents of {}", path)),
267    ///                 blob: None,
268    ///             }],
269    ///         })
270    ///     });
271    ///
272    /// let router = McpRouter::new()
273    ///     .resource_template(template);
274    /// ```
275    pub fn resource_template(mut self, template: ResourceTemplate) -> Self {
276        Arc::make_mut(&mut self.inner)
277            .resource_templates
278            .push(Arc::new(template));
279        self
280    }
281
282    /// Register a prompt
283    pub fn prompt(mut self, prompt: Prompt) -> Self {
284        Arc::make_mut(&mut self.inner)
285            .prompts
286            .insert(prompt.name.clone(), Arc::new(prompt));
287        self
288    }
289
290    /// Register a completion handler for `completion/complete` requests.
291    ///
292    /// The handler receives `CompleteParams` containing the reference (prompt or resource)
293    /// and the argument being completed, and should return completion suggestions.
294    ///
295    /// # Example
296    ///
297    /// ```rust
298    /// use tower_mcp::{McpRouter, CompleteResult};
299    /// use tower_mcp::protocol::{CompleteParams, CompletionReference};
300    ///
301    /// let router = McpRouter::new()
302    ///     .completion_handler(|params: CompleteParams| async move {
303    ///         // Provide completions based on the reference and argument
304    ///         match params.reference {
305    ///             CompletionReference::Prompt { name } => {
306    ///                 // Return prompt argument completions
307    ///                 Ok(CompleteResult::new(vec!["option1".to_string(), "option2".to_string()]))
308    ///             }
309    ///             CompletionReference::Resource { uri } => {
310    ///                 // Return resource URI completions
311    ///                 Ok(CompleteResult::new(vec![]))
312    ///             }
313    ///         }
314    ///     });
315    /// ```
316    pub fn completion_handler<F, Fut>(mut self, handler: F) -> Self
317    where
318        F: Fn(CompleteParams) -> Fut + Send + Sync + 'static,
319        Fut: Future<Output = Result<CompleteResult>> + Send + 'static,
320    {
321        Arc::make_mut(&mut self.inner).completion_handler =
322            Some(Arc::new(move |params| Box::pin(handler(params))));
323        self
324    }
325
326    /// Get access to the session state
327    pub fn session(&self) -> &SessionState {
328        &self.session
329    }
330
331    /// Send a log message notification to the client
332    ///
333    /// This sends a `notifications/message` notification with the given parameters.
334    /// Returns `true` if the notification was sent, `false` if no notification channel
335    /// is configured.
336    ///
337    /// # Example
338    ///
339    /// ```rust,ignore
340    /// use tower_mcp::protocol::{LogLevel, LoggingMessageParams};
341    ///
342    /// // Simple info message
343    /// router.log(LoggingMessageParams::new(LogLevel::Info).with_data(
344    ///     serde_json::json!({"message": "Operation completed"})
345    /// ));
346    ///
347    /// // Error with logger name
348    /// router.log(LoggingMessageParams::new(LogLevel::Error)
349    ///     .with_logger("database")
350    ///     .with_data(serde_json::json!({"error": "Connection failed"})));
351    /// ```
352    pub fn log(&self, params: LoggingMessageParams) -> bool {
353        if let Some(tx) = &self.inner.notification_tx
354            && tx.try_send(ServerNotification::LogMessage(params)).is_ok()
355        {
356            return true;
357        }
358        false
359    }
360
361    /// Send an info-level log message
362    ///
363    /// Convenience method for sending an info log with optional data.
364    pub fn log_info(&self, message: &str) -> bool {
365        self.log(
366            LoggingMessageParams::new(LogLevel::Info)
367                .with_data(serde_json::json!({ "message": message })),
368        )
369    }
370
371    /// Send a warning-level log message
372    pub fn log_warning(&self, message: &str) -> bool {
373        self.log(
374            LoggingMessageParams::new(LogLevel::Warning)
375                .with_data(serde_json::json!({ "message": message })),
376        )
377    }
378
379    /// Send an error-level log message
380    pub fn log_error(&self, message: &str) -> bool {
381        self.log(
382            LoggingMessageParams::new(LogLevel::Error)
383                .with_data(serde_json::json!({ "message": message })),
384        )
385    }
386
387    /// Send a debug-level log message
388    pub fn log_debug(&self, message: &str) -> bool {
389        self.log(
390            LoggingMessageParams::new(LogLevel::Debug)
391                .with_data(serde_json::json!({ "message": message })),
392        )
393    }
394
395    /// Check if a resource URI is currently subscribed
396    pub fn is_subscribed(&self, uri: &str) -> bool {
397        if let Ok(subs) = self.inner.subscriptions.read() {
398            return subs.contains(uri);
399        }
400        false
401    }
402
403    /// Get a list of all subscribed resource URIs
404    pub fn subscribed_uris(&self) -> Vec<String> {
405        if let Ok(subs) = self.inner.subscriptions.read() {
406            return subs.iter().cloned().collect();
407        }
408        Vec::new()
409    }
410
411    /// Subscribe to a resource URI
412    fn subscribe(&self, uri: &str) -> bool {
413        if let Ok(mut subs) = self.inner.subscriptions.write() {
414            return subs.insert(uri.to_string());
415        }
416        false
417    }
418
419    /// Unsubscribe from a resource URI
420    fn unsubscribe(&self, uri: &str) -> bool {
421        if let Ok(mut subs) = self.inner.subscriptions.write() {
422            return subs.remove(uri);
423        }
424        false
425    }
426
427    /// Notify clients that a subscribed resource has been updated
428    ///
429    /// Only sends the notification if the resource is currently subscribed.
430    /// Returns `true` if the notification was sent.
431    pub fn notify_resource_updated(&self, uri: &str) -> bool {
432        // Only notify if the resource is subscribed
433        if !self.is_subscribed(uri) {
434            return false;
435        }
436
437        if let Some(tx) = &self.inner.notification_tx
438            && tx
439                .try_send(ServerNotification::ResourceUpdated {
440                    uri: uri.to_string(),
441                })
442                .is_ok()
443        {
444            return true;
445        }
446        false
447    }
448
449    /// Notify clients that the list of available resources has changed
450    ///
451    /// Returns `true` if the notification was sent.
452    pub fn notify_resources_list_changed(&self) -> bool {
453        if let Some(tx) = &self.inner.notification_tx
454            && tx
455                .try_send(ServerNotification::ResourcesListChanged)
456                .is_ok()
457        {
458            return true;
459        }
460        false
461    }
462
463    /// Get server capabilities based on registered handlers
464    fn capabilities(&self) -> ServerCapabilities {
465        let has_resources =
466            !self.inner.resources.is_empty() || !self.inner.resource_templates.is_empty();
467
468        ServerCapabilities {
469            tools: if self.inner.tools.is_empty() {
470                None
471            } else {
472                Some(ToolsCapability::default())
473            },
474            resources: if has_resources {
475                Some(ResourcesCapability {
476                    subscribe: true,
477                    ..Default::default()
478                })
479            } else {
480                None
481            },
482            prompts: if self.inner.prompts.is_empty() {
483                None
484            } else {
485                Some(PromptsCapability::default())
486            },
487            // Always advertise logging capability when notification channel is configured
488            logging: if self.inner.notification_tx.is_some() {
489                Some(LoggingCapability::default())
490            } else {
491                None
492            },
493            // Tasks capability is always available
494            tasks: Some(TasksCapability::default()),
495            // Completions capability when a handler is registered
496            completions: if self.inner.completion_handler.is_some() {
497                Some(CompletionsCapability::default())
498            } else {
499                None
500            },
501        }
502    }
503
504    /// Handle an MCP request
505    async fn handle(&self, request_id: RequestId, request: McpRequest) -> Result<McpResponse> {
506        // Enforce session state - reject requests before initialization
507        let method = request.method_name();
508        if !self.session.is_request_allowed(method) {
509            tracing::warn!(
510                method = %method,
511                phase = ?self.session.phase(),
512                "Request rejected: session not initialized"
513            );
514            return Err(Error::JsonRpc(JsonRpcError::invalid_request(format!(
515                "Session not initialized. Only 'initialize' and 'ping' are allowed before initialization. Got: {}",
516                method
517            ))));
518        }
519
520        match request {
521            McpRequest::Initialize(params) => {
522                tracing::info!(
523                    client = %params.client_info.name,
524                    version = %params.client_info.version,
525                    "Client initializing"
526                );
527
528                // Protocol version negotiation: respond with same version if supported,
529                // otherwise respond with our latest supported version
530                let protocol_version = if crate::protocol::SUPPORTED_PROTOCOL_VERSIONS
531                    .contains(&params.protocol_version.as_str())
532                {
533                    params.protocol_version
534                } else {
535                    crate::protocol::LATEST_PROTOCOL_VERSION.to_string()
536                };
537
538                // Transition session state to Initializing
539                self.session.mark_initializing();
540
541                Ok(McpResponse::Initialize(InitializeResult {
542                    protocol_version,
543                    capabilities: self.capabilities(),
544                    server_info: Implementation {
545                        name: self.inner.server_name.clone(),
546                        version: self.inner.server_version.clone(),
547                    },
548                    instructions: self.inner.instructions.clone(),
549                }))
550            }
551
552            McpRequest::ListTools(_params) => {
553                let tools: Vec<ToolDefinition> =
554                    self.inner.tools.values().map(|t| t.definition()).collect();
555
556                Ok(McpResponse::ListTools(ListToolsResult {
557                    tools,
558                    next_cursor: None,
559                }))
560            }
561
562            McpRequest::CallTool(params) => {
563                let tool =
564                    self.inner.tools.get(&params.name).ok_or_else(|| {
565                        Error::JsonRpc(JsonRpcError::method_not_found(&params.name))
566                    })?;
567
568                // Extract progress token from request metadata
569                let progress_token = params.meta.and_then(|m| m.progress_token);
570                let ctx = self.create_context(request_id, progress_token);
571
572                tracing::debug!(tool = %params.name, "Calling tool");
573                let result = tool.call_with_context(ctx, params.arguments).await?;
574
575                Ok(McpResponse::CallTool(result))
576            }
577
578            McpRequest::ListResources(_params) => {
579                let resources: Vec<ResourceDefinition> = self
580                    .inner
581                    .resources
582                    .values()
583                    .map(|r| r.definition())
584                    .collect();
585
586                Ok(McpResponse::ListResources(ListResourcesResult {
587                    resources,
588                    next_cursor: None,
589                }))
590            }
591
592            McpRequest::ListResourceTemplates(_params) => {
593                let resource_templates: Vec<ResourceTemplateDefinition> = self
594                    .inner
595                    .resource_templates
596                    .iter()
597                    .map(|t| t.definition())
598                    .collect();
599
600                Ok(McpResponse::ListResourceTemplates(
601                    ListResourceTemplatesResult {
602                        resource_templates,
603                        next_cursor: None,
604                    },
605                ))
606            }
607
608            McpRequest::ReadResource(params) => {
609                // First, try to find a static resource
610                if let Some(resource) = self.inner.resources.get(&params.uri) {
611                    tracing::debug!(uri = %params.uri, "Reading static resource");
612                    let result = resource.read().await?;
613                    return Ok(McpResponse::ReadResource(result));
614                }
615
616                // If no static resource found, try to match against templates
617                for template in &self.inner.resource_templates {
618                    if let Some(variables) = template.match_uri(&params.uri) {
619                        tracing::debug!(
620                            uri = %params.uri,
621                            template = %template.uri_template,
622                            "Reading resource via template"
623                        );
624                        let result = template.read(&params.uri, variables).await?;
625                        return Ok(McpResponse::ReadResource(result));
626                    }
627                }
628
629                // No match found
630                Err(Error::JsonRpc(JsonRpcError::resource_not_found(
631                    &params.uri,
632                )))
633            }
634
635            McpRequest::SubscribeResource(params) => {
636                // Verify the resource exists
637                if !self.inner.resources.contains_key(&params.uri) {
638                    return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
639                        &params.uri,
640                    )));
641                }
642
643                tracing::debug!(uri = %params.uri, "Subscribing to resource");
644                self.subscribe(&params.uri);
645
646                Ok(McpResponse::SubscribeResource(EmptyResult {}))
647            }
648
649            McpRequest::UnsubscribeResource(params) => {
650                // Verify the resource exists
651                if !self.inner.resources.contains_key(&params.uri) {
652                    return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
653                        &params.uri,
654                    )));
655                }
656
657                tracing::debug!(uri = %params.uri, "Unsubscribing from resource");
658                self.unsubscribe(&params.uri);
659
660                Ok(McpResponse::UnsubscribeResource(EmptyResult {}))
661            }
662
663            McpRequest::ListPrompts(_params) => {
664                let prompts: Vec<PromptDefinition> = self
665                    .inner
666                    .prompts
667                    .values()
668                    .map(|p| p.definition())
669                    .collect();
670
671                Ok(McpResponse::ListPrompts(ListPromptsResult {
672                    prompts,
673                    next_cursor: None,
674                }))
675            }
676
677            McpRequest::GetPrompt(params) => {
678                let prompt = self.inner.prompts.get(&params.name).ok_or_else(|| {
679                    Error::JsonRpc(JsonRpcError::method_not_found(&format!(
680                        "Prompt not found: {}",
681                        params.name
682                    )))
683                })?;
684
685                tracing::debug!(name = %params.name, "Getting prompt");
686                let result = prompt.get(params.arguments).await?;
687
688                Ok(McpResponse::GetPrompt(result))
689            }
690
691            McpRequest::Ping => Ok(McpResponse::Pong(EmptyResult {})),
692
693            McpRequest::EnqueueTask(params) => {
694                // Verify the tool exists
695                let tool = self.inner.tools.get(&params.tool_name).ok_or_else(|| {
696                    Error::JsonRpc(JsonRpcError::method_not_found(&format!(
697                        "Tool not found: {}",
698                        params.tool_name
699                    )))
700                })?;
701
702                // Create the task
703                let (task_id, cancellation_token) = self.inner.task_store.create_task(
704                    &params.tool_name,
705                    params.arguments.clone(),
706                    params.ttl,
707                );
708
709                tracing::info!(task_id = %task_id, tool = %params.tool_name, "Enqueued async task");
710
711                // Create a context for the async task execution
712                let ctx = self.create_context(request_id, None);
713
714                // Spawn the task execution in the background
715                let task_store = self.inner.task_store.clone();
716                let tool = tool.clone();
717                let arguments = params.arguments;
718                let task_id_clone = task_id.clone();
719
720                tokio::spawn(async move {
721                    // Check for cancellation before starting
722                    if cancellation_token.is_cancelled() {
723                        tracing::debug!(task_id = %task_id_clone, "Task cancelled before execution");
724                        return;
725                    }
726
727                    // Execute the tool
728                    match tool.call_with_context(ctx, arguments).await {
729                        Ok(result) => {
730                            if cancellation_token.is_cancelled() {
731                                tracing::debug!(task_id = %task_id_clone, "Task cancelled during execution");
732                            } else {
733                                task_store.complete_task(&task_id_clone, result);
734                                tracing::debug!(task_id = %task_id_clone, "Task completed successfully");
735                            }
736                        }
737                        Err(e) => {
738                            task_store.fail_task(&task_id_clone, &e.to_string());
739                            tracing::warn!(task_id = %task_id_clone, error = %e, "Task failed");
740                        }
741                    }
742                });
743
744                Ok(McpResponse::EnqueueTask(EnqueueTaskResult {
745                    task_id,
746                    status: TaskStatus::Working,
747                    poll_interval: Some(2),
748                }))
749            }
750
751            McpRequest::ListTasks(params) => {
752                let tasks = self.inner.task_store.list_tasks(params.status);
753
754                Ok(McpResponse::ListTasks(ListTasksResult {
755                    tasks,
756                    next_cursor: None,
757                }))
758            }
759
760            McpRequest::GetTaskInfo(params) => {
761                let task = self
762                    .inner
763                    .task_store
764                    .get_task(&params.task_id)
765                    .ok_or_else(|| {
766                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
767                            "Task not found: {}",
768                            params.task_id
769                        )))
770                    })?;
771
772                Ok(McpResponse::GetTaskInfo(task))
773            }
774
775            McpRequest::GetTaskResult(params) => {
776                let (status, result, error) = self
777                    .inner
778                    .task_store
779                    .get_task_full(&params.task_id)
780                    .ok_or_else(|| {
781                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
782                            "Task not found: {}",
783                            params.task_id
784                        )))
785                    })?;
786
787                Ok(McpResponse::GetTaskResult(GetTaskResultResult {
788                    task_id: params.task_id,
789                    status,
790                    result,
791                    error,
792                }))
793            }
794
795            McpRequest::CancelTask(params) => {
796                let status = self
797                    .inner
798                    .task_store
799                    .cancel_task(&params.task_id, params.reason.as_deref())
800                    .ok_or_else(|| {
801                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
802                            "Task not found: {}",
803                            params.task_id
804                        )))
805                    })?;
806
807                let cancelled = status == TaskStatus::Cancelled;
808
809                Ok(McpResponse::CancelTask(CancelTaskResult {
810                    cancelled,
811                    status,
812                }))
813            }
814
815            McpRequest::SetLoggingLevel(params) => {
816                // Store the log level for filtering outgoing log notifications
817                // For now, we just accept the request - actual filtering would be
818                // implemented in the notification sending logic
819                tracing::debug!(level = ?params.level, "Client set logging level");
820                Ok(McpResponse::SetLoggingLevel(EmptyResult {}))
821            }
822
823            McpRequest::Complete(params) => {
824                tracing::debug!(
825                    reference = ?params.reference,
826                    argument = %params.argument.name,
827                    "Completion request"
828                );
829
830                // Delegate to registered completion handler if available
831                if let Some(ref handler) = self.inner.completion_handler {
832                    let result = handler(params).await?;
833                    Ok(McpResponse::Complete(result))
834                } else {
835                    // No completion handler registered, return empty completions
836                    Ok(McpResponse::Complete(CompleteResult::new(vec![])))
837                }
838            }
839
840            McpRequest::Unknown { method, .. } => {
841                Err(Error::JsonRpc(JsonRpcError::method_not_found(&method)))
842            }
843        }
844    }
845
846    /// Handle an MCP notification (no response expected)
847    pub fn handle_notification(&self, notification: McpNotification) {
848        match notification {
849            McpNotification::Initialized => {
850                if self.session.mark_initialized() {
851                    tracing::info!("Session initialized, entering operation phase");
852                } else {
853                    tracing::warn!(
854                        "Received initialized notification in unexpected state: {:?}",
855                        self.session.phase()
856                    );
857                }
858            }
859            McpNotification::Cancelled(params) => {
860                if self.cancel_request(&params.request_id) {
861                    tracing::info!(
862                        request_id = ?params.request_id,
863                        reason = ?params.reason,
864                        "Request cancelled"
865                    );
866                } else {
867                    tracing::debug!(
868                        request_id = ?params.request_id,
869                        reason = ?params.reason,
870                        "Cancellation requested for unknown request"
871                    );
872                }
873            }
874            McpNotification::Progress(params) => {
875                tracing::trace!(
876                    token = ?params.progress_token,
877                    progress = params.progress,
878                    total = ?params.total,
879                    "Progress notification"
880                );
881                // Progress notifications from client are unusual but valid
882            }
883            McpNotification::RootsListChanged => {
884                tracing::info!("Client roots list changed");
885                // Server should re-request roots if needed
886                // This is handled by the application layer
887            }
888            McpNotification::Unknown { method, .. } => {
889                tracing::debug!(method = %method, "Unknown notification received");
890            }
891        }
892    }
893}
894
895impl Default for McpRouter {
896    fn default() -> Self {
897        Self::new()
898    }
899}
900
901// =============================================================================
902// Tower Service implementation
903// =============================================================================
904
905/// Request type for the tower Service implementation
906#[derive(Debug)]
907pub struct RouterRequest {
908    pub id: RequestId,
909    pub inner: McpRequest,
910}
911
912/// Response type for the tower Service implementation
913#[derive(Debug)]
914pub struct RouterResponse {
915    pub id: RequestId,
916    pub inner: std::result::Result<McpResponse, JsonRpcError>,
917}
918
919impl RouterResponse {
920    /// Convert to JSON-RPC response
921    pub fn into_jsonrpc(self) -> JsonRpcResponse {
922        match self.inner {
923            Ok(response) => match serde_json::to_value(response) {
924                Ok(result) => JsonRpcResponse::result(self.id, result),
925                Err(e) => {
926                    tracing::error!(error = %e, "Failed to serialize response");
927                    JsonRpcResponse::error(
928                        Some(self.id),
929                        JsonRpcError::internal_error(format!("Serialization error: {}", e)),
930                    )
931                }
932            },
933            Err(error) => JsonRpcResponse::error(Some(self.id), error),
934        }
935    }
936}
937
938impl Service<RouterRequest> for McpRouter {
939    type Response = RouterResponse;
940    type Error = std::convert::Infallible; // Errors are in the response
941    type Future =
942        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
943
944    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
945        Poll::Ready(Ok(()))
946    }
947
948    fn call(&mut self, req: RouterRequest) -> Self::Future {
949        let router = self.clone();
950        let request_id = req.id.clone();
951        Box::pin(async move {
952            let result = router.handle(req.id, req.inner).await;
953            // Clean up tracking after request completes
954            router.complete_request(&request_id);
955            Ok(RouterResponse {
956                id: request_id,
957                inner: result.map_err(|e| match e {
958                    Error::JsonRpc(err) => err,
959                    Error::Tool(err) => JsonRpcError::internal_error(err.to_string()),
960                    e => JsonRpcError::internal_error(e.to_string()),
961                }),
962            })
963        })
964    }
965}
966
967#[cfg(test)]
968mod tests {
969    use super::*;
970    use crate::jsonrpc::JsonRpcService;
971    use crate::tool::ToolBuilder;
972    use schemars::JsonSchema;
973    use serde::Deserialize;
974    use tower::ServiceExt;
975
976    #[derive(Debug, Deserialize, JsonSchema)]
977    struct AddInput {
978        a: i64,
979        b: i64,
980    }
981
982    /// Helper to initialize a router for testing
983    async fn init_router(router: &mut McpRouter) {
984        // Send initialize request
985        let init_req = RouterRequest {
986            id: RequestId::Number(0),
987            inner: McpRequest::Initialize(InitializeParams {
988                protocol_version: "2025-03-26".to_string(),
989                capabilities: ClientCapabilities {
990                    roots: None,
991                    sampling: None,
992                    elicitation: None,
993                },
994                client_info: Implementation {
995                    name: "test".to_string(),
996                    version: "1.0".to_string(),
997                },
998            }),
999        };
1000        let _ = router.ready().await.unwrap().call(init_req).await.unwrap();
1001        // Send initialized notification
1002        router.handle_notification(McpNotification::Initialized);
1003    }
1004
1005    #[tokio::test]
1006    async fn test_router_list_tools() {
1007        let add_tool = ToolBuilder::new("add")
1008            .description("Add two numbers")
1009            .handler(|input: AddInput| async move {
1010                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1011            })
1012            .build()
1013            .expect("valid tool name");
1014
1015        let mut router = McpRouter::new().tool(add_tool);
1016
1017        // Initialize session first
1018        init_router(&mut router).await;
1019
1020        let req = RouterRequest {
1021            id: RequestId::Number(1),
1022            inner: McpRequest::ListTools(ListToolsParams::default()),
1023        };
1024
1025        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1026
1027        match resp.inner {
1028            Ok(McpResponse::ListTools(result)) => {
1029                assert_eq!(result.tools.len(), 1);
1030                assert_eq!(result.tools[0].name, "add");
1031            }
1032            _ => panic!("Expected ListTools response"),
1033        }
1034    }
1035
1036    #[tokio::test]
1037    async fn test_router_call_tool() {
1038        let add_tool = ToolBuilder::new("add")
1039            .description("Add two numbers")
1040            .handler(|input: AddInput| async move {
1041                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1042            })
1043            .build()
1044            .expect("valid tool name");
1045
1046        let mut router = McpRouter::new().tool(add_tool);
1047
1048        // Initialize session first
1049        init_router(&mut router).await;
1050
1051        let req = RouterRequest {
1052            id: RequestId::Number(1),
1053            inner: McpRequest::CallTool(CallToolParams {
1054                name: "add".to_string(),
1055                arguments: serde_json::json!({"a": 2, "b": 3}),
1056                meta: None,
1057            }),
1058        };
1059
1060        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1061
1062        match resp.inner {
1063            Ok(McpResponse::CallTool(result)) => {
1064                assert!(!result.is_error);
1065                // Check the text content
1066                match &result.content[0] {
1067                    Content::Text { text, .. } => assert_eq!(text, "5"),
1068                    _ => panic!("Expected text content"),
1069                }
1070            }
1071            _ => panic!("Expected CallTool response"),
1072        }
1073    }
1074
1075    /// Helper to initialize a JsonRpcService for testing
1076    async fn init_jsonrpc_service(service: &mut JsonRpcService<McpRouter>, router: &McpRouter) {
1077        let init_req = JsonRpcRequest::new(0, "initialize").with_params(serde_json::json!({
1078            "protocolVersion": "2025-03-26",
1079            "capabilities": {},
1080            "clientInfo": { "name": "test", "version": "1.0" }
1081        }));
1082        let _ = service.call_single(init_req).await.unwrap();
1083        router.handle_notification(McpNotification::Initialized);
1084    }
1085
1086    #[tokio::test]
1087    async fn test_jsonrpc_service() {
1088        let add_tool = ToolBuilder::new("add")
1089            .description("Add two numbers")
1090            .handler(|input: AddInput| async move {
1091                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1092            })
1093            .build()
1094            .expect("valid tool name");
1095
1096        let router = McpRouter::new().tool(add_tool);
1097        let mut service = JsonRpcService::new(router.clone());
1098
1099        // Initialize session first
1100        init_jsonrpc_service(&mut service, &router).await;
1101
1102        let req = JsonRpcRequest::new(1, "tools/list");
1103
1104        let resp = service.call_single(req).await.unwrap();
1105
1106        match resp {
1107            JsonRpcResponse::Result(r) => {
1108                assert_eq!(r.id, RequestId::Number(1));
1109                let tools = r.result.get("tools").unwrap().as_array().unwrap();
1110                assert_eq!(tools.len(), 1);
1111            }
1112            JsonRpcResponse::Error(_) => panic!("Expected success response"),
1113        }
1114    }
1115
1116    #[tokio::test]
1117    async fn test_batch_request() {
1118        let add_tool = ToolBuilder::new("add")
1119            .description("Add two numbers")
1120            .handler(|input: AddInput| async move {
1121                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1122            })
1123            .build()
1124            .expect("valid tool name");
1125
1126        let router = McpRouter::new().tool(add_tool);
1127        let mut service = JsonRpcService::new(router.clone());
1128
1129        // Initialize session first
1130        init_jsonrpc_service(&mut service, &router).await;
1131
1132        // Create a batch of requests
1133        let requests = vec![
1134            JsonRpcRequest::new(1, "tools/list"),
1135            JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1136                "name": "add",
1137                "arguments": {"a": 10, "b": 20}
1138            })),
1139            JsonRpcRequest::new(3, "ping"),
1140        ];
1141
1142        let responses = service.call_batch(requests).await.unwrap();
1143
1144        assert_eq!(responses.len(), 3);
1145
1146        // Check first response (tools/list)
1147        match &responses[0] {
1148            JsonRpcResponse::Result(r) => {
1149                assert_eq!(r.id, RequestId::Number(1));
1150                let tools = r.result.get("tools").unwrap().as_array().unwrap();
1151                assert_eq!(tools.len(), 1);
1152            }
1153            JsonRpcResponse::Error(_) => panic!("Expected success for tools/list"),
1154        }
1155
1156        // Check second response (tools/call)
1157        match &responses[1] {
1158            JsonRpcResponse::Result(r) => {
1159                assert_eq!(r.id, RequestId::Number(2));
1160                let content = r.result.get("content").unwrap().as_array().unwrap();
1161                let text = content[0].get("text").unwrap().as_str().unwrap();
1162                assert_eq!(text, "30");
1163            }
1164            JsonRpcResponse::Error(_) => panic!("Expected success for tools/call"),
1165        }
1166
1167        // Check third response (ping)
1168        match &responses[2] {
1169            JsonRpcResponse::Result(r) => {
1170                assert_eq!(r.id, RequestId::Number(3));
1171            }
1172            JsonRpcResponse::Error(_) => panic!("Expected success for ping"),
1173        }
1174    }
1175
1176    #[tokio::test]
1177    async fn test_empty_batch_error() {
1178        let router = McpRouter::new();
1179        let mut service = JsonRpcService::new(router);
1180
1181        let result = service.call_batch(vec![]).await;
1182        assert!(result.is_err());
1183    }
1184
1185    // =========================================================================
1186    // Progress Token Tests
1187    // =========================================================================
1188
1189    #[tokio::test]
1190    async fn test_progress_token_extraction() {
1191        use crate::context::{RequestContext, ServerNotification, notification_channel};
1192        use crate::protocol::ProgressToken;
1193        use std::sync::Arc;
1194        use std::sync::atomic::{AtomicBool, Ordering};
1195
1196        // Track whether progress was reported
1197        let progress_reported = Arc::new(AtomicBool::new(false));
1198        let progress_ref = progress_reported.clone();
1199
1200        // Create a tool that reports progress
1201        let tool = ToolBuilder::new("progress_tool")
1202            .description("Tool that reports progress")
1203            .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1204                let reported = progress_ref.clone();
1205                async move {
1206                    // Report progress - this should work if token was extracted
1207                    ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
1208                        .await;
1209                    reported.store(true, Ordering::SeqCst);
1210                    Ok(CallToolResult::text("done"))
1211                }
1212            })
1213            .build()
1214            .expect("valid tool name");
1215
1216        // Set up notification channel
1217        let (tx, mut rx) = notification_channel(10);
1218        let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1219        let mut service = JsonRpcService::new(router.clone());
1220
1221        // Initialize
1222        init_jsonrpc_service(&mut service, &router).await;
1223
1224        // Call tool WITH progress token in _meta
1225        let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1226            "name": "progress_tool",
1227            "arguments": {"a": 1, "b": 2},
1228            "_meta": {
1229                "progressToken": "test-token-123"
1230            }
1231        }));
1232
1233        let resp = service.call_single(req).await.unwrap();
1234
1235        // Verify the tool was called successfully
1236        match resp {
1237            JsonRpcResponse::Result(_) => {}
1238            JsonRpcResponse::Error(e) => panic!("Expected success, got error: {:?}", e),
1239        }
1240
1241        // Verify progress was reported by handler
1242        assert!(progress_reported.load(Ordering::SeqCst));
1243
1244        // Verify progress notification was sent through channel
1245        let notification = rx.try_recv().expect("Expected progress notification");
1246        match notification {
1247            ServerNotification::Progress(params) => {
1248                assert_eq!(
1249                    params.progress_token,
1250                    ProgressToken::String("test-token-123".to_string())
1251                );
1252                assert_eq!(params.progress, 50.0);
1253                assert_eq!(params.total, Some(100.0));
1254                assert_eq!(params.message.as_deref(), Some("Halfway"));
1255            }
1256            _ => panic!("Expected Progress notification"),
1257        }
1258    }
1259
1260    #[tokio::test]
1261    async fn test_tool_call_without_progress_token() {
1262        use crate::context::{RequestContext, notification_channel};
1263        use std::sync::Arc;
1264        use std::sync::atomic::{AtomicBool, Ordering};
1265
1266        let progress_attempted = Arc::new(AtomicBool::new(false));
1267        let progress_ref = progress_attempted.clone();
1268
1269        let tool = ToolBuilder::new("no_token_tool")
1270            .description("Tool that tries to report progress without token")
1271            .handler_with_context(move |ctx: RequestContext, _input: AddInput| {
1272                let attempted = progress_ref.clone();
1273                async move {
1274                    // Try to report progress - should be a no-op without token
1275                    ctx.report_progress(50.0, Some(100.0), None).await;
1276                    attempted.store(true, Ordering::SeqCst);
1277                    Ok(CallToolResult::text("done"))
1278                }
1279            })
1280            .build()
1281            .expect("valid tool name");
1282
1283        let (tx, mut rx) = notification_channel(10);
1284        let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1285        let mut service = JsonRpcService::new(router.clone());
1286
1287        init_jsonrpc_service(&mut service, &router).await;
1288
1289        // Call tool WITHOUT progress token
1290        let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1291            "name": "no_token_tool",
1292            "arguments": {"a": 1, "b": 2}
1293        }));
1294
1295        let resp = service.call_single(req).await.unwrap();
1296        assert!(matches!(resp, JsonRpcResponse::Result(_)));
1297
1298        // Handler was called
1299        assert!(progress_attempted.load(Ordering::SeqCst));
1300
1301        // But no notification was sent (no progress token)
1302        assert!(rx.try_recv().is_err());
1303    }
1304
1305    #[tokio::test]
1306    async fn test_batch_errors_returned_not_dropped() {
1307        let add_tool = ToolBuilder::new("add")
1308            .description("Add two numbers")
1309            .handler(|input: AddInput| async move {
1310                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1311            })
1312            .build()
1313            .expect("valid tool name");
1314
1315        let router = McpRouter::new().tool(add_tool);
1316        let mut service = JsonRpcService::new(router.clone());
1317
1318        init_jsonrpc_service(&mut service, &router).await;
1319
1320        // Create a batch with one valid and one invalid request
1321        let requests = vec![
1322            // Valid request
1323            JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1324                "name": "add",
1325                "arguments": {"a": 10, "b": 20}
1326            })),
1327            // Invalid request - tool doesn't exist
1328            JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1329                "name": "nonexistent_tool",
1330                "arguments": {}
1331            })),
1332            // Another valid request
1333            JsonRpcRequest::new(3, "ping"),
1334        ];
1335
1336        let responses = service.call_batch(requests).await.unwrap();
1337
1338        // All three requests should have responses (errors are not dropped)
1339        assert_eq!(responses.len(), 3);
1340
1341        // First should be success
1342        match &responses[0] {
1343            JsonRpcResponse::Result(r) => {
1344                assert_eq!(r.id, RequestId::Number(1));
1345            }
1346            JsonRpcResponse::Error(_) => panic!("Expected success for first request"),
1347        }
1348
1349        // Second should be an error (tool not found)
1350        match &responses[1] {
1351            JsonRpcResponse::Error(e) => {
1352                assert_eq!(e.id, Some(RequestId::Number(2)));
1353                // Error should indicate method not found
1354                assert!(e.error.message.contains("not found") || e.error.code == -32601);
1355            }
1356            JsonRpcResponse::Result(_) => panic!("Expected error for second request"),
1357        }
1358
1359        // Third should be success
1360        match &responses[2] {
1361            JsonRpcResponse::Result(r) => {
1362                assert_eq!(r.id, RequestId::Number(3));
1363            }
1364            JsonRpcResponse::Error(_) => panic!("Expected success for third request"),
1365        }
1366    }
1367
1368    // =========================================================================
1369    // Resource Template Tests
1370    // =========================================================================
1371
1372    #[tokio::test]
1373    async fn test_list_resource_templates() {
1374        use crate::resource::ResourceTemplateBuilder;
1375        use std::collections::HashMap;
1376
1377        let template = ResourceTemplateBuilder::new("file:///{path}")
1378            .name("Project Files")
1379            .description("Access project files")
1380            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1381                Ok(ReadResourceResult {
1382                    contents: vec![ResourceContent {
1383                        uri,
1384                        mime_type: None,
1385                        text: None,
1386                        blob: None,
1387                    }],
1388                })
1389            });
1390
1391        let mut router = McpRouter::new().resource_template(template);
1392
1393        // Initialize session
1394        init_router(&mut router).await;
1395
1396        let req = RouterRequest {
1397            id: RequestId::Number(1),
1398            inner: McpRequest::ListResourceTemplates(ListResourceTemplatesParams::default()),
1399        };
1400
1401        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1402
1403        match resp.inner {
1404            Ok(McpResponse::ListResourceTemplates(result)) => {
1405                assert_eq!(result.resource_templates.len(), 1);
1406                assert_eq!(result.resource_templates[0].uri_template, "file:///{path}");
1407                assert_eq!(result.resource_templates[0].name, "Project Files");
1408            }
1409            _ => panic!("Expected ListResourceTemplates response"),
1410        }
1411    }
1412
1413    #[tokio::test]
1414    async fn test_read_resource_via_template() {
1415        use crate::resource::ResourceTemplateBuilder;
1416        use std::collections::HashMap;
1417
1418        let template = ResourceTemplateBuilder::new("db://users/{id}")
1419            .name("User Records")
1420            .handler(|uri: String, vars: HashMap<String, String>| async move {
1421                let id = vars.get("id").unwrap().clone();
1422                Ok(ReadResourceResult {
1423                    contents: vec![ResourceContent {
1424                        uri,
1425                        mime_type: Some("application/json".to_string()),
1426                        text: Some(format!(r#"{{"id": "{}"}}"#, id)),
1427                        blob: None,
1428                    }],
1429                })
1430            });
1431
1432        let mut router = McpRouter::new().resource_template(template);
1433
1434        // Initialize session
1435        init_router(&mut router).await;
1436
1437        // Read a resource that matches the template
1438        let req = RouterRequest {
1439            id: RequestId::Number(1),
1440            inner: McpRequest::ReadResource(ReadResourceParams {
1441                uri: "db://users/123".to_string(),
1442            }),
1443        };
1444
1445        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1446
1447        match resp.inner {
1448            Ok(McpResponse::ReadResource(result)) => {
1449                assert_eq!(result.contents.len(), 1);
1450                assert_eq!(result.contents[0].uri, "db://users/123");
1451                assert!(result.contents[0].text.as_ref().unwrap().contains("123"));
1452            }
1453            _ => panic!("Expected ReadResource response"),
1454        }
1455    }
1456
1457    #[tokio::test]
1458    async fn test_static_resource_takes_precedence_over_template() {
1459        use crate::resource::{ResourceBuilder, ResourceTemplateBuilder};
1460        use std::collections::HashMap;
1461
1462        // Template that would match the same URI
1463        let template = ResourceTemplateBuilder::new("file:///{path}")
1464            .name("Files Template")
1465            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1466                Ok(ReadResourceResult {
1467                    contents: vec![ResourceContent {
1468                        uri,
1469                        mime_type: None,
1470                        text: Some("from template".to_string()),
1471                        blob: None,
1472                    }],
1473                })
1474            });
1475
1476        // Static resource with exact URI
1477        let static_resource = ResourceBuilder::new("file:///README.md")
1478            .name("README")
1479            .text("from static resource");
1480
1481        let mut router = McpRouter::new()
1482            .resource_template(template)
1483            .resource(static_resource);
1484
1485        // Initialize session
1486        init_router(&mut router).await;
1487
1488        // Read the static resource - should NOT go through template
1489        let req = RouterRequest {
1490            id: RequestId::Number(1),
1491            inner: McpRequest::ReadResource(ReadResourceParams {
1492                uri: "file:///README.md".to_string(),
1493            }),
1494        };
1495
1496        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1497
1498        match resp.inner {
1499            Ok(McpResponse::ReadResource(result)) => {
1500                // Should get static resource, not template
1501                assert_eq!(
1502                    result.contents[0].text.as_deref(),
1503                    Some("from static resource")
1504                );
1505            }
1506            _ => panic!("Expected ReadResource response"),
1507        }
1508    }
1509
1510    #[tokio::test]
1511    async fn test_resource_not_found_when_no_match() {
1512        use crate::resource::ResourceTemplateBuilder;
1513        use std::collections::HashMap;
1514
1515        let template = ResourceTemplateBuilder::new("db://users/{id}")
1516            .name("Users")
1517            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1518                Ok(ReadResourceResult {
1519                    contents: vec![ResourceContent {
1520                        uri,
1521                        mime_type: None,
1522                        text: None,
1523                        blob: None,
1524                    }],
1525                })
1526            });
1527
1528        let mut router = McpRouter::new().resource_template(template);
1529
1530        // Initialize session
1531        init_router(&mut router).await;
1532
1533        // Try to read a URI that doesn't match any resource or template
1534        let req = RouterRequest {
1535            id: RequestId::Number(1),
1536            inner: McpRequest::ReadResource(ReadResourceParams {
1537                uri: "db://posts/123".to_string(),
1538            }),
1539        };
1540
1541        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1542
1543        match resp.inner {
1544            Err(err) => {
1545                assert!(err.message.contains("not found"));
1546            }
1547            Ok(_) => panic!("Expected error for non-matching URI"),
1548        }
1549    }
1550
1551    #[tokio::test]
1552    async fn test_capabilities_include_resources_with_only_templates() {
1553        use crate::resource::ResourceTemplateBuilder;
1554        use std::collections::HashMap;
1555
1556        let template = ResourceTemplateBuilder::new("file:///{path}")
1557            .name("Files")
1558            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1559                Ok(ReadResourceResult {
1560                    contents: vec![ResourceContent {
1561                        uri,
1562                        mime_type: None,
1563                        text: None,
1564                        blob: None,
1565                    }],
1566                })
1567            });
1568
1569        let mut router = McpRouter::new().resource_template(template);
1570
1571        // Send initialize request and check capabilities
1572        let init_req = RouterRequest {
1573            id: RequestId::Number(0),
1574            inner: McpRequest::Initialize(InitializeParams {
1575                protocol_version: "2025-03-26".to_string(),
1576                capabilities: ClientCapabilities {
1577                    roots: None,
1578                    sampling: None,
1579                    elicitation: None,
1580                },
1581                client_info: Implementation {
1582                    name: "test".to_string(),
1583                    version: "1.0".to_string(),
1584                },
1585            }),
1586        };
1587        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1588
1589        match resp.inner {
1590            Ok(McpResponse::Initialize(result)) => {
1591                // Should have resources capability even though only templates registered
1592                assert!(result.capabilities.resources.is_some());
1593            }
1594            _ => panic!("Expected Initialize response"),
1595        }
1596    }
1597
1598    // =========================================================================
1599    // Logging Notification Tests
1600    // =========================================================================
1601
1602    #[tokio::test]
1603    async fn test_log_sends_notification() {
1604        use crate::context::notification_channel;
1605
1606        let (tx, mut rx) = notification_channel(10);
1607        let router = McpRouter::new().with_notification_sender(tx);
1608
1609        // Send an info log
1610        let sent = router.log_info("Test message");
1611        assert!(sent);
1612
1613        // Should receive the notification
1614        let notification = rx.try_recv().unwrap();
1615        match notification {
1616            ServerNotification::LogMessage(params) => {
1617                assert_eq!(params.level, LogLevel::Info);
1618                let data = params.data.unwrap();
1619                assert_eq!(
1620                    data.get("message").unwrap().as_str().unwrap(),
1621                    "Test message"
1622                );
1623            }
1624            _ => panic!("Expected LogMessage notification"),
1625        }
1626    }
1627
1628    #[tokio::test]
1629    async fn test_log_with_custom_params() {
1630        use crate::context::notification_channel;
1631
1632        let (tx, mut rx) = notification_channel(10);
1633        let router = McpRouter::new().with_notification_sender(tx);
1634
1635        // Send a custom log message
1636        let params = LoggingMessageParams::new(LogLevel::Error)
1637            .with_logger("database")
1638            .with_data(serde_json::json!({
1639                "error": "Connection failed",
1640                "host": "localhost"
1641            }));
1642
1643        let sent = router.log(params);
1644        assert!(sent);
1645
1646        let notification = rx.try_recv().unwrap();
1647        match notification {
1648            ServerNotification::LogMessage(params) => {
1649                assert_eq!(params.level, LogLevel::Error);
1650                assert_eq!(params.logger.as_deref(), Some("database"));
1651                let data = params.data.unwrap();
1652                assert_eq!(
1653                    data.get("error").unwrap().as_str().unwrap(),
1654                    "Connection failed"
1655                );
1656            }
1657            _ => panic!("Expected LogMessage notification"),
1658        }
1659    }
1660
1661    #[tokio::test]
1662    async fn test_log_without_channel_returns_false() {
1663        // Router without notification channel
1664        let router = McpRouter::new();
1665
1666        // Should return false when no channel configured
1667        assert!(!router.log_info("Test"));
1668        assert!(!router.log_warning("Test"));
1669        assert!(!router.log_error("Test"));
1670        assert!(!router.log_debug("Test"));
1671    }
1672
1673    #[tokio::test]
1674    async fn test_logging_capability_with_channel() {
1675        use crate::context::notification_channel;
1676
1677        let (tx, _rx) = notification_channel(10);
1678        let mut router = McpRouter::new().with_notification_sender(tx);
1679
1680        // Initialize and check capabilities
1681        let init_req = RouterRequest {
1682            id: RequestId::Number(0),
1683            inner: McpRequest::Initialize(InitializeParams {
1684                protocol_version: "2025-03-26".to_string(),
1685                capabilities: ClientCapabilities {
1686                    roots: None,
1687                    sampling: None,
1688                    elicitation: None,
1689                },
1690                client_info: Implementation {
1691                    name: "test".to_string(),
1692                    version: "1.0".to_string(),
1693                },
1694            }),
1695        };
1696        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1697
1698        match resp.inner {
1699            Ok(McpResponse::Initialize(result)) => {
1700                // Should have logging capability when notification channel is set
1701                assert!(result.capabilities.logging.is_some());
1702            }
1703            _ => panic!("Expected Initialize response"),
1704        }
1705    }
1706
1707    #[tokio::test]
1708    async fn test_no_logging_capability_without_channel() {
1709        let mut router = McpRouter::new();
1710
1711        // Initialize and check capabilities
1712        let init_req = RouterRequest {
1713            id: RequestId::Number(0),
1714            inner: McpRequest::Initialize(InitializeParams {
1715                protocol_version: "2025-03-26".to_string(),
1716                capabilities: ClientCapabilities {
1717                    roots: None,
1718                    sampling: None,
1719                    elicitation: None,
1720                },
1721                client_info: Implementation {
1722                    name: "test".to_string(),
1723                    version: "1.0".to_string(),
1724                },
1725            }),
1726        };
1727        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
1728
1729        match resp.inner {
1730            Ok(McpResponse::Initialize(result)) => {
1731                // Should NOT have logging capability without notification channel
1732                assert!(result.capabilities.logging.is_none());
1733            }
1734            _ => panic!("Expected Initialize response"),
1735        }
1736    }
1737
1738    // =========================================================================
1739    // Task Lifecycle Tests
1740    // =========================================================================
1741
1742    #[tokio::test]
1743    async fn test_enqueue_task() {
1744        let add_tool = ToolBuilder::new("add")
1745            .description("Add two numbers")
1746            .handler(|input: AddInput| async move {
1747                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1748            })
1749            .build()
1750            .expect("valid tool name");
1751
1752        let mut router = McpRouter::new().tool(add_tool);
1753        init_router(&mut router).await;
1754
1755        let req = RouterRequest {
1756            id: RequestId::Number(1),
1757            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1758                tool_name: "add".to_string(),
1759                arguments: serde_json::json!({"a": 5, "b": 10}),
1760                ttl: None,
1761            }),
1762        };
1763
1764        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1765
1766        match resp.inner {
1767            Ok(McpResponse::EnqueueTask(result)) => {
1768                assert!(result.task_id.starts_with("task-"));
1769                assert_eq!(result.status, TaskStatus::Working);
1770            }
1771            _ => panic!("Expected EnqueueTask response"),
1772        }
1773    }
1774
1775    #[tokio::test]
1776    async fn test_list_tasks_empty() {
1777        let mut router = McpRouter::new();
1778        init_router(&mut router).await;
1779
1780        let req = RouterRequest {
1781            id: RequestId::Number(1),
1782            inner: McpRequest::ListTasks(ListTasksParams::default()),
1783        };
1784
1785        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1786
1787        match resp.inner {
1788            Ok(McpResponse::ListTasks(result)) => {
1789                assert!(result.tasks.is_empty());
1790            }
1791            _ => panic!("Expected ListTasks response"),
1792        }
1793    }
1794
1795    #[tokio::test]
1796    async fn test_task_lifecycle_complete() {
1797        let add_tool = ToolBuilder::new("add")
1798            .description("Add two numbers")
1799            .handler(|input: AddInput| async move {
1800                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1801            })
1802            .build()
1803            .expect("valid tool name");
1804
1805        let mut router = McpRouter::new().tool(add_tool);
1806        init_router(&mut router).await;
1807
1808        // Enqueue task
1809        let req = RouterRequest {
1810            id: RequestId::Number(1),
1811            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1812                tool_name: "add".to_string(),
1813                arguments: serde_json::json!({"a": 7, "b": 8}),
1814                ttl: None,
1815            }),
1816        };
1817
1818        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1819        let task_id = match resp.inner {
1820            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
1821            _ => panic!("Expected EnqueueTask response"),
1822        };
1823
1824        // Wait for task to complete
1825        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1826
1827        // Get task result
1828        let req = RouterRequest {
1829            id: RequestId::Number(2),
1830            inner: McpRequest::GetTaskResult(GetTaskResultParams {
1831                task_id: task_id.clone(),
1832            }),
1833        };
1834
1835        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1836
1837        match resp.inner {
1838            Ok(McpResponse::GetTaskResult(result)) => {
1839                assert_eq!(result.task_id, task_id);
1840                assert_eq!(result.status, TaskStatus::Completed);
1841                assert!(result.result.is_some());
1842                assert!(result.error.is_none());
1843
1844                // Check the result content
1845                let tool_result = result.result.unwrap();
1846                match &tool_result.content[0] {
1847                    Content::Text { text, .. } => assert_eq!(text, "15"),
1848                    _ => panic!("Expected text content"),
1849                }
1850            }
1851            _ => panic!("Expected GetTaskResult response"),
1852        }
1853    }
1854
1855    #[tokio::test]
1856    async fn test_task_cancellation() {
1857        // Use a slow tool to test cancellation
1858        let slow_tool = ToolBuilder::new("slow")
1859            .description("Slow tool")
1860            .handler(|_input: serde_json::Value| async move {
1861                tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
1862                Ok(CallToolResult::text("done"))
1863            })
1864            .build()
1865            .expect("valid tool name");
1866
1867        let mut router = McpRouter::new().tool(slow_tool);
1868        init_router(&mut router).await;
1869
1870        // Enqueue task
1871        let req = RouterRequest {
1872            id: RequestId::Number(1),
1873            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1874                tool_name: "slow".to_string(),
1875                arguments: serde_json::json!({}),
1876                ttl: None,
1877            }),
1878        };
1879
1880        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1881        let task_id = match resp.inner {
1882            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
1883            _ => panic!("Expected EnqueueTask response"),
1884        };
1885
1886        // Cancel the task
1887        let req = RouterRequest {
1888            id: RequestId::Number(2),
1889            inner: McpRequest::CancelTask(CancelTaskParams {
1890                task_id: task_id.clone(),
1891                reason: Some("Test cancellation".to_string()),
1892            }),
1893        };
1894
1895        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1896
1897        match resp.inner {
1898            Ok(McpResponse::CancelTask(result)) => {
1899                assert!(result.cancelled);
1900                assert_eq!(result.status, TaskStatus::Cancelled);
1901            }
1902            _ => panic!("Expected CancelTask response"),
1903        }
1904    }
1905
1906    #[tokio::test]
1907    async fn test_get_task_info() {
1908        let add_tool = ToolBuilder::new("add")
1909            .description("Add two numbers")
1910            .handler(|input: AddInput| async move {
1911                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1912            })
1913            .build()
1914            .expect("valid tool name");
1915
1916        let mut router = McpRouter::new().tool(add_tool);
1917        init_router(&mut router).await;
1918
1919        // Enqueue task
1920        let req = RouterRequest {
1921            id: RequestId::Number(1),
1922            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1923                tool_name: "add".to_string(),
1924                arguments: serde_json::json!({"a": 1, "b": 2}),
1925                ttl: Some(600),
1926            }),
1927        };
1928
1929        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1930        let task_id = match resp.inner {
1931            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
1932            _ => panic!("Expected EnqueueTask response"),
1933        };
1934
1935        // Get task info
1936        let req = RouterRequest {
1937            id: RequestId::Number(2),
1938            inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
1939                task_id: task_id.clone(),
1940            }),
1941        };
1942
1943        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1944
1945        match resp.inner {
1946            Ok(McpResponse::GetTaskInfo(info)) => {
1947                assert_eq!(info.task_id, task_id);
1948                assert!(info.created_at.contains('T')); // ISO 8601
1949                assert_eq!(info.ttl, Some(600));
1950            }
1951            _ => panic!("Expected GetTaskInfo response"),
1952        }
1953    }
1954
1955    #[tokio::test]
1956    async fn test_enqueue_nonexistent_tool() {
1957        let mut router = McpRouter::new();
1958        init_router(&mut router).await;
1959
1960        let req = RouterRequest {
1961            id: RequestId::Number(1),
1962            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
1963                tool_name: "nonexistent".to_string(),
1964                arguments: serde_json::json!({}),
1965                ttl: None,
1966            }),
1967        };
1968
1969        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1970
1971        match resp.inner {
1972            Err(e) => {
1973                assert!(e.message.contains("not found"));
1974            }
1975            _ => panic!("Expected error response"),
1976        }
1977    }
1978
1979    #[tokio::test]
1980    async fn test_get_nonexistent_task() {
1981        let mut router = McpRouter::new();
1982        init_router(&mut router).await;
1983
1984        let req = RouterRequest {
1985            id: RequestId::Number(1),
1986            inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
1987                task_id: "task-999".to_string(),
1988            }),
1989        };
1990
1991        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1992
1993        match resp.inner {
1994            Err(e) => {
1995                assert!(e.message.contains("not found"));
1996            }
1997            _ => panic!("Expected error response"),
1998        }
1999    }
2000
2001    // =========================================================================
2002    // Resource Subscription Tests
2003    // =========================================================================
2004
2005    #[tokio::test]
2006    async fn test_subscribe_to_resource() {
2007        use crate::resource::ResourceBuilder;
2008
2009        let resource = ResourceBuilder::new("file:///test.txt")
2010            .name("Test File")
2011            .text("Hello");
2012
2013        let mut router = McpRouter::new().resource(resource);
2014        init_router(&mut router).await;
2015
2016        // Subscribe to the resource
2017        let req = RouterRequest {
2018            id: RequestId::Number(1),
2019            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2020                uri: "file:///test.txt".to_string(),
2021            }),
2022        };
2023
2024        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2025
2026        match resp.inner {
2027            Ok(McpResponse::SubscribeResource(_)) => {
2028                // Should be subscribed now
2029                assert!(router.is_subscribed("file:///test.txt"));
2030            }
2031            _ => panic!("Expected SubscribeResource response"),
2032        }
2033    }
2034
2035    #[tokio::test]
2036    async fn test_unsubscribe_from_resource() {
2037        use crate::resource::ResourceBuilder;
2038
2039        let resource = ResourceBuilder::new("file:///test.txt")
2040            .name("Test File")
2041            .text("Hello");
2042
2043        let mut router = McpRouter::new().resource(resource);
2044        init_router(&mut router).await;
2045
2046        // Subscribe first
2047        let req = RouterRequest {
2048            id: RequestId::Number(1),
2049            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2050                uri: "file:///test.txt".to_string(),
2051            }),
2052        };
2053        let _ = router.ready().await.unwrap().call(req).await.unwrap();
2054        assert!(router.is_subscribed("file:///test.txt"));
2055
2056        // Now unsubscribe
2057        let req = RouterRequest {
2058            id: RequestId::Number(2),
2059            inner: McpRequest::UnsubscribeResource(UnsubscribeResourceParams {
2060                uri: "file:///test.txt".to_string(),
2061            }),
2062        };
2063
2064        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2065
2066        match resp.inner {
2067            Ok(McpResponse::UnsubscribeResource(_)) => {
2068                // Should no longer be subscribed
2069                assert!(!router.is_subscribed("file:///test.txt"));
2070            }
2071            _ => panic!("Expected UnsubscribeResource response"),
2072        }
2073    }
2074
2075    #[tokio::test]
2076    async fn test_subscribe_nonexistent_resource() {
2077        let mut router = McpRouter::new();
2078        init_router(&mut router).await;
2079
2080        let req = RouterRequest {
2081            id: RequestId::Number(1),
2082            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2083                uri: "file:///nonexistent.txt".to_string(),
2084            }),
2085        };
2086
2087        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2088
2089        match resp.inner {
2090            Err(e) => {
2091                assert!(e.message.contains("not found"));
2092            }
2093            _ => panic!("Expected error response"),
2094        }
2095    }
2096
2097    #[tokio::test]
2098    async fn test_notify_resource_updated() {
2099        use crate::context::notification_channel;
2100        use crate::resource::ResourceBuilder;
2101
2102        let (tx, mut rx) = notification_channel(10);
2103
2104        let resource = ResourceBuilder::new("file:///test.txt")
2105            .name("Test File")
2106            .text("Hello");
2107
2108        let router = McpRouter::new()
2109            .resource(resource)
2110            .with_notification_sender(tx);
2111
2112        // First, manually subscribe (simulate subscription)
2113        router.subscribe("file:///test.txt");
2114
2115        // Now notify
2116        let sent = router.notify_resource_updated("file:///test.txt");
2117        assert!(sent);
2118
2119        // Check the notification was sent
2120        let notification = rx.try_recv().unwrap();
2121        match notification {
2122            ServerNotification::ResourceUpdated { uri } => {
2123                assert_eq!(uri, "file:///test.txt");
2124            }
2125            _ => panic!("Expected ResourceUpdated notification"),
2126        }
2127    }
2128
2129    #[tokio::test]
2130    async fn test_notify_resource_updated_not_subscribed() {
2131        use crate::context::notification_channel;
2132        use crate::resource::ResourceBuilder;
2133
2134        let (tx, mut rx) = notification_channel(10);
2135
2136        let resource = ResourceBuilder::new("file:///test.txt")
2137            .name("Test File")
2138            .text("Hello");
2139
2140        let router = McpRouter::new()
2141            .resource(resource)
2142            .with_notification_sender(tx);
2143
2144        // Try to notify without subscribing
2145        let sent = router.notify_resource_updated("file:///test.txt");
2146        assert!(!sent); // Should not send because not subscribed
2147
2148        // Channel should be empty
2149        assert!(rx.try_recv().is_err());
2150    }
2151
2152    #[tokio::test]
2153    async fn test_notify_resources_list_changed() {
2154        use crate::context::notification_channel;
2155
2156        let (tx, mut rx) = notification_channel(10);
2157        let router = McpRouter::new().with_notification_sender(tx);
2158
2159        let sent = router.notify_resources_list_changed();
2160        assert!(sent);
2161
2162        let notification = rx.try_recv().unwrap();
2163        match notification {
2164            ServerNotification::ResourcesListChanged => {}
2165            _ => panic!("Expected ResourcesListChanged notification"),
2166        }
2167    }
2168
2169    #[tokio::test]
2170    async fn test_subscribed_uris() {
2171        use crate::resource::ResourceBuilder;
2172
2173        let resource1 = ResourceBuilder::new("file:///a.txt").name("A").text("A");
2174
2175        let resource2 = ResourceBuilder::new("file:///b.txt").name("B").text("B");
2176
2177        let router = McpRouter::new().resource(resource1).resource(resource2);
2178
2179        // Subscribe to both
2180        router.subscribe("file:///a.txt");
2181        router.subscribe("file:///b.txt");
2182
2183        let uris = router.subscribed_uris();
2184        assert_eq!(uris.len(), 2);
2185        assert!(uris.contains(&"file:///a.txt".to_string()));
2186        assert!(uris.contains(&"file:///b.txt".to_string()));
2187    }
2188
2189    #[tokio::test]
2190    async fn test_subscription_capability_advertised() {
2191        use crate::resource::ResourceBuilder;
2192
2193        let resource = ResourceBuilder::new("file:///test.txt")
2194            .name("Test")
2195            .text("Hello");
2196
2197        let mut router = McpRouter::new().resource(resource);
2198
2199        // Initialize and check capabilities
2200        let init_req = RouterRequest {
2201            id: RequestId::Number(0),
2202            inner: McpRequest::Initialize(InitializeParams {
2203                protocol_version: "2025-03-26".to_string(),
2204                capabilities: ClientCapabilities {
2205                    roots: None,
2206                    sampling: None,
2207                    elicitation: None,
2208                },
2209                client_info: Implementation {
2210                    name: "test".to_string(),
2211                    version: "1.0".to_string(),
2212                },
2213            }),
2214        };
2215        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2216
2217        match resp.inner {
2218            Ok(McpResponse::Initialize(result)) => {
2219                // Should have resources capability with subscribe enabled
2220                let resources_cap = result.capabilities.resources.unwrap();
2221                assert!(resources_cap.subscribe);
2222            }
2223            _ => panic!("Expected Initialize response"),
2224        }
2225    }
2226
2227    #[tokio::test]
2228    async fn test_completion_handler() {
2229        let router = McpRouter::new()
2230            .server_info("test", "1.0")
2231            .completion_handler(|params: CompleteParams| async move {
2232                // Return suggestions based on the argument value
2233                let prefix = &params.argument.value;
2234                let suggestions: Vec<String> = vec!["alpha", "beta", "gamma"]
2235                    .into_iter()
2236                    .filter(|s| s.starts_with(prefix))
2237                    .map(String::from)
2238                    .collect();
2239                Ok(CompleteResult::new(suggestions))
2240            });
2241
2242        // Initialize
2243        let init_req = RouterRequest {
2244            id: RequestId::Number(0),
2245            inner: McpRequest::Initialize(InitializeParams {
2246                protocol_version: "2025-03-26".to_string(),
2247                capabilities: ClientCapabilities::default(),
2248                client_info: Implementation {
2249                    name: "test".to_string(),
2250                    version: "1.0".to_string(),
2251                },
2252            }),
2253        };
2254        let resp = router
2255            .clone()
2256            .ready()
2257            .await
2258            .unwrap()
2259            .call(init_req)
2260            .await
2261            .unwrap();
2262
2263        // Check that completions capability is advertised
2264        match resp.inner {
2265            Ok(McpResponse::Initialize(result)) => {
2266                assert!(result.capabilities.completions.is_some());
2267            }
2268            _ => panic!("Expected Initialize response"),
2269        }
2270
2271        // Send initialized notification
2272        router.handle_notification(McpNotification::Initialized);
2273
2274        // Test completion request
2275        let complete_req = RouterRequest {
2276            id: RequestId::Number(1),
2277            inner: McpRequest::Complete(CompleteParams {
2278                reference: CompletionReference::prompt("test-prompt"),
2279                argument: CompletionArgument::new("query", "al"),
2280            }),
2281        };
2282        let resp = router
2283            .clone()
2284            .ready()
2285            .await
2286            .unwrap()
2287            .call(complete_req)
2288            .await
2289            .unwrap();
2290
2291        match resp.inner {
2292            Ok(McpResponse::Complete(result)) => {
2293                assert_eq!(result.completion.values, vec!["alpha"]);
2294            }
2295            _ => panic!("Expected Complete response"),
2296        }
2297    }
2298
2299    #[tokio::test]
2300    async fn test_completion_without_handler_returns_empty() {
2301        let router = McpRouter::new().server_info("test", "1.0");
2302
2303        // Initialize
2304        let init_req = RouterRequest {
2305            id: RequestId::Number(0),
2306            inner: McpRequest::Initialize(InitializeParams {
2307                protocol_version: "2025-03-26".to_string(),
2308                capabilities: ClientCapabilities::default(),
2309                client_info: Implementation {
2310                    name: "test".to_string(),
2311                    version: "1.0".to_string(),
2312                },
2313            }),
2314        };
2315        let resp = router
2316            .clone()
2317            .ready()
2318            .await
2319            .unwrap()
2320            .call(init_req)
2321            .await
2322            .unwrap();
2323
2324        // Check that completions capability is NOT advertised
2325        match resp.inner {
2326            Ok(McpResponse::Initialize(result)) => {
2327                assert!(result.capabilities.completions.is_none());
2328            }
2329            _ => panic!("Expected Initialize response"),
2330        }
2331
2332        // Send initialized notification
2333        router.handle_notification(McpNotification::Initialized);
2334
2335        // Test completion request still works but returns empty
2336        let complete_req = RouterRequest {
2337            id: RequestId::Number(1),
2338            inner: McpRequest::Complete(CompleteParams {
2339                reference: CompletionReference::prompt("test-prompt"),
2340                argument: CompletionArgument::new("query", "al"),
2341            }),
2342        };
2343        let resp = router
2344            .clone()
2345            .ready()
2346            .await
2347            .unwrap()
2348            .call(complete_req)
2349            .await
2350            .unwrap();
2351
2352        match resp.inner {
2353            Ok(McpResponse::Complete(result)) => {
2354                assert!(result.completion.values.is_empty());
2355            }
2356            _ => panic!("Expected Complete response"),
2357        }
2358    }
2359}