Skip to main content

tower_mcp/client/
mod.rs

1//! MCP Client with bidirectional communication support.
2//!
3//! Provides [`McpClient`] for connecting to MCP servers over any
4//! [`ClientTransport`]. The client runs a background message loop that
5//! handles request/response correlation, server-initiated requests
6//! (sampling, elicitation, roots), and notifications.
7//!
8//! # Example
9//!
10//! ```rust,no_run
11//! use tower_mcp::client::{McpClient, StdioClientTransport};
12//!
13//! #[tokio::main]
14//! async fn main() -> Result<(), tower_mcp::BoxError> {
15//!     let transport = StdioClientTransport::spawn("my-mcp-server", &["--flag"]).await?;
16//!     let client = McpClient::connect(transport).await?;
17//!
18//!     let server_info = client.initialize("my-client", "1.0.0").await?;
19//!     println!("Connected to: {}", server_info.server_info.name);
20//!
21//!     let tools = client.list_tools().await?;
22//!     for tool in &tools.tools {
23//!         println!("Tool: {}", tool.name);
24//!     }
25//!
26//!     let result = client.call_tool("my-tool", serde_json::json!({"arg": "value"})).await?;
27//!     println!("Result: {:?}", result);
28//!
29//!     Ok(())
30//! }
31//! ```
32
33mod channel;
34mod handler;
35#[cfg(feature = "http-client")]
36mod http;
37#[cfg(feature = "oauth-client")]
38mod oauth;
39#[cfg(feature = "oauth-client")]
40mod oauth_authcode;
41mod stdio;
42mod transport;
43
44pub use channel::ChannelTransport;
45pub use handler::{ClientHandler, NotificationHandler, ServerNotification};
46#[cfg(feature = "http-client")]
47pub use http::{HttpClientConfig, HttpClientTransport};
48#[cfg(feature = "oauth-client")]
49pub use oauth::{
50    OAuthClientCredentials, OAuthClientCredentialsBuilder, OAuthClientError, TokenProvider,
51};
52#[cfg(feature = "oauth-client")]
53pub use oauth_authcode::{OAuthAuthCodeConfig, OAuthAuthorizationCode};
54pub use stdio::StdioClientTransport;
55pub use transport::ClientTransport;
56
57use std::collections::HashMap;
58use std::sync::Arc;
59use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
60
61use tokio::sync::{Mutex, RwLock, mpsc, oneshot};
62use tokio::task::JoinHandle;
63
64use crate::error::{Error, Result};
65use crate::protocol::{
66    CallToolParams, CallToolResult, ClientCapabilities, CompleteParams, CompleteResult,
67    CompletionArgument, CompletionReference, ElicitationCapability, GetPromptParams,
68    GetPromptResult, Implementation, InitializeParams, InitializeResult, JsonRpcNotification,
69    JsonRpcRequest, ListPromptsParams, ListPromptsResult, ListResourceTemplatesParams,
70    ListResourceTemplatesResult, ListResourcesParams, ListResourcesResult, ListRootsResult,
71    ListToolsParams, ListToolsResult, PromptDefinition, ReadResourceParams, ReadResourceResult,
72    RequestId, ResourceDefinition, ResourceTemplateDefinition, Root, RootsCapability,
73    SamplingCapability, ToolDefinition, notifications,
74};
75use tower_mcp_types::JsonRpcError;
76
77/// Internal command sent from McpClient methods to the background loop.
78enum LoopCommand {
79    /// Send a JSON-RPC request and await a response.
80    Request {
81        method: String,
82        params: serde_json::Value,
83        response_tx: oneshot::Sender<Result<serde_json::Value>>,
84    },
85    /// Send a JSON-RPC notification (no response expected).
86    Notify {
87        method: String,
88        params: serde_json::Value,
89    },
90    /// Reset the transport's session state for re-initialization.
91    ResetSession { done_tx: oneshot::Sender<()> },
92    /// Graceful shutdown.
93    Shutdown,
94}
95
96/// MCP client with a background message loop.
97///
98/// Unlike previous versions, this type is not generic over the transport.
99/// The transport is consumed during [`connect()`](Self::connect) and moved
100/// into a background Tokio task that handles message multiplexing.
101///
102/// All public methods take `&self`, enabling concurrent use from multiple
103/// tasks.
104///
105/// # Construction
106///
107/// ```rust,no_run
108/// use tower_mcp::client::{McpClient, StdioClientTransport};
109///
110/// # async fn example() -> Result<(), tower_mcp::BoxError> {
111/// // Simple: no handler for server-initiated requests
112/// let transport = StdioClientTransport::spawn("server", &[]).await?;
113/// let client = McpClient::connect(transport).await?;
114///
115/// // With configuration
116/// use tower_mcp::protocol::Root;
117/// let transport = StdioClientTransport::spawn("server", &[]).await?;
118/// let client = McpClient::builder()
119///     .with_roots(vec![Root::new("file:///project")])
120///     .connect_simple(transport)
121///     .await?;
122/// # Ok(())
123/// # }
124/// ```
125pub struct McpClient {
126    /// Channel to send commands to the background loop.
127    command_tx: mpsc::Sender<LoopCommand>,
128    /// Background task handle.
129    task: Option<JoinHandle<()>>,
130    /// Whether `initialize()` has been called successfully.
131    initialized: AtomicBool,
132    /// Server info (set after successful initialization).
133    server_info: RwLock<Option<InitializeResult>>,
134    /// Client capabilities declared during initialization.
135    capabilities: ClientCapabilities,
136    /// Current roots (shared with the loop for roots/list responses).
137    roots: Arc<RwLock<Vec<Root>>>,
138    /// Whether the transport is still connected.
139    connected: Arc<AtomicBool>,
140    /// Whether the transport supports session recovery.
141    supports_session_recovery: bool,
142    /// Stored init params for session recovery re-initialization.
143    init_params: RwLock<Option<(String, String)>>,
144    /// Lock to prevent concurrent session recovery attempts.
145    recovery_lock: Mutex<()>,
146}
147
148/// Builder for configuring and connecting an [`McpClient`].
149///
150/// # Example
151///
152/// ```rust,no_run
153/// use tower_mcp::client::{McpClient, StdioClientTransport};
154/// use tower_mcp::protocol::Root;
155///
156/// # async fn example() -> Result<(), tower_mcp::BoxError> {
157/// let transport = StdioClientTransport::spawn("server", &[]).await?;
158/// let handler = (); // Use a real ClientHandler for bidirectional support
159/// let client = McpClient::builder()
160///     .with_roots(vec![Root::new("file:///project")])
161///     .with_sampling()
162///     .connect(transport, handler)
163///     .await?;
164/// # Ok(())
165/// # }
166/// ```
167pub struct McpClientBuilder {
168    capabilities: ClientCapabilities,
169    roots: Vec<Root>,
170}
171
172impl McpClientBuilder {
173    /// Create a new builder with default settings.
174    pub fn new() -> Self {
175        Self {
176            capabilities: ClientCapabilities::default(),
177            roots: Vec::new(),
178        }
179    }
180
181    /// Configure roots for this client.
182    ///
183    /// The client will declare roots support during initialization and
184    /// respond to `roots/list` requests with these roots.
185    pub fn with_roots(mut self, roots: Vec<Root>) -> Self {
186        self.roots = roots;
187        self.capabilities.roots = Some(RootsCapability { list_changed: true });
188        self
189    }
190
191    /// Configure custom capabilities for this client.
192    pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
193        self.capabilities = capabilities;
194        self
195    }
196
197    /// Declare sampling support.
198    ///
199    /// Sets the sampling capability so the server knows this client can
200    /// handle `sampling/createMessage` requests. The handler passed to
201    /// [`connect()`](Self::connect) should override
202    /// [`handle_create_message()`](ClientHandler::handle_create_message).
203    pub fn with_sampling(mut self) -> Self {
204        self.capabilities.sampling = Some(SamplingCapability::default());
205        self
206    }
207
208    /// Declare elicitation support.
209    ///
210    /// Sets the elicitation capability so the server knows this client can
211    /// handle `elicitation/create` requests. The handler passed to
212    /// [`connect()`](Self::connect) should override
213    /// [`handle_elicit()`](ClientHandler::handle_elicit).
214    pub fn with_elicitation(mut self) -> Self {
215        self.capabilities.elicitation = Some(ElicitationCapability::default());
216        self
217    }
218
219    /// Connect to a server using the given transport and handler.
220    ///
221    /// Spawns a background task to handle message I/O. The transport is
222    /// consumed and owned by the background task.
223    pub async fn connect<T, H>(self, transport: T, handler: H) -> Result<McpClient>
224    where
225        T: ClientTransport,
226        H: ClientHandler,
227    {
228        McpClient::connect_inner(transport, handler, self.capabilities, self.roots).await
229    }
230
231    /// Connect to a server without a handler.
232    ///
233    /// All server-initiated requests will be rejected with `method_not_found`.
234    pub async fn connect_simple<T: ClientTransport>(self, transport: T) -> Result<McpClient> {
235        self.connect(transport, ()).await
236    }
237}
238
239impl Default for McpClientBuilder {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245impl McpClient {
246    /// Connect with default settings and no handler.
247    ///
248    /// Shorthand for `McpClient::builder().connect_simple(transport)`.
249    pub async fn connect<T: ClientTransport>(transport: T) -> Result<Self> {
250        McpClientBuilder::new().connect_simple(transport).await
251    }
252
253    /// Connect with a handler for server-initiated requests.
254    pub async fn connect_with_handler<T, H>(transport: T, handler: H) -> Result<Self>
255    where
256        T: ClientTransport,
257        H: ClientHandler,
258    {
259        McpClientBuilder::new().connect(transport, handler).await
260    }
261
262    /// Create a builder for advanced configuration.
263    pub fn builder() -> McpClientBuilder {
264        McpClientBuilder::new()
265    }
266
267    /// Internal connect implementation.
268    async fn connect_inner<T, H>(
269        transport: T,
270        handler: H,
271        capabilities: ClientCapabilities,
272        roots: Vec<Root>,
273    ) -> Result<Self>
274    where
275        T: ClientTransport,
276        H: ClientHandler,
277    {
278        let supports_session_recovery = transport.supports_session_recovery();
279        let (command_tx, command_rx) = mpsc::channel::<LoopCommand>(64);
280        let connected = Arc::new(AtomicBool::new(true));
281        let roots = Arc::new(RwLock::new(roots));
282
283        let loop_connected = connected.clone();
284        let loop_roots = roots.clone();
285
286        let task = tokio::spawn(async move {
287            message_loop(transport, handler, command_rx, loop_connected, loop_roots).await;
288        });
289
290        Ok(Self {
291            command_tx,
292            task: Some(task),
293            initialized: AtomicBool::new(false),
294            server_info: RwLock::new(None),
295            capabilities,
296            roots,
297            connected,
298            supports_session_recovery,
299            init_params: RwLock::new(None),
300            recovery_lock: Mutex::new(()),
301        })
302    }
303
304    /// Check if the client has been initialized.
305    pub fn is_initialized(&self) -> bool {
306        self.initialized.load(Ordering::Acquire)
307    }
308
309    /// Check if the transport is still connected.
310    pub fn is_connected(&self) -> bool {
311        self.connected.load(Ordering::Acquire)
312    }
313
314    /// Get the server info (available after initialization).
315    pub async fn server_info(&self) -> Option<InitializeResult> {
316        self.server_info.read().await.clone()
317    }
318
319    /// Get the server info synchronously (best-effort, non-blocking).
320    ///
321    /// Returns `None` if the lock is currently held by a writer or if
322    /// initialization hasn't completed. Prefer [`server_info()`](Self::server_info)
323    /// in async contexts.
324    pub fn server_info_blocking(&self) -> Option<InitializeResult> {
325        self.server_info.try_read().ok()?.clone()
326    }
327
328    /// Initialize the MCP connection.
329    ///
330    /// Sends the `initialize` request and `notifications/initialized` notification.
331    /// Must be called before any other operations.
332    pub async fn initialize(
333        &self,
334        client_name: &str,
335        client_version: &str,
336    ) -> Result<InitializeResult> {
337        let params = InitializeParams {
338            protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
339            capabilities: self.capabilities.clone(),
340            client_info: Implementation {
341                name: client_name.to_string(),
342                version: client_version.to_string(),
343                ..Default::default()
344            },
345            meta: None,
346        };
347
348        let result: InitializeResult = self.send_request("initialize", &params).await?;
349        *self.server_info.write().await = Some(result.clone());
350
351        // Store init params for potential session recovery
352        *self.init_params.write().await =
353            Some((client_name.to_string(), client_version.to_string()));
354
355        // Send initialized notification
356        self.send_notification("notifications/initialized", &serde_json::json!({}))
357            .await?;
358        self.initialized.store(true, Ordering::Release);
359
360        Ok(result)
361    }
362
363    /// List available tools.
364    pub async fn list_tools(&self) -> Result<ListToolsResult> {
365        self.ensure_initialized()?;
366        self.send_request(
367            "tools/list",
368            &ListToolsParams {
369                cursor: None,
370                meta: None,
371            },
372        )
373        .await
374    }
375
376    /// Call a tool.
377    pub async fn call_tool(
378        &self,
379        name: &str,
380        arguments: serde_json::Value,
381    ) -> Result<CallToolResult> {
382        self.ensure_initialized()?;
383        let params = CallToolParams {
384            name: name.to_string(),
385            arguments,
386            meta: None,
387            task: None,
388        };
389        self.send_request("tools/call", &params).await
390    }
391
392    /// List available resources.
393    pub async fn list_resources(&self) -> Result<ListResourcesResult> {
394        self.ensure_initialized()?;
395        self.send_request(
396            "resources/list",
397            &ListResourcesParams {
398                cursor: None,
399                meta: None,
400            },
401        )
402        .await
403    }
404
405    /// Read a resource.
406    pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult> {
407        self.ensure_initialized()?;
408        let params = ReadResourceParams {
409            uri: uri.to_string(),
410            meta: None,
411        };
412        self.send_request("resources/read", &params).await
413    }
414
415    /// List available prompts.
416    pub async fn list_prompts(&self) -> Result<ListPromptsResult> {
417        self.ensure_initialized()?;
418        self.send_request(
419            "prompts/list",
420            &ListPromptsParams {
421                cursor: None,
422                meta: None,
423            },
424        )
425        .await
426    }
427
428    /// List tools with an optional pagination cursor.
429    pub async fn list_tools_with_cursor(&self, cursor: Option<String>) -> Result<ListToolsResult> {
430        self.ensure_initialized()?;
431        self.send_request("tools/list", &ListToolsParams { cursor, meta: None })
432            .await
433    }
434
435    /// List resources with an optional pagination cursor.
436    pub async fn list_resources_with_cursor(
437        &self,
438        cursor: Option<String>,
439    ) -> Result<ListResourcesResult> {
440        self.ensure_initialized()?;
441        self.send_request(
442            "resources/list",
443            &ListResourcesParams { cursor, meta: None },
444        )
445        .await
446    }
447
448    /// List resource templates.
449    pub async fn list_resource_templates(&self) -> Result<ListResourceTemplatesResult> {
450        self.ensure_initialized()?;
451        self.send_request(
452            "resources/templates/list",
453            &ListResourceTemplatesParams {
454                cursor: None,
455                meta: None,
456            },
457        )
458        .await
459    }
460
461    /// List resource templates with an optional pagination cursor.
462    pub async fn list_resource_templates_with_cursor(
463        &self,
464        cursor: Option<String>,
465    ) -> Result<ListResourceTemplatesResult> {
466        self.ensure_initialized()?;
467        self.send_request(
468            "resources/templates/list",
469            &ListResourceTemplatesParams { cursor, meta: None },
470        )
471        .await
472    }
473
474    /// List prompts with an optional pagination cursor.
475    pub async fn list_prompts_with_cursor(
476        &self,
477        cursor: Option<String>,
478    ) -> Result<ListPromptsResult> {
479        self.ensure_initialized()?;
480        self.send_request("prompts/list", &ListPromptsParams { cursor, meta: None })
481            .await
482    }
483
484    /// List all tools, following pagination cursors until exhausted.
485    pub async fn list_all_tools(&self) -> Result<Vec<ToolDefinition>> {
486        let mut all = Vec::new();
487        let mut cursor = None;
488        loop {
489            let result = self.list_tools_with_cursor(cursor).await?;
490            all.extend(result.tools);
491            match result.next_cursor {
492                Some(c) => cursor = Some(c),
493                None => break,
494            }
495        }
496        Ok(all)
497    }
498
499    /// List all resources, following pagination cursors until exhausted.
500    pub async fn list_all_resources(&self) -> Result<Vec<ResourceDefinition>> {
501        let mut all = Vec::new();
502        let mut cursor = None;
503        loop {
504            let result = self.list_resources_with_cursor(cursor).await?;
505            all.extend(result.resources);
506            match result.next_cursor {
507                Some(c) => cursor = Some(c),
508                None => break,
509            }
510        }
511        Ok(all)
512    }
513
514    /// List all resource templates, following pagination cursors until exhausted.
515    pub async fn list_all_resource_templates(&self) -> Result<Vec<ResourceTemplateDefinition>> {
516        let mut all = Vec::new();
517        let mut cursor = None;
518        loop {
519            let result = self.list_resource_templates_with_cursor(cursor).await?;
520            all.extend(result.resource_templates);
521            match result.next_cursor {
522                Some(c) => cursor = Some(c),
523                None => break,
524            }
525        }
526        Ok(all)
527    }
528
529    /// List all prompts, following pagination cursors until exhausted.
530    pub async fn list_all_prompts(&self) -> Result<Vec<PromptDefinition>> {
531        let mut all = Vec::new();
532        let mut cursor = None;
533        loop {
534            let result = self.list_prompts_with_cursor(cursor).await?;
535            all.extend(result.prompts);
536            match result.next_cursor {
537                Some(c) => cursor = Some(c),
538                None => break,
539            }
540        }
541        Ok(all)
542    }
543
544    /// Call a tool and return the concatenated text content.
545    ///
546    /// Returns the text from all [`Text`](crate::protocol::Content::Text) items joined together.
547    /// If the tool result indicates an error (`is_error` is true), returns
548    /// an error with the text content as the message.
549    ///
550    /// For more control over the result, use [`call_tool()`](Self::call_tool).
551    pub async fn call_tool_text(&self, name: &str, arguments: serde_json::Value) -> Result<String> {
552        let result = self.call_tool(name, arguments).await?;
553        if result.is_error {
554            return Err(Error::Internal(result.all_text()));
555        }
556        Ok(result.all_text())
557    }
558
559    /// Get a prompt.
560    pub async fn get_prompt(
561        &self,
562        name: &str,
563        arguments: Option<std::collections::HashMap<String, String>>,
564    ) -> Result<GetPromptResult> {
565        self.ensure_initialized()?;
566        let params = GetPromptParams {
567            name: name.to_string(),
568            arguments: arguments.unwrap_or_default(),
569            meta: None,
570        };
571        self.send_request("prompts/get", &params).await
572    }
573
574    /// Ping the server.
575    pub async fn ping(&self) -> Result<()> {
576        let _: serde_json::Value = self.send_request("ping", &serde_json::json!({})).await?;
577        Ok(())
578    }
579
580    /// Request completion suggestions from the server.
581    pub async fn complete(
582        &self,
583        reference: CompletionReference,
584        argument_name: &str,
585        argument_value: &str,
586    ) -> Result<CompleteResult> {
587        self.ensure_initialized()?;
588        let params = CompleteParams {
589            reference,
590            argument: CompletionArgument::new(argument_name, argument_value),
591            context: None,
592            meta: None,
593        };
594        self.send_request("completion/complete", &params).await
595    }
596
597    /// Request completion for a prompt argument.
598    pub async fn complete_prompt_arg(
599        &self,
600        prompt_name: &str,
601        argument_name: &str,
602        argument_value: &str,
603    ) -> Result<CompleteResult> {
604        self.complete(
605            CompletionReference::prompt(prompt_name),
606            argument_name,
607            argument_value,
608        )
609        .await
610    }
611
612    /// Request completion for a resource URI.
613    pub async fn complete_resource_uri(
614        &self,
615        resource_uri: &str,
616        argument_name: &str,
617        argument_value: &str,
618    ) -> Result<CompleteResult> {
619        self.complete(
620            CompletionReference::resource(resource_uri),
621            argument_name,
622            argument_value,
623        )
624        .await
625    }
626
627    /// Send a raw typed request to the server.
628    pub async fn request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
629        &self,
630        method: &str,
631        params: &P,
632    ) -> Result<R> {
633        self.send_request(method, params).await
634    }
635
636    /// Send a raw typed notification to the server.
637    pub async fn notify<P: serde::Serialize>(&self, method: &str, params: &P) -> Result<()> {
638        self.send_notification(method, params).await
639    }
640
641    /// Get the current roots.
642    pub async fn roots(&self) -> Vec<Root> {
643        self.roots.read().await.clone()
644    }
645
646    /// Set roots and notify the server if initialized.
647    pub async fn set_roots(&self, roots: Vec<Root>) -> Result<()> {
648        *self.roots.write().await = roots;
649        if self.is_initialized() {
650            self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
651                .await?;
652        }
653        Ok(())
654    }
655
656    /// Add a root and notify the server if initialized.
657    pub async fn add_root(&self, root: Root) -> Result<()> {
658        self.roots.write().await.push(root);
659        if self.is_initialized() {
660            self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
661                .await?;
662        }
663        Ok(())
664    }
665
666    /// Remove a root by URI and notify the server if initialized.
667    pub async fn remove_root(&self, uri: &str) -> Result<bool> {
668        let mut roots = self.roots.write().await;
669        let initial_len = roots.len();
670        roots.retain(|r| r.uri != uri);
671        let removed = roots.len() < initial_len;
672        drop(roots);
673
674        if removed && self.is_initialized() {
675            self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
676                .await?;
677        }
678        Ok(removed)
679    }
680
681    /// Get the roots list result (for responding to server's roots/list request).
682    pub async fn list_roots(&self) -> ListRootsResult {
683        ListRootsResult {
684            roots: self.roots.read().await.clone(),
685            meta: None,
686        }
687    }
688
689    /// Gracefully shut down the client and close the transport.
690    pub async fn shutdown(mut self) -> Result<()> {
691        let _ = self.command_tx.send(LoopCommand::Shutdown).await;
692        if let Some(task) = self.task.take() {
693            let _ = task.await;
694        }
695        Ok(())
696    }
697
698    // --- Internal helpers ---
699
700    async fn send_request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
701        &self,
702        method: &str,
703        params: &P,
704    ) -> Result<R> {
705        match self.send_request_once(method, params).await {
706            Err(Error::SessionExpired)
707                if self.supports_session_recovery && method != "initialize" =>
708            {
709                tracing::info!(method = %method, "Session expired, attempting recovery");
710                self.recover_session().await?;
711                self.send_request_once(method, params).await
712            }
713            other => other,
714        }
715    }
716
717    async fn send_request_once<P: serde::Serialize, R: serde::de::DeserializeOwned>(
718        &self,
719        method: &str,
720        params: &P,
721    ) -> Result<R> {
722        self.ensure_connected()?;
723        let params_value = serde_json::to_value(params)
724            .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
725
726        let (response_tx, response_rx) = oneshot::channel();
727        self.command_tx
728            .send(LoopCommand::Request {
729                method: method.to_string(),
730                params: params_value,
731                response_tx,
732            })
733            .await
734            .map_err(|_| Error::Transport("Connection closed".to_string()))?;
735
736        let result = response_rx
737            .await
738            .map_err(|_| Error::Transport("Connection closed".to_string()))??;
739
740        serde_json::from_value(result)
741            .map_err(|e| Error::Transport(format!("Failed to deserialize response: {}", e)))
742    }
743
744    /// Recover from a session expiry by resetting the transport and re-initializing.
745    async fn recover_session(&self) -> Result<()> {
746        // Serialize recovery attempts
747        let _guard = self.recovery_lock.lock().await;
748
749        // Check if another task already recovered while we waited
750        // (the init_params being present means we were initialized before)
751        let init_params = self.init_params.read().await.clone();
752        let (client_name, client_version) = match init_params {
753            Some(params) => params,
754            None => {
755                return Err(Error::Transport(
756                    "Cannot recover: never initialized".to_string(),
757                ));
758            }
759        };
760
761        // Tell the message loop to reset the transport
762        let (done_tx, done_rx) = oneshot::channel();
763        self.command_tx
764            .send(LoopCommand::ResetSession { done_tx })
765            .await
766            .map_err(|_| Error::Transport("Connection closed".to_string()))?;
767        done_rx
768            .await
769            .map_err(|_| Error::Transport("Connection closed during recovery".to_string()))?;
770
771        // Clear initialized state
772        self.initialized.store(false, Ordering::Release);
773        *self.server_info.write().await = None;
774
775        // Re-initialize (using send_request_once to avoid recursion)
776        tracing::info!("Re-initializing session after expiry");
777        let params = InitializeParams {
778            protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
779            capabilities: self.capabilities.clone(),
780            client_info: Implementation {
781                name: client_name,
782                version: client_version,
783                ..Default::default()
784            },
785            meta: None,
786        };
787
788        let result: InitializeResult = self.send_request_once("initialize", &params).await?;
789        *self.server_info.write().await = Some(result);
790
791        self.send_notification("notifications/initialized", &serde_json::json!({}))
792            .await?;
793        self.initialized.store(true, Ordering::Release);
794
795        Ok(())
796    }
797
798    async fn send_notification<P: serde::Serialize>(&self, method: &str, params: &P) -> Result<()> {
799        self.ensure_connected()?;
800        let params_value = serde_json::to_value(params)
801            .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
802
803        self.command_tx
804            .send(LoopCommand::Notify {
805                method: method.to_string(),
806                params: params_value,
807            })
808            .await
809            .map_err(|_| Error::Transport("Connection closed".to_string()))?;
810
811        Ok(())
812    }
813
814    fn ensure_connected(&self) -> Result<()> {
815        if !self.connected.load(Ordering::Acquire) {
816            return Err(Error::Transport("Connection closed".to_string()));
817        }
818        Ok(())
819    }
820
821    fn ensure_initialized(&self) -> Result<()> {
822        if !self.initialized.load(Ordering::Acquire) {
823            return Err(Error::Transport("Client not initialized".to_string()));
824        }
825        Ok(())
826    }
827}
828
829impl Drop for McpClient {
830    fn drop(&mut self) {
831        if let Some(task) = self.task.take() {
832            task.abort();
833        }
834    }
835}
836
837// =============================================================================
838// Background Message Loop
839// =============================================================================
840
841/// A pending request waiting for a response from the server.
842struct PendingRequest {
843    response_tx: oneshot::Sender<Result<serde_json::Value>>,
844}
845
846/// Background message loop that multiplexes incoming/outgoing messages.
847async fn message_loop<T: ClientTransport, H: ClientHandler>(
848    mut transport: T,
849    handler: H,
850    mut command_rx: mpsc::Receiver<LoopCommand>,
851    connected: Arc<AtomicBool>,
852    roots: Arc<RwLock<Vec<Root>>>,
853) {
854    let handler = Arc::new(handler);
855    let mut pending_requests: HashMap<RequestId, PendingRequest> = HashMap::new();
856    let next_id = AtomicI64::new(1);
857
858    loop {
859        tokio::select! {
860            // Commands from McpClient methods
861            command = command_rx.recv() => {
862                match command {
863                    Some(LoopCommand::Request { method, params, response_tx }) => {
864                        let id = RequestId::Number(next_id.fetch_add(1, Ordering::Relaxed));
865
866                        let request = JsonRpcRequest::new(id.clone(), &method)
867                            .with_params(params);
868                        let json = match serde_json::to_string(&request) {
869                            Ok(j) => j,
870                            Err(e) => {
871                                let _ = response_tx.send(Err(Error::Transport(
872                                    format!("Serialization failed: {}", e)
873                                )));
874                                continue;
875                            }
876                        };
877
878                        tracing::debug!(method = %method, id = ?id, "Sending request");
879                        pending_requests.insert(id, PendingRequest { response_tx });
880
881                        if let Err(e) = transport.send(&json).await {
882                            tracing::error!(error = %e, "Transport send error");
883                            fail_all_pending(&mut pending_requests, &format!("Transport error: {}", e));
884                            break;
885                        }
886                    }
887                    Some(LoopCommand::Notify { method, params }) => {
888                        let notification = JsonRpcNotification::new(&method)
889                            .with_params(params);
890                        if let Ok(json) = serde_json::to_string(&notification) {
891                            tracing::debug!(method = %method, "Sending notification");
892                            let _ = transport.send(&json).await;
893                        }
894                    }
895                    Some(LoopCommand::ResetSession { done_tx }) => {
896                        tracing::info!("Resetting transport session for re-initialization");
897                        transport.reset_session().await;
898                        // Fail any pending requests with session expired
899                        for (_, pending) in pending_requests.drain() {
900                            let _ = pending.response_tx.send(Err(Error::SessionExpired));
901                        }
902                        let _ = done_tx.send(());
903                    }
904                    Some(LoopCommand::Shutdown) | None => {
905                        tracing::debug!("Message loop shutting down");
906                        break;
907                    }
908                }
909            }
910
911            // Incoming messages from the server
912            result = transport.recv() => {
913                match result {
914                    Ok(Some(line)) => {
915                        handle_incoming(
916                            &line,
917                            &mut pending_requests,
918                            &handler,
919                            &roots,
920                            &mut transport,
921                        ).await;
922                    }
923                    Ok(None) => {
924                        tracing::info!("Transport closed (EOF)");
925                        break;
926                    }
927                    Err(e) => {
928                        tracing::error!(error = %e, "Transport receive error");
929                        break;
930                    }
931                }
932            }
933        }
934    }
935
936    // Cleanup
937    connected.store(false, Ordering::Release);
938    fail_all_pending(&mut pending_requests, "Connection closed");
939    let _ = transport.close().await;
940}
941
942/// Handle a single incoming message from the server.
943async fn handle_incoming<T: ClientTransport, H: ClientHandler>(
944    line: &str,
945    pending_requests: &mut HashMap<RequestId, PendingRequest>,
946    handler: &Arc<H>,
947    roots: &Arc<RwLock<Vec<Root>>>,
948    transport: &mut T,
949) {
950    let parsed: serde_json::Value = match serde_json::from_str(line) {
951        Ok(v) => v,
952        Err(e) => {
953            tracing::warn!(error = %e, "Failed to parse incoming message");
954            return;
955        }
956    };
957
958    // Case 1: Response to one of our pending requests (has result or error, no method)
959    if parsed.get("method").is_none()
960        && (parsed.get("result").is_some() || parsed.get("error").is_some())
961    {
962        // Check for session-level errors (id: null with -32005) that affect
963        // all pending requests, not just a specific one.
964        if let Some(error) = parsed.get("error") {
965            let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(0) as i32;
966            let id_missing_or_null = parsed.get("id").is_none_or(|id| id.is_null());
967            if code == -32005 && id_missing_or_null {
968                tracing::warn!(
969                    "Session expired (-32005 with null id), failing all pending requests"
970                );
971                for (_, pending) in pending_requests.drain() {
972                    let _ = pending.response_tx.send(Err(Error::SessionExpired));
973                }
974                return;
975            }
976        }
977
978        handle_response(&parsed, pending_requests);
979        return;
980    }
981
982    // Case 2: Server-initiated request (has id + method)
983    if parsed.get("id").is_some() && parsed.get("method").is_some() {
984        let id = parse_request_id(&parsed);
985        let method = parsed["method"].as_str().unwrap_or("");
986        let params = parsed.get("params").cloned();
987
988        let result = dispatch_server_request(handler, roots, method, params).await;
989
990        // Send response back to the server
991        let response = match result {
992            Ok(value) => {
993                if let Some(id) = id {
994                    serde_json::json!({
995                        "jsonrpc": "2.0",
996                        "id": id,
997                        "result": value
998                    })
999                } else {
1000                    return;
1001                }
1002            }
1003            Err(error) => {
1004                serde_json::json!({
1005                    "jsonrpc": "2.0",
1006                    "id": id,
1007                    "error": {
1008                        "code": error.code,
1009                        "message": error.message
1010                    }
1011                })
1012            }
1013        };
1014
1015        if let Ok(json) = serde_json::to_string(&response) {
1016            let _ = transport.send(&json).await;
1017        }
1018        return;
1019    }
1020
1021    // Case 3: Server notification (has method, no id)
1022    if parsed.get("method").is_some() && parsed.get("id").is_none() {
1023        let method = parsed["method"].as_str().unwrap_or("");
1024        let params = parsed.get("params").cloned();
1025        let notification = parse_server_notification(method, params);
1026        handler.on_notification(notification).await;
1027    }
1028}
1029
1030/// Handle a JSON-RPC response by routing to the pending request.
1031fn handle_response(
1032    parsed: &serde_json::Value,
1033    pending_requests: &mut HashMap<RequestId, PendingRequest>,
1034) {
1035    let id = match parse_request_id(parsed) {
1036        Some(id) => id,
1037        None => {
1038            tracing::warn!("Response without id");
1039            return;
1040        }
1041    };
1042
1043    let pending = match pending_requests.remove(&id) {
1044        Some(p) => p,
1045        None => {
1046            tracing::warn!(id = ?id, "Response for unknown request");
1047            return;
1048        }
1049    };
1050
1051    tracing::debug!(id = ?id, "Received response");
1052
1053    if let Some(error) = parsed.get("error") {
1054        let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1) as i32;
1055        let message = error
1056            .get("message")
1057            .and_then(|m| m.as_str())
1058            .unwrap_or("Unknown error")
1059            .to_string();
1060        let data = error.get("data").cloned();
1061
1062        // -32005 = SessionNotFound: signal session expiry for recovery
1063        if code == -32005 {
1064            let _ = pending.response_tx.send(Err(Error::SessionExpired));
1065            return;
1066        }
1067
1068        let json_rpc_error = JsonRpcError {
1069            code,
1070            message,
1071            data,
1072        };
1073        let _ = pending
1074            .response_tx
1075            .send(Err(Error::JsonRpc(json_rpc_error)));
1076    } else if let Some(result) = parsed.get("result") {
1077        let _ = pending.response_tx.send(Ok(result.clone()));
1078    } else {
1079        let _ = pending
1080            .response_tx
1081            .send(Err(Error::Transport("Invalid response".to_string())));
1082    }
1083}
1084
1085/// Dispatch a server-initiated request to the handler.
1086async fn dispatch_server_request<H: ClientHandler>(
1087    handler: &Arc<H>,
1088    roots: &Arc<RwLock<Vec<Root>>>,
1089    method: &str,
1090    params: Option<serde_json::Value>,
1091) -> std::result::Result<serde_json::Value, JsonRpcError> {
1092    match method {
1093        "sampling/createMessage" => {
1094            let p = serde_json::from_value(params.unwrap_or_default())
1095                .map_err(|e| JsonRpcError::invalid_params(e.to_string()))?;
1096            let result = handler.handle_create_message(p).await?;
1097            serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
1098        }
1099        "elicitation/create" => {
1100            let p = serde_json::from_value(params.unwrap_or_default())
1101                .map_err(|e| JsonRpcError::invalid_params(e.to_string()))?;
1102            let result = handler.handle_elicit(p).await?;
1103            serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
1104        }
1105        "roots/list" => {
1106            // Use client-configured roots if available, otherwise delegate to handler
1107            let roots_list = roots.read().await;
1108            if !roots_list.is_empty() {
1109                let result = ListRootsResult {
1110                    roots: roots_list.clone(),
1111                    meta: None,
1112                };
1113                return serde_json::to_value(result)
1114                    .map_err(|e| JsonRpcError::internal_error(e.to_string()));
1115            }
1116            drop(roots_list);
1117
1118            let result = handler.handle_list_roots().await?;
1119            serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
1120        }
1121        "ping" => Ok(serde_json::json!({})),
1122        _ => Err(JsonRpcError::method_not_found(method)),
1123    }
1124}
1125
1126/// Parse a request ID from a JSON-RPC message.
1127fn parse_request_id(parsed: &serde_json::Value) -> Option<RequestId> {
1128    parsed.get("id").and_then(|id| {
1129        if let Some(n) = id.as_i64() {
1130            Some(RequestId::Number(n))
1131        } else {
1132            id.as_str().map(|s| RequestId::String(s.to_string()))
1133        }
1134    })
1135}
1136
1137/// Parse a server notification into the typed enum.
1138fn parse_server_notification(
1139    method: &str,
1140    params: Option<serde_json::Value>,
1141) -> ServerNotification {
1142    match method {
1143        notifications::PROGRESS => {
1144            if let Some(params) = params
1145                && let Ok(p) = serde_json::from_value(params)
1146            {
1147                return ServerNotification::Progress(p);
1148            }
1149            ServerNotification::Unknown {
1150                method: method.to_string(),
1151                params: None,
1152            }
1153        }
1154        notifications::MESSAGE => {
1155            if let Some(params) = params
1156                && let Ok(p) = serde_json::from_value(params)
1157            {
1158                return ServerNotification::LogMessage(p);
1159            }
1160            ServerNotification::Unknown {
1161                method: method.to_string(),
1162                params: None,
1163            }
1164        }
1165        notifications::RESOURCE_UPDATED => {
1166            if let Some(params) = &params
1167                && let Some(uri) = params.get("uri").and_then(|u| u.as_str())
1168            {
1169                return ServerNotification::ResourceUpdated {
1170                    uri: uri.to_string(),
1171                };
1172            }
1173            ServerNotification::Unknown {
1174                method: method.to_string(),
1175                params,
1176            }
1177        }
1178        notifications::RESOURCES_LIST_CHANGED => ServerNotification::ResourcesListChanged,
1179        notifications::TOOLS_LIST_CHANGED => ServerNotification::ToolsListChanged,
1180        notifications::PROMPTS_LIST_CHANGED => ServerNotification::PromptsListChanged,
1181        _ => ServerNotification::Unknown {
1182            method: method.to_string(),
1183            params,
1184        },
1185    }
1186}
1187
1188/// Fail all pending requests with the given error message.
1189fn fail_all_pending(pending: &mut HashMap<RequestId, PendingRequest>, reason: &str) {
1190    for (_, req) in pending.drain() {
1191        let _ = req
1192            .response_tx
1193            .send(Err(Error::Transport(reason.to_string())));
1194    }
1195}
1196
1197#[cfg(test)]
1198mod tests {
1199    use super::*;
1200    use async_trait::async_trait;
1201    use std::sync::Mutex;
1202
1203    /// Mock transport for testing that auto-responds to requests.
1204    ///
1205    /// When the client sends a request via `send()`, the mock extracts the
1206    /// request ID, pairs it with the next preconfigured response, and feeds
1207    /// it back through a channel that `recv()` awaits on. This ensures
1208    /// `recv()` blocks when no messages are available (instead of returning
1209    /// EOF), keeping the background message loop alive.
1210    struct MockTransport {
1211        /// Pre-configured response payloads (result values, not full envelopes).
1212        responses: Arc<Mutex<Vec<serde_json::Value>>>,
1213        /// Index of the next response to use.
1214        response_idx: Arc<std::sync::atomic::AtomicUsize>,
1215        /// Channel sender for feeding responses back to `recv()`.
1216        incoming_tx: mpsc::Sender<String>,
1217        /// Channel receiver for `recv()` to await on.
1218        incoming_rx: mpsc::Receiver<String>,
1219        /// Collected outgoing messages from `send()`.
1220        outgoing: Arc<Mutex<Vec<String>>>,
1221        connected: Arc<AtomicBool>,
1222    }
1223
1224    #[allow(dead_code)]
1225    impl MockTransport {
1226        fn new() -> Self {
1227            let (tx, rx) = mpsc::channel(32);
1228            Self {
1229                responses: Arc::new(Mutex::new(Vec::new())),
1230                response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1231                incoming_tx: tx,
1232                incoming_rx: rx,
1233                outgoing: Arc::new(Mutex::new(Vec::new())),
1234                connected: Arc::new(AtomicBool::new(true)),
1235            }
1236        }
1237
1238        /// Create a mock that auto-responds with the given result payloads.
1239        ///
1240        /// When `send()` receives a JSON-RPC request, it extracts the request
1241        /// ID and pairs it with the next response from this list, sending the
1242        /// complete JSON-RPC response through the channel for `recv()`.
1243        fn with_responses(responses: Vec<serde_json::Value>) -> Self {
1244            let (tx, rx) = mpsc::channel(32);
1245            Self {
1246                responses: Arc::new(Mutex::new(responses)),
1247                response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1248                incoming_tx: tx,
1249                incoming_rx: rx,
1250                outgoing: Arc::new(Mutex::new(Vec::new())),
1251                connected: Arc::new(AtomicBool::new(true)),
1252            }
1253        }
1254    }
1255
1256    #[async_trait]
1257    impl ClientTransport for MockTransport {
1258        async fn send(&mut self, message: &str) -> Result<()> {
1259            self.outgoing.lock().unwrap().push(message.to_string());
1260
1261            // Parse the outgoing message to extract the request ID
1262            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(message) {
1263                // Only respond to requests (messages with an id and method)
1264                if let Some(id) = parsed.get("id") {
1265                    let idx = self
1266                        .response_idx
1267                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1268                    let responses = self.responses.lock().unwrap();
1269                    if let Some(result) = responses.get(idx) {
1270                        let response = serde_json::json!({
1271                            "jsonrpc": "2.0",
1272                            "id": id,
1273                            "result": result
1274                        });
1275                        let _ = self.incoming_tx.try_send(response.to_string());
1276                    }
1277                }
1278            }
1279
1280            Ok(())
1281        }
1282
1283        async fn recv(&mut self) -> Result<Option<String>> {
1284            // Await on the channel -- blocks until a message is available
1285            // or the sender is dropped (returns None = EOF).
1286            match self.incoming_rx.recv().await {
1287                Some(msg) => Ok(Some(msg)),
1288                None => Ok(None),
1289            }
1290        }
1291
1292        fn is_connected(&self) -> bool {
1293            self.connected.load(Ordering::Relaxed)
1294        }
1295
1296        async fn close(&mut self) -> Result<()> {
1297            self.connected.store(false, Ordering::Relaxed);
1298            Ok(())
1299        }
1300    }
1301
1302    fn mock_initialize_response() -> serde_json::Value {
1303        serde_json::json!({
1304            "protocolVersion": "2025-11-25",
1305            "serverInfo": {
1306                "name": "test-server",
1307                "version": "1.0.0"
1308            },
1309            "capabilities": {
1310                "tools": {}
1311            }
1312        })
1313    }
1314
1315    #[tokio::test]
1316    async fn test_client_not_initialized() {
1317        let client = McpClient::connect(MockTransport::with_responses(vec![]))
1318            .await
1319            .unwrap();
1320
1321        let result = client.list_tools().await;
1322        assert!(result.is_err());
1323        assert!(result.unwrap_err().to_string().contains("not initialized"));
1324    }
1325
1326    #[tokio::test]
1327    async fn test_client_initialize() {
1328        let client = McpClient::connect(MockTransport::with_responses(vec![
1329            mock_initialize_response(),
1330        ]))
1331        .await
1332        .unwrap();
1333
1334        assert!(!client.is_initialized());
1335
1336        let result = client.initialize("test-client", "1.0.0").await;
1337        assert!(result.is_ok());
1338        assert!(client.is_initialized());
1339
1340        let server_info = client.server_info().await.unwrap();
1341        assert_eq!(server_info.server_info.name, "test-server");
1342    }
1343
1344    #[tokio::test]
1345    async fn test_list_tools() {
1346        let client = McpClient::connect(MockTransport::with_responses(vec![
1347            mock_initialize_response(),
1348            serde_json::json!({
1349                "tools": [
1350                    {
1351                        "name": "test_tool",
1352                        "description": "A test tool",
1353                        "inputSchema": {
1354                            "type": "object",
1355                            "properties": {}
1356                        }
1357                    }
1358                ]
1359            }),
1360        ]))
1361        .await
1362        .unwrap();
1363
1364        client.initialize("test-client", "1.0.0").await.unwrap();
1365        let tools = client.list_tools().await.unwrap();
1366
1367        assert_eq!(tools.tools.len(), 1);
1368        assert_eq!(tools.tools[0].name, "test_tool");
1369    }
1370
1371    #[tokio::test]
1372    async fn test_call_tool() {
1373        let client = McpClient::connect(MockTransport::with_responses(vec![
1374            mock_initialize_response(),
1375            serde_json::json!({
1376                "content": [
1377                    {
1378                        "type": "text",
1379                        "text": "Tool result"
1380                    }
1381                ]
1382            }),
1383        ]))
1384        .await
1385        .unwrap();
1386
1387        client.initialize("test-client", "1.0.0").await.unwrap();
1388        let result = client
1389            .call_tool("test_tool", serde_json::json!({"arg": "value"}))
1390            .await
1391            .unwrap();
1392
1393        assert!(!result.content.is_empty());
1394    }
1395
1396    #[tokio::test]
1397    async fn test_list_resources() {
1398        let client = McpClient::connect(MockTransport::with_responses(vec![
1399            mock_initialize_response(),
1400            serde_json::json!({
1401                "resources": [
1402                    {
1403                        "uri": "file://test.txt",
1404                        "name": "Test File"
1405                    }
1406                ]
1407            }),
1408        ]))
1409        .await
1410        .unwrap();
1411
1412        client.initialize("test-client", "1.0.0").await.unwrap();
1413        let resources = client.list_resources().await.unwrap();
1414
1415        assert_eq!(resources.resources.len(), 1);
1416        assert_eq!(resources.resources[0].uri, "file://test.txt");
1417    }
1418
1419    #[tokio::test]
1420    async fn test_read_resource() {
1421        let client = McpClient::connect(MockTransport::with_responses(vec![
1422            mock_initialize_response(),
1423            serde_json::json!({
1424                "contents": [
1425                    {
1426                        "uri": "file://test.txt",
1427                        "text": "File contents"
1428                    }
1429                ]
1430            }),
1431        ]))
1432        .await
1433        .unwrap();
1434
1435        client.initialize("test-client", "1.0.0").await.unwrap();
1436        let result = client.read_resource("file://test.txt").await.unwrap();
1437
1438        assert_eq!(result.contents.len(), 1);
1439        assert_eq!(result.contents[0].text.as_deref(), Some("File contents"));
1440    }
1441
1442    #[tokio::test]
1443    async fn test_list_prompts() {
1444        let client = McpClient::connect(MockTransport::with_responses(vec![
1445            mock_initialize_response(),
1446            serde_json::json!({
1447                "prompts": [
1448                    {
1449                        "name": "test_prompt",
1450                        "description": "A test prompt"
1451                    }
1452                ]
1453            }),
1454        ]))
1455        .await
1456        .unwrap();
1457
1458        client.initialize("test-client", "1.0.0").await.unwrap();
1459        let prompts = client.list_prompts().await.unwrap();
1460
1461        assert_eq!(prompts.prompts.len(), 1);
1462        assert_eq!(prompts.prompts[0].name, "test_prompt");
1463    }
1464
1465    #[tokio::test]
1466    async fn test_get_prompt() {
1467        let client = McpClient::connect(MockTransport::with_responses(vec![
1468            mock_initialize_response(),
1469            serde_json::json!({
1470                "messages": [
1471                    {
1472                        "role": "user",
1473                        "content": {
1474                            "type": "text",
1475                            "text": "Prompt message"
1476                        }
1477                    }
1478                ]
1479            }),
1480        ]))
1481        .await
1482        .unwrap();
1483
1484        client.initialize("test-client", "1.0.0").await.unwrap();
1485        let result = client.get_prompt("test_prompt", None).await.unwrap();
1486
1487        assert_eq!(result.messages.len(), 1);
1488    }
1489
1490    #[tokio::test]
1491    async fn test_ping() {
1492        let client = McpClient::connect(MockTransport::with_responses(vec![
1493            mock_initialize_response(),
1494            serde_json::json!({}),
1495        ]))
1496        .await
1497        .unwrap();
1498
1499        client.initialize("test-client", "1.0.0").await.unwrap();
1500        let result = client.ping().await;
1501        assert!(result.is_ok());
1502    }
1503
1504    #[tokio::test]
1505    async fn test_with_roots() {
1506        let roots = vec![Root::new("file:///test")];
1507        let client = McpClient::builder()
1508            .with_roots(roots)
1509            .connect_simple(MockTransport::with_responses(vec![]))
1510            .await
1511            .unwrap();
1512
1513        let current_roots = client.roots().await;
1514        assert_eq!(current_roots.len(), 1);
1515    }
1516
1517    #[tokio::test]
1518    async fn test_roots_management() {
1519        let client = McpClient::connect(MockTransport::with_responses(vec![
1520            mock_initialize_response(),
1521        ]))
1522        .await
1523        .unwrap();
1524
1525        // Initially no roots
1526        assert!(client.roots().await.is_empty());
1527
1528        // Add a root before initialization (no notification sent)
1529        client.add_root(Root::new("file:///project")).await.unwrap();
1530        assert_eq!(client.roots().await.len(), 1);
1531
1532        // Initialize
1533        client.initialize("test-client", "1.0.0").await.unwrap();
1534
1535        // Remove a root
1536        let removed = client.remove_root("file:///project").await.unwrap();
1537        assert!(removed);
1538        assert!(client.roots().await.is_empty());
1539
1540        // Try to remove non-existent root
1541        let not_removed = client.remove_root("file:///nonexistent").await.unwrap();
1542        assert!(!not_removed);
1543    }
1544
1545    #[tokio::test]
1546    async fn test_list_roots() {
1547        let roots = vec![
1548            Root::new("file:///project1"),
1549            Root::with_name("file:///project2", "Project 2"),
1550        ];
1551        let client = McpClient::builder()
1552            .with_roots(roots)
1553            .connect_simple(MockTransport::with_responses(vec![]))
1554            .await
1555            .unwrap();
1556
1557        let result = client.list_roots().await;
1558        assert_eq!(result.roots.len(), 2);
1559        assert_eq!(result.roots[1].name, Some("Project 2".to_string()));
1560    }
1561
1562    #[test]
1563    fn test_builder_with_sampling() {
1564        let builder = McpClientBuilder::new().with_sampling();
1565        assert!(builder.capabilities.sampling.is_some());
1566    }
1567
1568    #[test]
1569    fn test_builder_with_elicitation() {
1570        let builder = McpClientBuilder::new().with_elicitation();
1571        assert!(builder.capabilities.elicitation.is_some());
1572    }
1573
1574    #[test]
1575    fn test_builder_chaining() {
1576        let builder = McpClientBuilder::new()
1577            .with_sampling()
1578            .with_elicitation()
1579            .with_roots(vec![Root::new("file:///project")]);
1580        assert!(builder.capabilities.sampling.is_some());
1581        assert!(builder.capabilities.elicitation.is_some());
1582        assert!(builder.capabilities.roots.is_some());
1583    }
1584
1585    #[tokio::test]
1586    async fn test_bidirectional_sampling_round_trip() {
1587        use crate::protocol::{
1588            ContentRole, CreateMessageParams, CreateMessageResult, SamplingContent,
1589            SamplingContentOrArray,
1590        };
1591
1592        // A handler that records whether handle_create_message was called
1593        struct RecordingHandler {
1594            called: Arc<AtomicBool>,
1595        }
1596
1597        #[async_trait]
1598        impl ClientHandler for RecordingHandler {
1599            async fn handle_create_message(
1600                &self,
1601                _params: CreateMessageParams,
1602            ) -> std::result::Result<CreateMessageResult, tower_mcp_types::JsonRpcError>
1603            {
1604                self.called.store(true, Ordering::SeqCst);
1605                Ok(CreateMessageResult {
1606                    content: SamplingContentOrArray::Single(SamplingContent::Text {
1607                        text: "test response".to_string(),
1608                        annotations: None,
1609                        meta: None,
1610                    }),
1611                    model: "test-model".to_string(),
1612                    role: ContentRole::Assistant,
1613                    stop_reason: Some("end_turn".to_string()),
1614                    meta: None,
1615                })
1616            }
1617        }
1618
1619        let called = Arc::new(AtomicBool::new(false));
1620        let handler = RecordingHandler {
1621            called: called.clone(),
1622        };
1623
1624        // Build a mock transport, keeping a clone of incoming_tx so we can
1625        // inject a server-initiated request after the transport is consumed.
1626        let (inject_tx, rx) = mpsc::channel::<String>(32);
1627        let responses = vec![mock_initialize_response()];
1628        let inject_tx_clone = inject_tx.clone();
1629
1630        let transport = MockTransport {
1631            responses: Arc::new(Mutex::new(responses)),
1632            response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1633            incoming_tx: inject_tx,
1634            incoming_rx: rx,
1635            outgoing: Arc::new(Mutex::new(Vec::new())),
1636            connected: Arc::new(AtomicBool::new(true)),
1637        };
1638
1639        let client = McpClient::builder()
1640            .with_sampling()
1641            .connect(transport, handler)
1642            .await
1643            .unwrap();
1644
1645        // Initialize the client (this sends initialize request + notification)
1646        client.initialize("test-client", "1.0.0").await.unwrap();
1647
1648        // Inject a server-initiated sampling/createMessage request
1649        let sampling_request = serde_json::json!({
1650            "jsonrpc": "2.0",
1651            "id": 100,
1652            "method": "sampling/createMessage",
1653            "params": {
1654                "messages": [
1655                    {
1656                        "role": "user",
1657                        "content": {
1658                            "type": "text",
1659                            "text": "Hello"
1660                        }
1661                    }
1662                ],
1663                "maxTokens": 100
1664            }
1665        });
1666        inject_tx_clone
1667            .send(sampling_request.to_string())
1668            .await
1669            .unwrap();
1670
1671        // Give the background loop time to process
1672        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1673
1674        // Verify the handler was called
1675        assert!(
1676            called.load(Ordering::SeqCst),
1677            "handle_create_message should have been called"
1678        );
1679    }
1680
1681    #[tokio::test]
1682    async fn test_list_resource_templates() {
1683        let client = McpClient::connect(MockTransport::with_responses(vec![
1684            mock_initialize_response(),
1685            serde_json::json!({
1686                "resourceTemplates": [
1687                    {
1688                        "uriTemplate": "file:///{path}",
1689                        "name": "File Template",
1690                        "description": "A file template"
1691                    }
1692                ]
1693            }),
1694        ]))
1695        .await
1696        .unwrap();
1697
1698        client.initialize("test-client", "1.0.0").await.unwrap();
1699        let result = client.list_resource_templates().await.unwrap();
1700
1701        assert_eq!(result.resource_templates.len(), 1);
1702        assert_eq!(result.resource_templates[0].name, "File Template");
1703    }
1704
1705    #[tokio::test]
1706    async fn test_list_all_tools_single_page() {
1707        let client = McpClient::connect(MockTransport::with_responses(vec![
1708            mock_initialize_response(),
1709            serde_json::json!({
1710                "tools": [
1711                    {
1712                        "name": "tool_a",
1713                        "description": "Tool A",
1714                        "inputSchema": { "type": "object", "properties": {} }
1715                    },
1716                    {
1717                        "name": "tool_b",
1718                        "description": "Tool B",
1719                        "inputSchema": { "type": "object", "properties": {} }
1720                    }
1721                ]
1722            }),
1723        ]))
1724        .await
1725        .unwrap();
1726
1727        client.initialize("test-client", "1.0.0").await.unwrap();
1728        let tools = client.list_all_tools().await.unwrap();
1729
1730        assert_eq!(tools.len(), 2);
1731        assert_eq!(tools[0].name, "tool_a");
1732        assert_eq!(tools[1].name, "tool_b");
1733    }
1734
1735    #[tokio::test]
1736    async fn test_list_all_tools_paginated() {
1737        let client = McpClient::connect(MockTransport::with_responses(vec![
1738            mock_initialize_response(),
1739            // First page with a next_cursor
1740            serde_json::json!({
1741                "tools": [
1742                    {
1743                        "name": "tool_a",
1744                        "description": "Tool A",
1745                        "inputSchema": { "type": "object", "properties": {} }
1746                    }
1747                ],
1748                "nextCursor": "page2"
1749            }),
1750            // Second page with no next_cursor
1751            serde_json::json!({
1752                "tools": [
1753                    {
1754                        "name": "tool_b",
1755                        "description": "Tool B",
1756                        "inputSchema": { "type": "object", "properties": {} }
1757                    }
1758                ]
1759            }),
1760        ]))
1761        .await
1762        .unwrap();
1763
1764        client.initialize("test-client", "1.0.0").await.unwrap();
1765        let tools = client.list_all_tools().await.unwrap();
1766
1767        assert_eq!(tools.len(), 2);
1768        assert_eq!(tools[0].name, "tool_a");
1769        assert_eq!(tools[1].name, "tool_b");
1770    }
1771
1772    #[tokio::test]
1773    async fn test_call_tool_text_success() {
1774        let client = McpClient::connect(MockTransport::with_responses(vec![
1775            mock_initialize_response(),
1776            serde_json::json!({
1777                "content": [
1778                    { "type": "text", "text": "Hello " },
1779                    { "type": "text", "text": "World" }
1780                ]
1781            }),
1782        ]))
1783        .await
1784        .unwrap();
1785
1786        client.initialize("test-client", "1.0.0").await.unwrap();
1787        let text = client
1788            .call_tool_text("test_tool", serde_json::json!({}))
1789            .await
1790            .unwrap();
1791
1792        assert_eq!(text, "Hello World");
1793    }
1794
1795    #[tokio::test]
1796    async fn test_call_tool_text_error() {
1797        let client = McpClient::connect(MockTransport::with_responses(vec![
1798            mock_initialize_response(),
1799            serde_json::json!({
1800                "content": [
1801                    { "type": "text", "text": "something went wrong" }
1802                ],
1803                "isError": true
1804            }),
1805        ]))
1806        .await
1807        .unwrap();
1808
1809        client.initialize("test-client", "1.0.0").await.unwrap();
1810        let result = client
1811            .call_tool_text("test_tool", serde_json::json!({}))
1812            .await;
1813
1814        assert!(result.is_err());
1815        let err = result.unwrap_err();
1816        assert!(
1817            err.to_string().contains("something went wrong"),
1818            "Error message should contain tool error text, got: {}",
1819            err
1820        );
1821    }
1822
1823    #[tokio::test]
1824    async fn test_server_notification_parsing() {
1825        let notification = parse_server_notification("notifications/tools/list_changed", None);
1826        assert!(matches!(notification, ServerNotification::ToolsListChanged));
1827
1828        let notification = parse_server_notification("notifications/resources/list_changed", None);
1829        assert!(matches!(
1830            notification,
1831            ServerNotification::ResourcesListChanged
1832        ));
1833
1834        let notification = parse_server_notification(
1835            "notifications/resources/updated",
1836            Some(serde_json::json!({"uri": "file:///test"})),
1837        );
1838        match notification {
1839            ServerNotification::ResourceUpdated { uri } => {
1840                assert_eq!(uri, "file:///test");
1841            }
1842            _ => panic!("Expected ResourceUpdated"),
1843        }
1844
1845        let notification =
1846            parse_server_notification("custom/notification", Some(serde_json::json!({"data": 42})));
1847        match notification {
1848            ServerNotification::Unknown { method, params } => {
1849                assert_eq!(method, "custom/notification");
1850                assert!(params.is_some());
1851            }
1852            _ => panic!("Expected Unknown"),
1853        }
1854    }
1855}