Skip to main content

tower_mcp/client/
mod.rs

1//! MCP Client with bidirectional communication support.
2//!
3//! Provides [`McpClient`] for connecting to MCP servers over any
4//! [`ClientTransport`]. The client runs a background message loop that
5//! handles request/response correlation, server-initiated requests
6//! (sampling, elicitation, roots), and notifications.
7//!
8//! # Example
9//!
10//! ```rust,no_run
11//! use tower_mcp::client::{McpClient, StdioClientTransport};
12//!
13//! #[tokio::main]
14//! async fn main() -> Result<(), tower_mcp::BoxError> {
15//!     let transport = StdioClientTransport::spawn("my-mcp-server", &["--flag"]).await?;
16//!     let client = McpClient::connect(transport).await?;
17//!
18//!     let server_info = client.initialize("my-client", "1.0.0").await?;
19//!     println!("Connected to: {}", server_info.server_info.name);
20//!
21//!     let tools = client.list_tools().await?;
22//!     for tool in &tools.tools {
23//!         println!("Tool: {}", tool.name);
24//!     }
25//!
26//!     let result = client.call_tool("my-tool", serde_json::json!({"arg": "value"})).await?;
27//!     println!("Result: {:?}", result);
28//!
29//!     Ok(())
30//! }
31//! ```
32
33mod channel;
34mod handler;
35#[cfg(feature = "http-client")]
36mod http;
37#[cfg(feature = "oauth-client")]
38mod oauth;
39mod stdio;
40mod transport;
41
42pub use channel::ChannelTransport;
43pub use handler::{ClientHandler, NotificationHandler, ServerNotification};
44#[cfg(feature = "http-client")]
45pub use http::{HttpClientConfig, HttpClientTransport};
46#[cfg(feature = "oauth-client")]
47pub use oauth::{
48    OAuthClientCredentials, OAuthClientCredentialsBuilder, OAuthClientError, TokenProvider,
49};
50pub use stdio::StdioClientTransport;
51pub use transport::ClientTransport;
52
53use std::collections::HashMap;
54use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
55use std::sync::{Arc, OnceLock};
56
57use tokio::sync::{RwLock, mpsc, oneshot};
58use tokio::task::JoinHandle;
59
60use crate::error::{Error, Result};
61use crate::protocol::{
62    CallToolParams, CallToolResult, ClientCapabilities, CompleteParams, CompleteResult,
63    CompletionArgument, CompletionReference, ElicitationCapability, GetPromptParams,
64    GetPromptResult, Implementation, InitializeParams, InitializeResult, JsonRpcNotification,
65    JsonRpcRequest, ListPromptsParams, ListPromptsResult, ListResourceTemplatesParams,
66    ListResourceTemplatesResult, ListResourcesParams, ListResourcesResult, ListRootsResult,
67    ListToolsParams, ListToolsResult, PromptDefinition, ReadResourceParams, ReadResourceResult,
68    RequestId, ResourceDefinition, ResourceTemplateDefinition, Root, RootsCapability,
69    SamplingCapability, ToolDefinition, notifications,
70};
71use tower_mcp_types::JsonRpcError;
72
73/// Internal command sent from McpClient methods to the background loop.
74enum LoopCommand {
75    /// Send a JSON-RPC request and await a response.
76    Request {
77        method: String,
78        params: serde_json::Value,
79        response_tx: oneshot::Sender<Result<serde_json::Value>>,
80    },
81    /// Send a JSON-RPC notification (no response expected).
82    Notify {
83        method: String,
84        params: serde_json::Value,
85    },
86    /// Graceful shutdown.
87    Shutdown,
88}
89
90/// MCP client with a background message loop.
91///
92/// Unlike previous versions, this type is not generic over the transport.
93/// The transport is consumed during [`connect()`](Self::connect) and moved
94/// into a background Tokio task that handles message multiplexing.
95///
96/// All public methods take `&self`, enabling concurrent use from multiple
97/// tasks.
98///
99/// # Construction
100///
101/// ```rust,no_run
102/// use tower_mcp::client::{McpClient, StdioClientTransport};
103///
104/// # async fn example() -> Result<(), tower_mcp::BoxError> {
105/// // Simple: no handler for server-initiated requests
106/// let transport = StdioClientTransport::spawn("server", &[]).await?;
107/// let client = McpClient::connect(transport).await?;
108///
109/// // With configuration
110/// use tower_mcp::protocol::Root;
111/// let transport = StdioClientTransport::spawn("server", &[]).await?;
112/// let client = McpClient::builder()
113///     .with_roots(vec![Root::new("file:///project")])
114///     .connect_simple(transport)
115///     .await?;
116/// # Ok(())
117/// # }
118/// ```
119pub struct McpClient {
120    /// Channel to send commands to the background loop.
121    command_tx: mpsc::Sender<LoopCommand>,
122    /// Background task handle.
123    task: Option<JoinHandle<()>>,
124    /// Whether `initialize()` has been called successfully.
125    initialized: AtomicBool,
126    /// Server info (set after successful initialization).
127    server_info: OnceLock<InitializeResult>,
128    /// Client capabilities declared during initialization.
129    capabilities: ClientCapabilities,
130    /// Current roots (shared with the loop for roots/list responses).
131    roots: Arc<RwLock<Vec<Root>>>,
132    /// Whether the transport is still connected.
133    connected: Arc<AtomicBool>,
134}
135
136/// Builder for configuring and connecting an [`McpClient`].
137///
138/// # Example
139///
140/// ```rust,no_run
141/// use tower_mcp::client::{McpClient, StdioClientTransport};
142/// use tower_mcp::protocol::Root;
143///
144/// # async fn example() -> Result<(), tower_mcp::BoxError> {
145/// let transport = StdioClientTransport::spawn("server", &[]).await?;
146/// let handler = (); // Use a real ClientHandler for bidirectional support
147/// let client = McpClient::builder()
148///     .with_roots(vec![Root::new("file:///project")])
149///     .with_sampling()
150///     .connect(transport, handler)
151///     .await?;
152/// # Ok(())
153/// # }
154/// ```
155pub struct McpClientBuilder {
156    capabilities: ClientCapabilities,
157    roots: Vec<Root>,
158}
159
160impl McpClientBuilder {
161    /// Create a new builder with default settings.
162    pub fn new() -> Self {
163        Self {
164            capabilities: ClientCapabilities::default(),
165            roots: Vec::new(),
166        }
167    }
168
169    /// Configure roots for this client.
170    ///
171    /// The client will declare roots support during initialization and
172    /// respond to `roots/list` requests with these roots.
173    pub fn with_roots(mut self, roots: Vec<Root>) -> Self {
174        self.roots = roots;
175        self.capabilities.roots = Some(RootsCapability { list_changed: true });
176        self
177    }
178
179    /// Configure custom capabilities for this client.
180    pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
181        self.capabilities = capabilities;
182        self
183    }
184
185    /// Declare sampling support.
186    ///
187    /// Sets the sampling capability so the server knows this client can
188    /// handle `sampling/createMessage` requests. The handler passed to
189    /// [`connect()`](Self::connect) should override
190    /// [`handle_create_message()`](ClientHandler::handle_create_message).
191    pub fn with_sampling(mut self) -> Self {
192        self.capabilities.sampling = Some(SamplingCapability::default());
193        self
194    }
195
196    /// Declare elicitation support.
197    ///
198    /// Sets the elicitation capability so the server knows this client can
199    /// handle `elicitation/create` requests. The handler passed to
200    /// [`connect()`](Self::connect) should override
201    /// [`handle_elicit()`](ClientHandler::handle_elicit).
202    pub fn with_elicitation(mut self) -> Self {
203        self.capabilities.elicitation = Some(ElicitationCapability::default());
204        self
205    }
206
207    /// Connect to a server using the given transport and handler.
208    ///
209    /// Spawns a background task to handle message I/O. The transport is
210    /// consumed and owned by the background task.
211    pub async fn connect<T, H>(self, transport: T, handler: H) -> Result<McpClient>
212    where
213        T: ClientTransport,
214        H: ClientHandler,
215    {
216        McpClient::connect_inner(transport, handler, self.capabilities, self.roots).await
217    }
218
219    /// Connect to a server without a handler.
220    ///
221    /// All server-initiated requests will be rejected with `method_not_found`.
222    pub async fn connect_simple<T: ClientTransport>(self, transport: T) -> Result<McpClient> {
223        self.connect(transport, ()).await
224    }
225}
226
227impl Default for McpClientBuilder {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233impl McpClient {
234    /// Connect with default settings and no handler.
235    ///
236    /// Shorthand for `McpClient::builder().connect_simple(transport)`.
237    pub async fn connect<T: ClientTransport>(transport: T) -> Result<Self> {
238        McpClientBuilder::new().connect_simple(transport).await
239    }
240
241    /// Connect with a handler for server-initiated requests.
242    pub async fn connect_with_handler<T, H>(transport: T, handler: H) -> Result<Self>
243    where
244        T: ClientTransport,
245        H: ClientHandler,
246    {
247        McpClientBuilder::new().connect(transport, handler).await
248    }
249
250    /// Create a builder for advanced configuration.
251    pub fn builder() -> McpClientBuilder {
252        McpClientBuilder::new()
253    }
254
255    /// Internal connect implementation.
256    async fn connect_inner<T, H>(
257        transport: T,
258        handler: H,
259        capabilities: ClientCapabilities,
260        roots: Vec<Root>,
261    ) -> Result<Self>
262    where
263        T: ClientTransport,
264        H: ClientHandler,
265    {
266        let (command_tx, command_rx) = mpsc::channel::<LoopCommand>(64);
267        let connected = Arc::new(AtomicBool::new(true));
268        let roots = Arc::new(RwLock::new(roots));
269
270        let loop_connected = connected.clone();
271        let loop_roots = roots.clone();
272
273        let task = tokio::spawn(async move {
274            message_loop(transport, handler, command_rx, loop_connected, loop_roots).await;
275        });
276
277        Ok(Self {
278            command_tx,
279            task: Some(task),
280            initialized: AtomicBool::new(false),
281            server_info: OnceLock::new(),
282            capabilities,
283            roots,
284            connected,
285        })
286    }
287
288    /// Check if the client has been initialized.
289    pub fn is_initialized(&self) -> bool {
290        self.initialized.load(Ordering::Acquire)
291    }
292
293    /// Check if the transport is still connected.
294    pub fn is_connected(&self) -> bool {
295        self.connected.load(Ordering::Acquire)
296    }
297
298    /// Get the server info (available after initialization).
299    pub fn server_info(&self) -> Option<&InitializeResult> {
300        self.server_info.get()
301    }
302
303    /// Initialize the MCP connection.
304    ///
305    /// Sends the `initialize` request and `notifications/initialized` notification.
306    /// Must be called before any other operations.
307    pub async fn initialize(
308        &self,
309        client_name: &str,
310        client_version: &str,
311    ) -> Result<&InitializeResult> {
312        let params = InitializeParams {
313            protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
314            capabilities: self.capabilities.clone(),
315            client_info: Implementation {
316                name: client_name.to_string(),
317                version: client_version.to_string(),
318                ..Default::default()
319            },
320            meta: None,
321        };
322
323        let result: InitializeResult = self.send_request("initialize", &params).await?;
324        let _ = self.server_info.set(result);
325
326        // Send initialized notification
327        self.send_notification("notifications/initialized", &serde_json::json!({}))
328            .await?;
329        self.initialized.store(true, Ordering::Release);
330
331        Ok(self.server_info.get().unwrap())
332    }
333
334    /// List available tools.
335    pub async fn list_tools(&self) -> Result<ListToolsResult> {
336        self.ensure_initialized()?;
337        self.send_request(
338            "tools/list",
339            &ListToolsParams {
340                cursor: None,
341                meta: None,
342            },
343        )
344        .await
345    }
346
347    /// Call a tool.
348    pub async fn call_tool(
349        &self,
350        name: &str,
351        arguments: serde_json::Value,
352    ) -> Result<CallToolResult> {
353        self.ensure_initialized()?;
354        let params = CallToolParams {
355            name: name.to_string(),
356            arguments,
357            meta: None,
358            task: None,
359        };
360        self.send_request("tools/call", &params).await
361    }
362
363    /// List available resources.
364    pub async fn list_resources(&self) -> Result<ListResourcesResult> {
365        self.ensure_initialized()?;
366        self.send_request(
367            "resources/list",
368            &ListResourcesParams {
369                cursor: None,
370                meta: None,
371            },
372        )
373        .await
374    }
375
376    /// Read a resource.
377    pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult> {
378        self.ensure_initialized()?;
379        let params = ReadResourceParams {
380            uri: uri.to_string(),
381            meta: None,
382        };
383        self.send_request("resources/read", &params).await
384    }
385
386    /// List available prompts.
387    pub async fn list_prompts(&self) -> Result<ListPromptsResult> {
388        self.ensure_initialized()?;
389        self.send_request(
390            "prompts/list",
391            &ListPromptsParams {
392                cursor: None,
393                meta: None,
394            },
395        )
396        .await
397    }
398
399    /// List tools with an optional pagination cursor.
400    pub async fn list_tools_with_cursor(&self, cursor: Option<String>) -> Result<ListToolsResult> {
401        self.ensure_initialized()?;
402        self.send_request("tools/list", &ListToolsParams { cursor, meta: None })
403            .await
404    }
405
406    /// List resources with an optional pagination cursor.
407    pub async fn list_resources_with_cursor(
408        &self,
409        cursor: Option<String>,
410    ) -> Result<ListResourcesResult> {
411        self.ensure_initialized()?;
412        self.send_request(
413            "resources/list",
414            &ListResourcesParams { cursor, meta: None },
415        )
416        .await
417    }
418
419    /// List resource templates.
420    pub async fn list_resource_templates(&self) -> Result<ListResourceTemplatesResult> {
421        self.ensure_initialized()?;
422        self.send_request(
423            "resources/templates/list",
424            &ListResourceTemplatesParams {
425                cursor: None,
426                meta: None,
427            },
428        )
429        .await
430    }
431
432    /// List resource templates with an optional pagination cursor.
433    pub async fn list_resource_templates_with_cursor(
434        &self,
435        cursor: Option<String>,
436    ) -> Result<ListResourceTemplatesResult> {
437        self.ensure_initialized()?;
438        self.send_request(
439            "resources/templates/list",
440            &ListResourceTemplatesParams { cursor, meta: None },
441        )
442        .await
443    }
444
445    /// List prompts with an optional pagination cursor.
446    pub async fn list_prompts_with_cursor(
447        &self,
448        cursor: Option<String>,
449    ) -> Result<ListPromptsResult> {
450        self.ensure_initialized()?;
451        self.send_request("prompts/list", &ListPromptsParams { cursor, meta: None })
452            .await
453    }
454
455    /// List all tools, following pagination cursors until exhausted.
456    pub async fn list_all_tools(&self) -> Result<Vec<ToolDefinition>> {
457        let mut all = Vec::new();
458        let mut cursor = None;
459        loop {
460            let result = self.list_tools_with_cursor(cursor).await?;
461            all.extend(result.tools);
462            match result.next_cursor {
463                Some(c) => cursor = Some(c),
464                None => break,
465            }
466        }
467        Ok(all)
468    }
469
470    /// List all resources, following pagination cursors until exhausted.
471    pub async fn list_all_resources(&self) -> Result<Vec<ResourceDefinition>> {
472        let mut all = Vec::new();
473        let mut cursor = None;
474        loop {
475            let result = self.list_resources_with_cursor(cursor).await?;
476            all.extend(result.resources);
477            match result.next_cursor {
478                Some(c) => cursor = Some(c),
479                None => break,
480            }
481        }
482        Ok(all)
483    }
484
485    /// List all resource templates, following pagination cursors until exhausted.
486    pub async fn list_all_resource_templates(&self) -> Result<Vec<ResourceTemplateDefinition>> {
487        let mut all = Vec::new();
488        let mut cursor = None;
489        loop {
490            let result = self.list_resource_templates_with_cursor(cursor).await?;
491            all.extend(result.resource_templates);
492            match result.next_cursor {
493                Some(c) => cursor = Some(c),
494                None => break,
495            }
496        }
497        Ok(all)
498    }
499
500    /// List all prompts, following pagination cursors until exhausted.
501    pub async fn list_all_prompts(&self) -> Result<Vec<PromptDefinition>> {
502        let mut all = Vec::new();
503        let mut cursor = None;
504        loop {
505            let result = self.list_prompts_with_cursor(cursor).await?;
506            all.extend(result.prompts);
507            match result.next_cursor {
508                Some(c) => cursor = Some(c),
509                None => break,
510            }
511        }
512        Ok(all)
513    }
514
515    /// Call a tool and return the concatenated text content.
516    ///
517    /// Returns the text from all [`Text`](crate::protocol::Content::Text) items joined together.
518    /// If the tool result indicates an error (`is_error` is true), returns
519    /// an error with the text content as the message.
520    ///
521    /// For more control over the result, use [`call_tool()`](Self::call_tool).
522    pub async fn call_tool_text(&self, name: &str, arguments: serde_json::Value) -> Result<String> {
523        let result = self.call_tool(name, arguments).await?;
524        if result.is_error {
525            return Err(Error::Internal(result.all_text()));
526        }
527        Ok(result.all_text())
528    }
529
530    /// Get a prompt.
531    pub async fn get_prompt(
532        &self,
533        name: &str,
534        arguments: Option<std::collections::HashMap<String, String>>,
535    ) -> Result<GetPromptResult> {
536        self.ensure_initialized()?;
537        let params = GetPromptParams {
538            name: name.to_string(),
539            arguments: arguments.unwrap_or_default(),
540            meta: None,
541        };
542        self.send_request("prompts/get", &params).await
543    }
544
545    /// Ping the server.
546    pub async fn ping(&self) -> Result<()> {
547        let _: serde_json::Value = self.send_request("ping", &serde_json::json!({})).await?;
548        Ok(())
549    }
550
551    /// Request completion suggestions from the server.
552    pub async fn complete(
553        &self,
554        reference: CompletionReference,
555        argument_name: &str,
556        argument_value: &str,
557    ) -> Result<CompleteResult> {
558        self.ensure_initialized()?;
559        let params = CompleteParams {
560            reference,
561            argument: CompletionArgument::new(argument_name, argument_value),
562            context: None,
563            meta: None,
564        };
565        self.send_request("completion/complete", &params).await
566    }
567
568    /// Request completion for a prompt argument.
569    pub async fn complete_prompt_arg(
570        &self,
571        prompt_name: &str,
572        argument_name: &str,
573        argument_value: &str,
574    ) -> Result<CompleteResult> {
575        self.complete(
576            CompletionReference::prompt(prompt_name),
577            argument_name,
578            argument_value,
579        )
580        .await
581    }
582
583    /// Request completion for a resource URI.
584    pub async fn complete_resource_uri(
585        &self,
586        resource_uri: &str,
587        argument_name: &str,
588        argument_value: &str,
589    ) -> Result<CompleteResult> {
590        self.complete(
591            CompletionReference::resource(resource_uri),
592            argument_name,
593            argument_value,
594        )
595        .await
596    }
597
598    /// Send a raw typed request to the server.
599    pub async fn request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
600        &self,
601        method: &str,
602        params: &P,
603    ) -> Result<R> {
604        self.send_request(method, params).await
605    }
606
607    /// Send a raw typed notification to the server.
608    pub async fn notify<P: serde::Serialize>(&self, method: &str, params: &P) -> Result<()> {
609        self.send_notification(method, params).await
610    }
611
612    /// Get the current roots.
613    pub async fn roots(&self) -> Vec<Root> {
614        self.roots.read().await.clone()
615    }
616
617    /// Set roots and notify the server if initialized.
618    pub async fn set_roots(&self, roots: Vec<Root>) -> Result<()> {
619        *self.roots.write().await = roots;
620        if self.is_initialized() {
621            self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
622                .await?;
623        }
624        Ok(())
625    }
626
627    /// Add a root and notify the server if initialized.
628    pub async fn add_root(&self, root: Root) -> Result<()> {
629        self.roots.write().await.push(root);
630        if self.is_initialized() {
631            self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
632                .await?;
633        }
634        Ok(())
635    }
636
637    /// Remove a root by URI and notify the server if initialized.
638    pub async fn remove_root(&self, uri: &str) -> Result<bool> {
639        let mut roots = self.roots.write().await;
640        let initial_len = roots.len();
641        roots.retain(|r| r.uri != uri);
642        let removed = roots.len() < initial_len;
643        drop(roots);
644
645        if removed && self.is_initialized() {
646            self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
647                .await?;
648        }
649        Ok(removed)
650    }
651
652    /// Get the roots list result (for responding to server's roots/list request).
653    pub async fn list_roots(&self) -> ListRootsResult {
654        ListRootsResult {
655            roots: self.roots.read().await.clone(),
656            meta: None,
657        }
658    }
659
660    /// Gracefully shut down the client and close the transport.
661    pub async fn shutdown(mut self) -> Result<()> {
662        let _ = self.command_tx.send(LoopCommand::Shutdown).await;
663        if let Some(task) = self.task.take() {
664            let _ = task.await;
665        }
666        Ok(())
667    }
668
669    // --- Internal helpers ---
670
671    async fn send_request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
672        &self,
673        method: &str,
674        params: &P,
675    ) -> Result<R> {
676        self.ensure_connected()?;
677        let params_value = serde_json::to_value(params)
678            .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
679
680        let (response_tx, response_rx) = oneshot::channel();
681        self.command_tx
682            .send(LoopCommand::Request {
683                method: method.to_string(),
684                params: params_value,
685                response_tx,
686            })
687            .await
688            .map_err(|_| Error::Transport("Connection closed".to_string()))?;
689
690        let result = response_rx
691            .await
692            .map_err(|_| Error::Transport("Connection closed".to_string()))??;
693
694        serde_json::from_value(result)
695            .map_err(|e| Error::Transport(format!("Failed to deserialize response: {}", e)))
696    }
697
698    async fn send_notification<P: serde::Serialize>(&self, method: &str, params: &P) -> Result<()> {
699        self.ensure_connected()?;
700        let params_value = serde_json::to_value(params)
701            .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
702
703        self.command_tx
704            .send(LoopCommand::Notify {
705                method: method.to_string(),
706                params: params_value,
707            })
708            .await
709            .map_err(|_| Error::Transport("Connection closed".to_string()))?;
710
711        Ok(())
712    }
713
714    fn ensure_connected(&self) -> Result<()> {
715        if !self.connected.load(Ordering::Acquire) {
716            return Err(Error::Transport("Connection closed".to_string()));
717        }
718        Ok(())
719    }
720
721    fn ensure_initialized(&self) -> Result<()> {
722        if !self.initialized.load(Ordering::Acquire) {
723            return Err(Error::Transport("Client not initialized".to_string()));
724        }
725        Ok(())
726    }
727}
728
729impl Drop for McpClient {
730    fn drop(&mut self) {
731        if let Some(task) = self.task.take() {
732            task.abort();
733        }
734    }
735}
736
737// =============================================================================
738// Background Message Loop
739// =============================================================================
740
741/// A pending request waiting for a response from the server.
742struct PendingRequest {
743    response_tx: oneshot::Sender<Result<serde_json::Value>>,
744}
745
746/// Background message loop that multiplexes incoming/outgoing messages.
747async fn message_loop<T: ClientTransport, H: ClientHandler>(
748    mut transport: T,
749    handler: H,
750    mut command_rx: mpsc::Receiver<LoopCommand>,
751    connected: Arc<AtomicBool>,
752    roots: Arc<RwLock<Vec<Root>>>,
753) {
754    let handler = Arc::new(handler);
755    let mut pending_requests: HashMap<RequestId, PendingRequest> = HashMap::new();
756    let next_id = AtomicI64::new(1);
757
758    loop {
759        tokio::select! {
760            // Commands from McpClient methods
761            command = command_rx.recv() => {
762                match command {
763                    Some(LoopCommand::Request { method, params, response_tx }) => {
764                        let id = RequestId::Number(next_id.fetch_add(1, Ordering::Relaxed));
765
766                        let request = JsonRpcRequest::new(id.clone(), &method)
767                            .with_params(params);
768                        let json = match serde_json::to_string(&request) {
769                            Ok(j) => j,
770                            Err(e) => {
771                                let _ = response_tx.send(Err(Error::Transport(
772                                    format!("Serialization failed: {}", e)
773                                )));
774                                continue;
775                            }
776                        };
777
778                        tracing::debug!(method = %method, id = ?id, "Sending request");
779                        pending_requests.insert(id, PendingRequest { response_tx });
780
781                        if let Err(e) = transport.send(&json).await {
782                            tracing::error!(error = %e, "Transport send error");
783                            fail_all_pending(&mut pending_requests, &format!("Transport error: {}", e));
784                            break;
785                        }
786                    }
787                    Some(LoopCommand::Notify { method, params }) => {
788                        let notification = JsonRpcNotification::new(&method)
789                            .with_params(params);
790                        if let Ok(json) = serde_json::to_string(&notification) {
791                            tracing::debug!(method = %method, "Sending notification");
792                            let _ = transport.send(&json).await;
793                        }
794                    }
795                    Some(LoopCommand::Shutdown) | None => {
796                        tracing::debug!("Message loop shutting down");
797                        break;
798                    }
799                }
800            }
801
802            // Incoming messages from the server
803            result = transport.recv() => {
804                match result {
805                    Ok(Some(line)) => {
806                        handle_incoming(
807                            &line,
808                            &mut pending_requests,
809                            &handler,
810                            &roots,
811                            &mut transport,
812                        ).await;
813                    }
814                    Ok(None) => {
815                        tracing::info!("Transport closed (EOF)");
816                        break;
817                    }
818                    Err(e) => {
819                        tracing::error!(error = %e, "Transport receive error");
820                        break;
821                    }
822                }
823            }
824        }
825    }
826
827    // Cleanup
828    connected.store(false, Ordering::Release);
829    fail_all_pending(&mut pending_requests, "Connection closed");
830    let _ = transport.close().await;
831}
832
833/// Handle a single incoming message from the server.
834async fn handle_incoming<T: ClientTransport, H: ClientHandler>(
835    line: &str,
836    pending_requests: &mut HashMap<RequestId, PendingRequest>,
837    handler: &Arc<H>,
838    roots: &Arc<RwLock<Vec<Root>>>,
839    transport: &mut T,
840) {
841    let parsed: serde_json::Value = match serde_json::from_str(line) {
842        Ok(v) => v,
843        Err(e) => {
844            tracing::warn!(error = %e, "Failed to parse incoming message");
845            return;
846        }
847    };
848
849    // Case 1: Response to one of our pending requests (has result or error, no method)
850    if parsed.get("method").is_none()
851        && (parsed.get("result").is_some() || parsed.get("error").is_some())
852    {
853        handle_response(&parsed, pending_requests);
854        return;
855    }
856
857    // Case 2: Server-initiated request (has id + method)
858    if parsed.get("id").is_some() && parsed.get("method").is_some() {
859        let id = parse_request_id(&parsed);
860        let method = parsed["method"].as_str().unwrap_or("");
861        let params = parsed.get("params").cloned();
862
863        let result = dispatch_server_request(handler, roots, method, params).await;
864
865        // Send response back to the server
866        let response = match result {
867            Ok(value) => {
868                if let Some(id) = id {
869                    serde_json::json!({
870                        "jsonrpc": "2.0",
871                        "id": id,
872                        "result": value
873                    })
874                } else {
875                    return;
876                }
877            }
878            Err(error) => {
879                serde_json::json!({
880                    "jsonrpc": "2.0",
881                    "id": id,
882                    "error": {
883                        "code": error.code,
884                        "message": error.message
885                    }
886                })
887            }
888        };
889
890        if let Ok(json) = serde_json::to_string(&response) {
891            let _ = transport.send(&json).await;
892        }
893        return;
894    }
895
896    // Case 3: Server notification (has method, no id)
897    if parsed.get("method").is_some() && parsed.get("id").is_none() {
898        let method = parsed["method"].as_str().unwrap_or("");
899        let params = parsed.get("params").cloned();
900        let notification = parse_server_notification(method, params);
901        handler.on_notification(notification).await;
902    }
903}
904
905/// Handle a JSON-RPC response by routing to the pending request.
906fn handle_response(
907    parsed: &serde_json::Value,
908    pending_requests: &mut HashMap<RequestId, PendingRequest>,
909) {
910    let id = match parse_request_id(parsed) {
911        Some(id) => id,
912        None => {
913            tracing::warn!("Response without id");
914            return;
915        }
916    };
917
918    let pending = match pending_requests.remove(&id) {
919        Some(p) => p,
920        None => {
921            tracing::warn!(id = ?id, "Response for unknown request");
922            return;
923        }
924    };
925
926    tracing::debug!(id = ?id, "Received response");
927
928    if let Some(error) = parsed.get("error") {
929        let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1) as i32;
930        let message = error
931            .get("message")
932            .and_then(|m| m.as_str())
933            .unwrap_or("Unknown error")
934            .to_string();
935        let data = error.get("data").cloned();
936        let json_rpc_error = JsonRpcError {
937            code,
938            message,
939            data,
940        };
941        let _ = pending
942            .response_tx
943            .send(Err(Error::JsonRpc(json_rpc_error)));
944    } else if let Some(result) = parsed.get("result") {
945        let _ = pending.response_tx.send(Ok(result.clone()));
946    } else {
947        let _ = pending
948            .response_tx
949            .send(Err(Error::Transport("Invalid response".to_string())));
950    }
951}
952
953/// Dispatch a server-initiated request to the handler.
954async fn dispatch_server_request<H: ClientHandler>(
955    handler: &Arc<H>,
956    roots: &Arc<RwLock<Vec<Root>>>,
957    method: &str,
958    params: Option<serde_json::Value>,
959) -> std::result::Result<serde_json::Value, JsonRpcError> {
960    match method {
961        "sampling/createMessage" => {
962            let p = serde_json::from_value(params.unwrap_or_default())
963                .map_err(|e| JsonRpcError::invalid_params(e.to_string()))?;
964            let result = handler.handle_create_message(p).await?;
965            serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
966        }
967        "elicitation/create" => {
968            let p = serde_json::from_value(params.unwrap_or_default())
969                .map_err(|e| JsonRpcError::invalid_params(e.to_string()))?;
970            let result = handler.handle_elicit(p).await?;
971            serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
972        }
973        "roots/list" => {
974            // Use client-configured roots if available, otherwise delegate to handler
975            let roots_list = roots.read().await;
976            if !roots_list.is_empty() {
977                let result = ListRootsResult {
978                    roots: roots_list.clone(),
979                    meta: None,
980                };
981                return serde_json::to_value(result)
982                    .map_err(|e| JsonRpcError::internal_error(e.to_string()));
983            }
984            drop(roots_list);
985
986            let result = handler.handle_list_roots().await?;
987            serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
988        }
989        "ping" => Ok(serde_json::json!({})),
990        _ => Err(JsonRpcError::method_not_found(method)),
991    }
992}
993
994/// Parse a request ID from a JSON-RPC message.
995fn parse_request_id(parsed: &serde_json::Value) -> Option<RequestId> {
996    parsed.get("id").and_then(|id| {
997        if let Some(n) = id.as_i64() {
998            Some(RequestId::Number(n))
999        } else {
1000            id.as_str().map(|s| RequestId::String(s.to_string()))
1001        }
1002    })
1003}
1004
1005/// Parse a server notification into the typed enum.
1006fn parse_server_notification(
1007    method: &str,
1008    params: Option<serde_json::Value>,
1009) -> ServerNotification {
1010    match method {
1011        notifications::PROGRESS => {
1012            if let Some(params) = params
1013                && let Ok(p) = serde_json::from_value(params)
1014            {
1015                return ServerNotification::Progress(p);
1016            }
1017            ServerNotification::Unknown {
1018                method: method.to_string(),
1019                params: None,
1020            }
1021        }
1022        notifications::MESSAGE => {
1023            if let Some(params) = params
1024                && let Ok(p) = serde_json::from_value(params)
1025            {
1026                return ServerNotification::LogMessage(p);
1027            }
1028            ServerNotification::Unknown {
1029                method: method.to_string(),
1030                params: None,
1031            }
1032        }
1033        notifications::RESOURCE_UPDATED => {
1034            if let Some(params) = &params
1035                && let Some(uri) = params.get("uri").and_then(|u| u.as_str())
1036            {
1037                return ServerNotification::ResourceUpdated {
1038                    uri: uri.to_string(),
1039                };
1040            }
1041            ServerNotification::Unknown {
1042                method: method.to_string(),
1043                params,
1044            }
1045        }
1046        notifications::RESOURCES_LIST_CHANGED => ServerNotification::ResourcesListChanged,
1047        notifications::TOOLS_LIST_CHANGED => ServerNotification::ToolsListChanged,
1048        notifications::PROMPTS_LIST_CHANGED => ServerNotification::PromptsListChanged,
1049        _ => ServerNotification::Unknown {
1050            method: method.to_string(),
1051            params,
1052        },
1053    }
1054}
1055
1056/// Fail all pending requests with the given error message.
1057fn fail_all_pending(pending: &mut HashMap<RequestId, PendingRequest>, reason: &str) {
1058    for (_, req) in pending.drain() {
1059        let _ = req
1060            .response_tx
1061            .send(Err(Error::Transport(reason.to_string())));
1062    }
1063}
1064
1065#[cfg(test)]
1066mod tests {
1067    use super::*;
1068    use async_trait::async_trait;
1069    use std::sync::Mutex;
1070
1071    /// Mock transport for testing that auto-responds to requests.
1072    ///
1073    /// When the client sends a request via `send()`, the mock extracts the
1074    /// request ID, pairs it with the next preconfigured response, and feeds
1075    /// it back through a channel that `recv()` awaits on. This ensures
1076    /// `recv()` blocks when no messages are available (instead of returning
1077    /// EOF), keeping the background message loop alive.
1078    struct MockTransport {
1079        /// Pre-configured response payloads (result values, not full envelopes).
1080        responses: Arc<Mutex<Vec<serde_json::Value>>>,
1081        /// Index of the next response to use.
1082        response_idx: Arc<std::sync::atomic::AtomicUsize>,
1083        /// Channel sender for feeding responses back to `recv()`.
1084        incoming_tx: mpsc::Sender<String>,
1085        /// Channel receiver for `recv()` to await on.
1086        incoming_rx: mpsc::Receiver<String>,
1087        /// Collected outgoing messages from `send()`.
1088        outgoing: Arc<Mutex<Vec<String>>>,
1089        connected: Arc<AtomicBool>,
1090    }
1091
1092    #[allow(dead_code)]
1093    impl MockTransport {
1094        fn new() -> Self {
1095            let (tx, rx) = mpsc::channel(32);
1096            Self {
1097                responses: Arc::new(Mutex::new(Vec::new())),
1098                response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1099                incoming_tx: tx,
1100                incoming_rx: rx,
1101                outgoing: Arc::new(Mutex::new(Vec::new())),
1102                connected: Arc::new(AtomicBool::new(true)),
1103            }
1104        }
1105
1106        /// Create a mock that auto-responds with the given result payloads.
1107        ///
1108        /// When `send()` receives a JSON-RPC request, it extracts the request
1109        /// ID and pairs it with the next response from this list, sending the
1110        /// complete JSON-RPC response through the channel for `recv()`.
1111        fn with_responses(responses: Vec<serde_json::Value>) -> Self {
1112            let (tx, rx) = mpsc::channel(32);
1113            Self {
1114                responses: Arc::new(Mutex::new(responses)),
1115                response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1116                incoming_tx: tx,
1117                incoming_rx: rx,
1118                outgoing: Arc::new(Mutex::new(Vec::new())),
1119                connected: Arc::new(AtomicBool::new(true)),
1120            }
1121        }
1122    }
1123
1124    #[async_trait]
1125    impl ClientTransport for MockTransport {
1126        async fn send(&mut self, message: &str) -> Result<()> {
1127            self.outgoing.lock().unwrap().push(message.to_string());
1128
1129            // Parse the outgoing message to extract the request ID
1130            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(message) {
1131                // Only respond to requests (messages with an id and method)
1132                if let Some(id) = parsed.get("id") {
1133                    let idx = self
1134                        .response_idx
1135                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1136                    let responses = self.responses.lock().unwrap();
1137                    if let Some(result) = responses.get(idx) {
1138                        let response = serde_json::json!({
1139                            "jsonrpc": "2.0",
1140                            "id": id,
1141                            "result": result
1142                        });
1143                        let _ = self.incoming_tx.try_send(response.to_string());
1144                    }
1145                }
1146            }
1147
1148            Ok(())
1149        }
1150
1151        async fn recv(&mut self) -> Result<Option<String>> {
1152            // Await on the channel -- blocks until a message is available
1153            // or the sender is dropped (returns None = EOF).
1154            match self.incoming_rx.recv().await {
1155                Some(msg) => Ok(Some(msg)),
1156                None => Ok(None),
1157            }
1158        }
1159
1160        fn is_connected(&self) -> bool {
1161            self.connected.load(Ordering::Relaxed)
1162        }
1163
1164        async fn close(&mut self) -> Result<()> {
1165            self.connected.store(false, Ordering::Relaxed);
1166            Ok(())
1167        }
1168    }
1169
1170    fn mock_initialize_response() -> serde_json::Value {
1171        serde_json::json!({
1172            "protocolVersion": "2025-11-25",
1173            "serverInfo": {
1174                "name": "test-server",
1175                "version": "1.0.0"
1176            },
1177            "capabilities": {
1178                "tools": {}
1179            }
1180        })
1181    }
1182
1183    #[tokio::test]
1184    async fn test_client_not_initialized() {
1185        let client = McpClient::connect(MockTransport::with_responses(vec![]))
1186            .await
1187            .unwrap();
1188
1189        let result = client.list_tools().await;
1190        assert!(result.is_err());
1191        assert!(result.unwrap_err().to_string().contains("not initialized"));
1192    }
1193
1194    #[tokio::test]
1195    async fn test_client_initialize() {
1196        let client = McpClient::connect(MockTransport::with_responses(vec![
1197            mock_initialize_response(),
1198        ]))
1199        .await
1200        .unwrap();
1201
1202        assert!(!client.is_initialized());
1203
1204        let result = client.initialize("test-client", "1.0.0").await;
1205        assert!(result.is_ok());
1206        assert!(client.is_initialized());
1207
1208        let server_info = client.server_info().unwrap();
1209        assert_eq!(server_info.server_info.name, "test-server");
1210    }
1211
1212    #[tokio::test]
1213    async fn test_list_tools() {
1214        let client = McpClient::connect(MockTransport::with_responses(vec![
1215            mock_initialize_response(),
1216            serde_json::json!({
1217                "tools": [
1218                    {
1219                        "name": "test_tool",
1220                        "description": "A test tool",
1221                        "inputSchema": {
1222                            "type": "object",
1223                            "properties": {}
1224                        }
1225                    }
1226                ]
1227            }),
1228        ]))
1229        .await
1230        .unwrap();
1231
1232        client.initialize("test-client", "1.0.0").await.unwrap();
1233        let tools = client.list_tools().await.unwrap();
1234
1235        assert_eq!(tools.tools.len(), 1);
1236        assert_eq!(tools.tools[0].name, "test_tool");
1237    }
1238
1239    #[tokio::test]
1240    async fn test_call_tool() {
1241        let client = McpClient::connect(MockTransport::with_responses(vec![
1242            mock_initialize_response(),
1243            serde_json::json!({
1244                "content": [
1245                    {
1246                        "type": "text",
1247                        "text": "Tool result"
1248                    }
1249                ]
1250            }),
1251        ]))
1252        .await
1253        .unwrap();
1254
1255        client.initialize("test-client", "1.0.0").await.unwrap();
1256        let result = client
1257            .call_tool("test_tool", serde_json::json!({"arg": "value"}))
1258            .await
1259            .unwrap();
1260
1261        assert!(!result.content.is_empty());
1262    }
1263
1264    #[tokio::test]
1265    async fn test_list_resources() {
1266        let client = McpClient::connect(MockTransport::with_responses(vec![
1267            mock_initialize_response(),
1268            serde_json::json!({
1269                "resources": [
1270                    {
1271                        "uri": "file://test.txt",
1272                        "name": "Test File"
1273                    }
1274                ]
1275            }),
1276        ]))
1277        .await
1278        .unwrap();
1279
1280        client.initialize("test-client", "1.0.0").await.unwrap();
1281        let resources = client.list_resources().await.unwrap();
1282
1283        assert_eq!(resources.resources.len(), 1);
1284        assert_eq!(resources.resources[0].uri, "file://test.txt");
1285    }
1286
1287    #[tokio::test]
1288    async fn test_read_resource() {
1289        let client = McpClient::connect(MockTransport::with_responses(vec![
1290            mock_initialize_response(),
1291            serde_json::json!({
1292                "contents": [
1293                    {
1294                        "uri": "file://test.txt",
1295                        "text": "File contents"
1296                    }
1297                ]
1298            }),
1299        ]))
1300        .await
1301        .unwrap();
1302
1303        client.initialize("test-client", "1.0.0").await.unwrap();
1304        let result = client.read_resource("file://test.txt").await.unwrap();
1305
1306        assert_eq!(result.contents.len(), 1);
1307        assert_eq!(result.contents[0].text.as_deref(), Some("File contents"));
1308    }
1309
1310    #[tokio::test]
1311    async fn test_list_prompts() {
1312        let client = McpClient::connect(MockTransport::with_responses(vec![
1313            mock_initialize_response(),
1314            serde_json::json!({
1315                "prompts": [
1316                    {
1317                        "name": "test_prompt",
1318                        "description": "A test prompt"
1319                    }
1320                ]
1321            }),
1322        ]))
1323        .await
1324        .unwrap();
1325
1326        client.initialize("test-client", "1.0.0").await.unwrap();
1327        let prompts = client.list_prompts().await.unwrap();
1328
1329        assert_eq!(prompts.prompts.len(), 1);
1330        assert_eq!(prompts.prompts[0].name, "test_prompt");
1331    }
1332
1333    #[tokio::test]
1334    async fn test_get_prompt() {
1335        let client = McpClient::connect(MockTransport::with_responses(vec![
1336            mock_initialize_response(),
1337            serde_json::json!({
1338                "messages": [
1339                    {
1340                        "role": "user",
1341                        "content": {
1342                            "type": "text",
1343                            "text": "Prompt message"
1344                        }
1345                    }
1346                ]
1347            }),
1348        ]))
1349        .await
1350        .unwrap();
1351
1352        client.initialize("test-client", "1.0.0").await.unwrap();
1353        let result = client.get_prompt("test_prompt", None).await.unwrap();
1354
1355        assert_eq!(result.messages.len(), 1);
1356    }
1357
1358    #[tokio::test]
1359    async fn test_ping() {
1360        let client = McpClient::connect(MockTransport::with_responses(vec![
1361            mock_initialize_response(),
1362            serde_json::json!({}),
1363        ]))
1364        .await
1365        .unwrap();
1366
1367        client.initialize("test-client", "1.0.0").await.unwrap();
1368        let result = client.ping().await;
1369        assert!(result.is_ok());
1370    }
1371
1372    #[tokio::test]
1373    async fn test_with_roots() {
1374        let roots = vec![Root::new("file:///test")];
1375        let client = McpClient::builder()
1376            .with_roots(roots)
1377            .connect_simple(MockTransport::with_responses(vec![]))
1378            .await
1379            .unwrap();
1380
1381        let current_roots = client.roots().await;
1382        assert_eq!(current_roots.len(), 1);
1383    }
1384
1385    #[tokio::test]
1386    async fn test_roots_management() {
1387        let client = McpClient::connect(MockTransport::with_responses(vec![
1388            mock_initialize_response(),
1389        ]))
1390        .await
1391        .unwrap();
1392
1393        // Initially no roots
1394        assert!(client.roots().await.is_empty());
1395
1396        // Add a root before initialization (no notification sent)
1397        client.add_root(Root::new("file:///project")).await.unwrap();
1398        assert_eq!(client.roots().await.len(), 1);
1399
1400        // Initialize
1401        client.initialize("test-client", "1.0.0").await.unwrap();
1402
1403        // Remove a root
1404        let removed = client.remove_root("file:///project").await.unwrap();
1405        assert!(removed);
1406        assert!(client.roots().await.is_empty());
1407
1408        // Try to remove non-existent root
1409        let not_removed = client.remove_root("file:///nonexistent").await.unwrap();
1410        assert!(!not_removed);
1411    }
1412
1413    #[tokio::test]
1414    async fn test_list_roots() {
1415        let roots = vec![
1416            Root::new("file:///project1"),
1417            Root::with_name("file:///project2", "Project 2"),
1418        ];
1419        let client = McpClient::builder()
1420            .with_roots(roots)
1421            .connect_simple(MockTransport::with_responses(vec![]))
1422            .await
1423            .unwrap();
1424
1425        let result = client.list_roots().await;
1426        assert_eq!(result.roots.len(), 2);
1427        assert_eq!(result.roots[1].name, Some("Project 2".to_string()));
1428    }
1429
1430    #[test]
1431    fn test_builder_with_sampling() {
1432        let builder = McpClientBuilder::new().with_sampling();
1433        assert!(builder.capabilities.sampling.is_some());
1434    }
1435
1436    #[test]
1437    fn test_builder_with_elicitation() {
1438        let builder = McpClientBuilder::new().with_elicitation();
1439        assert!(builder.capabilities.elicitation.is_some());
1440    }
1441
1442    #[test]
1443    fn test_builder_chaining() {
1444        let builder = McpClientBuilder::new()
1445            .with_sampling()
1446            .with_elicitation()
1447            .with_roots(vec![Root::new("file:///project")]);
1448        assert!(builder.capabilities.sampling.is_some());
1449        assert!(builder.capabilities.elicitation.is_some());
1450        assert!(builder.capabilities.roots.is_some());
1451    }
1452
1453    #[tokio::test]
1454    async fn test_bidirectional_sampling_round_trip() {
1455        use crate::protocol::{
1456            ContentRole, CreateMessageParams, CreateMessageResult, SamplingContent,
1457            SamplingContentOrArray,
1458        };
1459
1460        // A handler that records whether handle_create_message was called
1461        struct RecordingHandler {
1462            called: Arc<AtomicBool>,
1463        }
1464
1465        #[async_trait]
1466        impl ClientHandler for RecordingHandler {
1467            async fn handle_create_message(
1468                &self,
1469                _params: CreateMessageParams,
1470            ) -> std::result::Result<CreateMessageResult, tower_mcp_types::JsonRpcError>
1471            {
1472                self.called.store(true, Ordering::SeqCst);
1473                Ok(CreateMessageResult {
1474                    content: SamplingContentOrArray::Single(SamplingContent::Text {
1475                        text: "test response".to_string(),
1476                        annotations: None,
1477                        meta: None,
1478                    }),
1479                    model: "test-model".to_string(),
1480                    role: ContentRole::Assistant,
1481                    stop_reason: Some("end_turn".to_string()),
1482                    meta: None,
1483                })
1484            }
1485        }
1486
1487        let called = Arc::new(AtomicBool::new(false));
1488        let handler = RecordingHandler {
1489            called: called.clone(),
1490        };
1491
1492        // Build a mock transport, keeping a clone of incoming_tx so we can
1493        // inject a server-initiated request after the transport is consumed.
1494        let (inject_tx, rx) = mpsc::channel::<String>(32);
1495        let responses = vec![mock_initialize_response()];
1496        let inject_tx_clone = inject_tx.clone();
1497
1498        let transport = MockTransport {
1499            responses: Arc::new(Mutex::new(responses)),
1500            response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1501            incoming_tx: inject_tx,
1502            incoming_rx: rx,
1503            outgoing: Arc::new(Mutex::new(Vec::new())),
1504            connected: Arc::new(AtomicBool::new(true)),
1505        };
1506
1507        let client = McpClient::builder()
1508            .with_sampling()
1509            .connect(transport, handler)
1510            .await
1511            .unwrap();
1512
1513        // Initialize the client (this sends initialize request + notification)
1514        client.initialize("test-client", "1.0.0").await.unwrap();
1515
1516        // Inject a server-initiated sampling/createMessage request
1517        let sampling_request = serde_json::json!({
1518            "jsonrpc": "2.0",
1519            "id": 100,
1520            "method": "sampling/createMessage",
1521            "params": {
1522                "messages": [
1523                    {
1524                        "role": "user",
1525                        "content": {
1526                            "type": "text",
1527                            "text": "Hello"
1528                        }
1529                    }
1530                ],
1531                "maxTokens": 100
1532            }
1533        });
1534        inject_tx_clone
1535            .send(sampling_request.to_string())
1536            .await
1537            .unwrap();
1538
1539        // Give the background loop time to process
1540        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1541
1542        // Verify the handler was called
1543        assert!(
1544            called.load(Ordering::SeqCst),
1545            "handle_create_message should have been called"
1546        );
1547    }
1548
1549    #[tokio::test]
1550    async fn test_list_resource_templates() {
1551        let client = McpClient::connect(MockTransport::with_responses(vec![
1552            mock_initialize_response(),
1553            serde_json::json!({
1554                "resourceTemplates": [
1555                    {
1556                        "uriTemplate": "file:///{path}",
1557                        "name": "File Template",
1558                        "description": "A file template"
1559                    }
1560                ]
1561            }),
1562        ]))
1563        .await
1564        .unwrap();
1565
1566        client.initialize("test-client", "1.0.0").await.unwrap();
1567        let result = client.list_resource_templates().await.unwrap();
1568
1569        assert_eq!(result.resource_templates.len(), 1);
1570        assert_eq!(result.resource_templates[0].name, "File Template");
1571    }
1572
1573    #[tokio::test]
1574    async fn test_list_all_tools_single_page() {
1575        let client = McpClient::connect(MockTransport::with_responses(vec![
1576            mock_initialize_response(),
1577            serde_json::json!({
1578                "tools": [
1579                    {
1580                        "name": "tool_a",
1581                        "description": "Tool A",
1582                        "inputSchema": { "type": "object", "properties": {} }
1583                    },
1584                    {
1585                        "name": "tool_b",
1586                        "description": "Tool B",
1587                        "inputSchema": { "type": "object", "properties": {} }
1588                    }
1589                ]
1590            }),
1591        ]))
1592        .await
1593        .unwrap();
1594
1595        client.initialize("test-client", "1.0.0").await.unwrap();
1596        let tools = client.list_all_tools().await.unwrap();
1597
1598        assert_eq!(tools.len(), 2);
1599        assert_eq!(tools[0].name, "tool_a");
1600        assert_eq!(tools[1].name, "tool_b");
1601    }
1602
1603    #[tokio::test]
1604    async fn test_list_all_tools_paginated() {
1605        let client = McpClient::connect(MockTransport::with_responses(vec![
1606            mock_initialize_response(),
1607            // First page with a next_cursor
1608            serde_json::json!({
1609                "tools": [
1610                    {
1611                        "name": "tool_a",
1612                        "description": "Tool A",
1613                        "inputSchema": { "type": "object", "properties": {} }
1614                    }
1615                ],
1616                "nextCursor": "page2"
1617            }),
1618            // Second page with no next_cursor
1619            serde_json::json!({
1620                "tools": [
1621                    {
1622                        "name": "tool_b",
1623                        "description": "Tool B",
1624                        "inputSchema": { "type": "object", "properties": {} }
1625                    }
1626                ]
1627            }),
1628        ]))
1629        .await
1630        .unwrap();
1631
1632        client.initialize("test-client", "1.0.0").await.unwrap();
1633        let tools = client.list_all_tools().await.unwrap();
1634
1635        assert_eq!(tools.len(), 2);
1636        assert_eq!(tools[0].name, "tool_a");
1637        assert_eq!(tools[1].name, "tool_b");
1638    }
1639
1640    #[tokio::test]
1641    async fn test_call_tool_text_success() {
1642        let client = McpClient::connect(MockTransport::with_responses(vec![
1643            mock_initialize_response(),
1644            serde_json::json!({
1645                "content": [
1646                    { "type": "text", "text": "Hello " },
1647                    { "type": "text", "text": "World" }
1648                ]
1649            }),
1650        ]))
1651        .await
1652        .unwrap();
1653
1654        client.initialize("test-client", "1.0.0").await.unwrap();
1655        let text = client
1656            .call_tool_text("test_tool", serde_json::json!({}))
1657            .await
1658            .unwrap();
1659
1660        assert_eq!(text, "Hello World");
1661    }
1662
1663    #[tokio::test]
1664    async fn test_call_tool_text_error() {
1665        let client = McpClient::connect(MockTransport::with_responses(vec![
1666            mock_initialize_response(),
1667            serde_json::json!({
1668                "content": [
1669                    { "type": "text", "text": "something went wrong" }
1670                ],
1671                "isError": true
1672            }),
1673        ]))
1674        .await
1675        .unwrap();
1676
1677        client.initialize("test-client", "1.0.0").await.unwrap();
1678        let result = client
1679            .call_tool_text("test_tool", serde_json::json!({}))
1680            .await;
1681
1682        assert!(result.is_err());
1683        let err = result.unwrap_err();
1684        assert!(
1685            err.to_string().contains("something went wrong"),
1686            "Error message should contain tool error text, got: {}",
1687            err
1688        );
1689    }
1690
1691    #[tokio::test]
1692    async fn test_server_notification_parsing() {
1693        let notification = parse_server_notification("notifications/tools/list_changed", None);
1694        assert!(matches!(notification, ServerNotification::ToolsListChanged));
1695
1696        let notification = parse_server_notification("notifications/resources/list_changed", None);
1697        assert!(matches!(
1698            notification,
1699            ServerNotification::ResourcesListChanged
1700        ));
1701
1702        let notification = parse_server_notification(
1703            "notifications/resources/updated",
1704            Some(serde_json::json!({"uri": "file:///test"})),
1705        );
1706        match notification {
1707            ServerNotification::ResourceUpdated { uri } => {
1708                assert_eq!(uri, "file:///test");
1709            }
1710            _ => panic!("Expected ResourceUpdated"),
1711        }
1712
1713        let notification =
1714            parse_server_notification("custom/notification", Some(serde_json::json!({"data": 42})));
1715        match notification {
1716            ServerNotification::Unknown { method, params } => {
1717                assert_eq!(method, "custom/notification");
1718                assert!(params.is_some());
1719            }
1720            _ => panic!("Expected Unknown"),
1721        }
1722    }
1723}