Skip to main content

tower_mcp/
client.rs

1//! MCP Client implementation
2//!
3//! Provides client functionality for connecting to MCP servers.
4//!
5//! # Example
6//!
7//! ```rust,no_run
8//! use tower_mcp::BoxError;
9//! use tower_mcp::client::{McpClient, StdioClientTransport};
10//!
11//! #[tokio::main]
12//! async fn main() -> Result<(), BoxError> {
13//!     // Connect to an MCP server via stdio
14//!     let transport = StdioClientTransport::spawn("my-mcp-server", &["--flag"]).await?;
15//!     let mut client = McpClient::new(transport);
16//!
17//!     // Initialize the connection
18//!     let server_info = client.initialize("my-client", "1.0.0").await?;
19//!     println!("Connected to: {}", server_info.server_info.name);
20//!
21//!     // List available tools
22//!     let tools = client.list_tools().await?;
23//!     for tool in &tools.tools {
24//!         println!("Tool: {}", tool.name);
25//!     }
26//!
27//!     // Call a tool
28//!     let result = client.call_tool("my-tool", serde_json::json!({"arg": "value"})).await?;
29//!     println!("Result: {:?}", result);
30//!
31//!     Ok(())
32//! }
33//! ```
34
35use std::process::Stdio;
36use std::sync::atomic::{AtomicI64, Ordering};
37
38use async_trait::async_trait;
39use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
40use tokio::process::{Child, Command};
41
42use crate::error::{Error, Result};
43use crate::protocol::{
44    CallToolParams, CallToolResult, ClientCapabilities, CompleteParams, CompleteResult,
45    CompletionArgument, CompletionReference, GetPromptParams, GetPromptResult, Implementation,
46    InitializeParams, InitializeResult, JsonRpcRequest, JsonRpcResponse, ListPromptsParams,
47    ListPromptsResult, ListResourcesParams, ListResourcesResult, ListRootsResult, ListToolsParams,
48    ListToolsResult, ReadResourceParams, ReadResourceResult, Root, RootsCapability, notifications,
49};
50
51/// Trait for MCP client transports
52#[async_trait]
53pub trait ClientTransport: Send {
54    /// Send a request and receive a response
55    async fn request(
56        &mut self,
57        method: &str,
58        params: serde_json::Value,
59    ) -> Result<serde_json::Value>;
60
61    /// Send a notification (no response expected)
62    async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()>;
63
64    /// Check if the transport is still connected
65    fn is_connected(&self) -> bool;
66
67    /// Close the transport
68    async fn close(self: Box<Self>) -> Result<()>;
69}
70
71/// MCP Client for connecting to MCP servers
72pub struct McpClient<T: ClientTransport> {
73    transport: T,
74    initialized: bool,
75    server_info: Option<InitializeResult>,
76    /// Client capabilities to declare during initialization
77    capabilities: ClientCapabilities,
78    /// Roots available to the server
79    roots: Vec<Root>,
80}
81
82impl<T: ClientTransport> McpClient<T> {
83    /// Create a new MCP client with the given transport
84    pub fn new(transport: T) -> Self {
85        Self {
86            transport,
87            initialized: false,
88            server_info: None,
89            capabilities: ClientCapabilities::default(),
90            roots: Vec::new(),
91        }
92    }
93
94    /// Configure roots for this client.
95    ///
96    /// The client will declare roots support during initialization and
97    /// provide these roots when requested by the server.
98    ///
99    /// # Example
100    ///
101    /// ```rust,no_run
102    /// use tower_mcp::client::{McpClient, StdioClientTransport};
103    /// use tower_mcp::protocol::Root;
104    ///
105    /// # async fn example() -> Result<(), tower_mcp::BoxError> {
106    /// let transport = StdioClientTransport::spawn("server", &[]).await?;
107    /// let client = McpClient::new(transport)
108    ///     .with_roots(vec![Root { uri: "file:///project".into(), name: Some("Project".into()), meta: None }]);
109    /// # Ok(())
110    /// # }
111    /// ```
112    pub fn with_roots(mut self, roots: Vec<Root>) -> Self {
113        self.roots = roots;
114        self.capabilities.roots = Some(RootsCapability { list_changed: true });
115        self
116    }
117
118    /// Configure custom capabilities for this client.
119    ///
120    /// # Example
121    ///
122    /// ```rust,no_run
123    /// use tower_mcp::client::{McpClient, StdioClientTransport};
124    /// use tower_mcp::protocol::ClientCapabilities;
125    ///
126    /// # async fn example() -> Result<(), tower_mcp::BoxError> {
127    /// let transport = StdioClientTransport::spawn("server", &[]).await?;
128    /// let client = McpClient::new(transport)
129    ///     .with_capabilities(ClientCapabilities {
130    ///         sampling: Some(Default::default()),
131    ///         ..Default::default()
132    ///     });
133    /// # Ok(())
134    /// # }
135    /// ```
136    pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
137        self.capabilities = capabilities;
138        self
139    }
140
141    /// Get the server info (available after initialization)
142    pub fn server_info(&self) -> Option<&InitializeResult> {
143        self.server_info.as_ref()
144    }
145
146    /// Check if the client is initialized
147    pub fn is_initialized(&self) -> bool {
148        self.initialized
149    }
150
151    /// Get the current roots
152    pub fn roots(&self) -> &[Root] {
153        &self.roots
154    }
155
156    /// Set roots and notify the server if initialized
157    ///
158    /// If the client is already initialized, sends a roots list changed notification.
159    pub async fn set_roots(&mut self, roots: Vec<Root>) -> Result<()> {
160        self.roots = roots;
161        if self.initialized {
162            self.notify_roots_changed().await?;
163        }
164        Ok(())
165    }
166
167    /// Add a root and notify the server if initialized
168    pub async fn add_root(&mut self, root: Root) -> Result<()> {
169        self.roots.push(root);
170        if self.initialized {
171            self.notify_roots_changed().await?;
172        }
173        Ok(())
174    }
175
176    /// Remove a root by URI and notify the server if initialized
177    pub async fn remove_root(&mut self, uri: &str) -> Result<bool> {
178        let initial_len = self.roots.len();
179        self.roots.retain(|r| r.uri != uri);
180        let removed = self.roots.len() < initial_len;
181        if removed && self.initialized {
182            self.notify_roots_changed().await?;
183        }
184        Ok(removed)
185    }
186
187    /// Send roots list changed notification to the server
188    async fn notify_roots_changed(&mut self) -> Result<()> {
189        self.transport
190            .notify(notifications::ROOTS_LIST_CHANGED, serde_json::json!({}))
191            .await
192    }
193
194    /// Get the roots list result (for responding to server's roots/list request)
195    ///
196    /// Returns a result suitable for responding to a roots/list request from the server.
197    pub fn list_roots(&self) -> ListRootsResult {
198        ListRootsResult {
199            roots: self.roots.clone(),
200            meta: None,
201        }
202    }
203
204    /// Initialize the MCP connection
205    pub async fn initialize(
206        &mut self,
207        client_name: &str,
208        client_version: &str,
209    ) -> Result<&InitializeResult> {
210        let params = InitializeParams {
211            protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
212            capabilities: self.capabilities.clone(),
213            client_info: Implementation {
214                name: client_name.to_string(),
215                version: client_version.to_string(),
216                ..Default::default()
217            },
218            meta: None,
219        };
220
221        let result: InitializeResult = self.request("initialize", &params).await?;
222        self.server_info = Some(result);
223
224        // Send initialized notification
225        self.transport
226            .notify("notifications/initialized", serde_json::json!({}))
227            .await?;
228
229        self.initialized = true;
230
231        Ok(self.server_info.as_ref().unwrap())
232    }
233
234    /// List available tools
235    pub async fn list_tools(&mut self) -> Result<ListToolsResult> {
236        self.ensure_initialized()?;
237        self.request(
238            "tools/list",
239            &ListToolsParams {
240                cursor: None,
241                meta: None,
242            },
243        )
244        .await
245    }
246
247    /// Call a tool
248    pub async fn call_tool(
249        &mut self,
250        name: &str,
251        arguments: serde_json::Value,
252    ) -> Result<CallToolResult> {
253        self.ensure_initialized()?;
254        let params = CallToolParams {
255            name: name.to_string(),
256            arguments,
257            meta: None,
258            task: None,
259        };
260        self.request("tools/call", &params).await
261    }
262
263    /// List available resources
264    pub async fn list_resources(&mut self) -> Result<ListResourcesResult> {
265        self.ensure_initialized()?;
266        self.request(
267            "resources/list",
268            &ListResourcesParams {
269                cursor: None,
270                meta: None,
271            },
272        )
273        .await
274    }
275
276    /// Read a resource
277    pub async fn read_resource(&mut self, uri: &str) -> Result<ReadResourceResult> {
278        self.ensure_initialized()?;
279        let params = ReadResourceParams {
280            uri: uri.to_string(),
281            meta: None,
282        };
283        self.request("resources/read", &params).await
284    }
285
286    /// List available prompts
287    pub async fn list_prompts(&mut self) -> Result<ListPromptsResult> {
288        self.ensure_initialized()?;
289        self.request(
290            "prompts/list",
291            &ListPromptsParams {
292                cursor: None,
293                meta: None,
294            },
295        )
296        .await
297    }
298
299    /// Get a prompt
300    pub async fn get_prompt(
301        &mut self,
302        name: &str,
303        arguments: Option<std::collections::HashMap<String, String>>,
304    ) -> Result<GetPromptResult> {
305        self.ensure_initialized()?;
306        let params = GetPromptParams {
307            name: name.to_string(),
308            arguments: arguments.unwrap_or_default(),
309            meta: None,
310        };
311        self.request("prompts/get", &params).await
312    }
313
314    /// Ping the server
315    pub async fn ping(&mut self) -> Result<()> {
316        let _: serde_json::Value = self.request("ping", &serde_json::json!({})).await?;
317        Ok(())
318    }
319
320    /// Request completion suggestions from the server
321    ///
322    /// This is used to get autocomplete suggestions for prompt arguments or resource URIs.
323    pub async fn complete(
324        &mut self,
325        reference: CompletionReference,
326        argument_name: &str,
327        argument_value: &str,
328    ) -> Result<CompleteResult> {
329        self.ensure_initialized()?;
330        let params = CompleteParams {
331            reference,
332            argument: CompletionArgument::new(argument_name, argument_value),
333            context: None,
334            meta: None,
335        };
336        self.request("completion/complete", &params).await
337    }
338
339    /// Request completion for a prompt argument
340    pub async fn complete_prompt_arg(
341        &mut self,
342        prompt_name: &str,
343        argument_name: &str,
344        argument_value: &str,
345    ) -> Result<CompleteResult> {
346        self.complete(
347            CompletionReference::prompt(prompt_name),
348            argument_name,
349            argument_value,
350        )
351        .await
352    }
353
354    /// Request completion for a resource URI
355    pub async fn complete_resource_uri(
356        &mut self,
357        resource_uri: &str,
358        argument_name: &str,
359        argument_value: &str,
360    ) -> Result<CompleteResult> {
361        self.complete(
362            CompletionReference::resource(resource_uri),
363            argument_name,
364            argument_value,
365        )
366        .await
367    }
368
369    /// Send a raw request
370    pub async fn request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
371        &mut self,
372        method: &str,
373        params: &P,
374    ) -> Result<R> {
375        let params_value = serde_json::to_value(params)
376            .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
377
378        let result = self.transport.request(method, params_value).await?;
379
380        serde_json::from_value(result)
381            .map_err(|e| Error::Transport(format!("Failed to deserialize response: {}", e)))
382    }
383
384    /// Send a notification
385    pub async fn notify<P: serde::Serialize>(&mut self, method: &str, params: &P) -> Result<()> {
386        let params_value = serde_json::to_value(params)
387            .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
388
389        self.transport.notify(method, params_value).await
390    }
391
392    fn ensure_initialized(&self) -> Result<()> {
393        if !self.initialized {
394            return Err(Error::Transport("Client not initialized".to_string()));
395        }
396        Ok(())
397    }
398}
399
400// ============================================================================
401// Stdio Client Transport
402// ============================================================================
403
404/// Client transport that communicates with a subprocess via stdio
405pub struct StdioClientTransport {
406    child: Option<Child>,
407    stdin: tokio::process::ChildStdin,
408    stdout: BufReader<tokio::process::ChildStdout>,
409    request_id: AtomicI64,
410}
411
412impl StdioClientTransport {
413    /// Spawn a new subprocess and connect to it
414    pub async fn spawn(program: &str, args: &[&str]) -> Result<Self> {
415        let mut cmd = Command::new(program);
416        cmd.args(args)
417            .stdin(Stdio::piped())
418            .stdout(Stdio::piped())
419            .stderr(Stdio::inherit());
420
421        let mut child = cmd
422            .spawn()
423            .map_err(|e| Error::Transport(format!("Failed to spawn {}: {}", program, e)))?;
424
425        let stdin = child
426            .stdin
427            .take()
428            .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
429        let stdout = child
430            .stdout
431            .take()
432            .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
433
434        tracing::info!(program = %program, "Spawned MCP server process");
435
436        Ok(Self {
437            child: Some(child),
438            stdin,
439            stdout: BufReader::new(stdout),
440            request_id: AtomicI64::new(1),
441        })
442    }
443
444    /// Create from an existing child process
445    pub fn from_child(mut child: Child) -> Result<Self> {
446        let stdin = child
447            .stdin
448            .take()
449            .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
450        let stdout = child
451            .stdout
452            .take()
453            .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
454
455        Ok(Self {
456            child: Some(child),
457            stdin,
458            stdout: BufReader::new(stdout),
459            request_id: AtomicI64::new(1),
460        })
461    }
462
463    async fn send_line(&mut self, line: &str) -> Result<()> {
464        self.stdin
465            .write_all(line.as_bytes())
466            .await
467            .map_err(|e| Error::Transport(format!("Failed to write: {}", e)))?;
468        self.stdin
469            .write_all(b"\n")
470            .await
471            .map_err(|e| Error::Transport(format!("Failed to write newline: {}", e)))?;
472        self.stdin
473            .flush()
474            .await
475            .map_err(|e| Error::Transport(format!("Failed to flush: {}", e)))?;
476        Ok(())
477    }
478
479    async fn read_line(&mut self) -> Result<String> {
480        let mut line = String::new();
481        self.stdout
482            .read_line(&mut line)
483            .await
484            .map_err(|e| Error::Transport(format!("Failed to read: {}", e)))?;
485
486        if line.is_empty() {
487            return Err(Error::Transport("Connection closed".to_string()));
488        }
489
490        Ok(line)
491    }
492}
493
494#[async_trait]
495impl ClientTransport for StdioClientTransport {
496    async fn request(
497        &mut self,
498        method: &str,
499        params: serde_json::Value,
500    ) -> Result<serde_json::Value> {
501        let id = self.request_id.fetch_add(1, Ordering::Relaxed);
502        let request = JsonRpcRequest::new(id, method).with_params(params);
503
504        let request_json = serde_json::to_string(&request)
505            .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
506
507        tracing::debug!(method = %method, id = %id, "Sending request");
508        self.send_line(&request_json).await?;
509
510        let response_line = self.read_line().await?;
511        tracing::debug!(response = %response_line.trim(), "Received response");
512
513        let response: JsonRpcResponse = serde_json::from_str(response_line.trim())
514            .map_err(|e| Error::Transport(format!("Failed to parse response: {}", e)))?;
515
516        match response {
517            JsonRpcResponse::Result(r) => Ok(r.result),
518            JsonRpcResponse::Error(e) => Err(Error::JsonRpc(e.error)),
519        }
520    }
521
522    async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
523        let notification = serde_json::json!({
524            "jsonrpc": "2.0",
525            "method": method,
526            "params": params
527        });
528
529        let json = serde_json::to_string(&notification)
530            .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
531
532        tracing::debug!(method = %method, "Sending notification");
533        self.send_line(&json).await
534    }
535
536    fn is_connected(&self) -> bool {
537        // Assume connected if we have a child process handle
538        self.child.is_some()
539    }
540
541    async fn close(mut self: Box<Self>) -> Result<()> {
542        // Close stdin to signal EOF
543        drop(self.stdin);
544
545        if let Some(mut child) = self.child.take() {
546            // Wait for process with timeout
547            let result =
548                tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
549
550            match result {
551                Ok(Ok(status)) => {
552                    tracing::info!(status = ?status, "Child process exited");
553                }
554                Ok(Err(e)) => {
555                    tracing::error!(error = %e, "Error waiting for child");
556                }
557                Err(_) => {
558                    tracing::warn!("Timeout waiting for child, killing");
559                    let _ = child.kill().await;
560                }
561            }
562        }
563
564        Ok(())
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571    use std::collections::VecDeque;
572    use std::sync::{Arc, Mutex};
573
574    /// Mock transport that returns preconfigured responses
575    struct MockTransport {
576        responses: Arc<Mutex<VecDeque<serde_json::Value>>>,
577        requests: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
578        notifications: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
579        connected: bool,
580    }
581
582    impl MockTransport {
583        fn new() -> Self {
584            Self {
585                responses: Arc::new(Mutex::new(VecDeque::new())),
586                requests: Arc::new(Mutex::new(Vec::new())),
587                notifications: Arc::new(Mutex::new(Vec::new())),
588                connected: true,
589            }
590        }
591
592        fn with_responses(responses: Vec<serde_json::Value>) -> Self {
593            Self {
594                responses: Arc::new(Mutex::new(responses.into())),
595                requests: Arc::new(Mutex::new(Vec::new())),
596                notifications: Arc::new(Mutex::new(Vec::new())),
597                connected: true,
598            }
599        }
600
601        #[allow(dead_code)]
602        fn get_requests(&self) -> Vec<(String, serde_json::Value)> {
603            self.requests.lock().unwrap().clone()
604        }
605
606        #[allow(dead_code)]
607        fn get_notifications(&self) -> Vec<(String, serde_json::Value)> {
608            self.notifications.lock().unwrap().clone()
609        }
610    }
611
612    #[async_trait]
613    impl ClientTransport for MockTransport {
614        async fn request(
615            &mut self,
616            method: &str,
617            params: serde_json::Value,
618        ) -> Result<serde_json::Value> {
619            self.requests
620                .lock()
621                .unwrap()
622                .push((method.to_string(), params));
623            self.responses
624                .lock()
625                .unwrap()
626                .pop_front()
627                .ok_or_else(|| Error::Transport("No more mock responses".to_string()))
628        }
629
630        async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
631            self.notifications
632                .lock()
633                .unwrap()
634                .push((method.to_string(), params));
635            Ok(())
636        }
637
638        fn is_connected(&self) -> bool {
639            self.connected
640        }
641
642        async fn close(self: Box<Self>) -> Result<()> {
643            Ok(())
644        }
645    }
646
647    fn mock_initialize_response() -> serde_json::Value {
648        serde_json::json!({
649            "protocolVersion": "2025-11-25",
650            "serverInfo": {
651                "name": "test-server",
652                "version": "1.0.0"
653            },
654            "capabilities": {
655                "tools": {}
656            }
657        })
658    }
659
660    #[tokio::test]
661    async fn test_client_not_initialized() {
662        let mut client = McpClient::new(MockTransport::new());
663
664        // Should fail because not initialized
665        let result = client.list_tools().await;
666        assert!(result.is_err());
667        assert!(result.unwrap_err().to_string().contains("not initialized"));
668    }
669
670    #[tokio::test]
671    async fn test_client_initialize() {
672        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
673        let mut client = McpClient::new(transport);
674
675        assert!(!client.is_initialized());
676
677        let result = client.initialize("test-client", "1.0.0").await;
678        assert!(result.is_ok());
679        assert!(client.is_initialized());
680
681        let server_info = client.server_info().unwrap();
682        assert_eq!(server_info.server_info.name, "test-server");
683    }
684
685    #[tokio::test]
686    async fn test_list_tools() {
687        let transport = MockTransport::with_responses(vec![
688            mock_initialize_response(),
689            serde_json::json!({
690                "tools": [
691                    {
692                        "name": "test_tool",
693                        "description": "A test tool",
694                        "inputSchema": {
695                            "type": "object",
696                            "properties": {}
697                        }
698                    }
699                ]
700            }),
701        ]);
702        let mut client = McpClient::new(transport);
703
704        client.initialize("test-client", "1.0.0").await.unwrap();
705        let tools = client.list_tools().await.unwrap();
706
707        assert_eq!(tools.tools.len(), 1);
708        assert_eq!(tools.tools[0].name, "test_tool");
709    }
710
711    #[tokio::test]
712    async fn test_call_tool() {
713        let transport = MockTransport::with_responses(vec![
714            mock_initialize_response(),
715            serde_json::json!({
716                "content": [
717                    {
718                        "type": "text",
719                        "text": "Tool result"
720                    }
721                ]
722            }),
723        ]);
724        let mut client = McpClient::new(transport);
725
726        client.initialize("test-client", "1.0.0").await.unwrap();
727        let result = client
728            .call_tool("test_tool", serde_json::json!({"arg": "value"}))
729            .await
730            .unwrap();
731
732        assert!(!result.content.is_empty());
733    }
734
735    #[tokio::test]
736    async fn test_list_resources() {
737        let transport = MockTransport::with_responses(vec![
738            mock_initialize_response(),
739            serde_json::json!({
740                "resources": [
741                    {
742                        "uri": "file://test.txt",
743                        "name": "Test File"
744                    }
745                ]
746            }),
747        ]);
748        let mut client = McpClient::new(transport);
749
750        client.initialize("test-client", "1.0.0").await.unwrap();
751        let resources = client.list_resources().await.unwrap();
752
753        assert_eq!(resources.resources.len(), 1);
754        assert_eq!(resources.resources[0].uri, "file://test.txt");
755    }
756
757    #[tokio::test]
758    async fn test_read_resource() {
759        let transport = MockTransport::with_responses(vec![
760            mock_initialize_response(),
761            serde_json::json!({
762                "contents": [
763                    {
764                        "uri": "file://test.txt",
765                        "text": "File contents"
766                    }
767                ]
768            }),
769        ]);
770        let mut client = McpClient::new(transport);
771
772        client.initialize("test-client", "1.0.0").await.unwrap();
773        let result = client.read_resource("file://test.txt").await.unwrap();
774
775        assert_eq!(result.contents.len(), 1);
776        assert_eq!(result.contents[0].text.as_deref(), Some("File contents"));
777    }
778
779    #[tokio::test]
780    async fn test_list_prompts() {
781        let transport = MockTransport::with_responses(vec![
782            mock_initialize_response(),
783            serde_json::json!({
784                "prompts": [
785                    {
786                        "name": "test_prompt",
787                        "description": "A test prompt"
788                    }
789                ]
790            }),
791        ]);
792        let mut client = McpClient::new(transport);
793
794        client.initialize("test-client", "1.0.0").await.unwrap();
795        let prompts = client.list_prompts().await.unwrap();
796
797        assert_eq!(prompts.prompts.len(), 1);
798        assert_eq!(prompts.prompts[0].name, "test_prompt");
799    }
800
801    #[tokio::test]
802    async fn test_get_prompt() {
803        let transport = MockTransport::with_responses(vec![
804            mock_initialize_response(),
805            serde_json::json!({
806                "messages": [
807                    {
808                        "role": "user",
809                        "content": {
810                            "type": "text",
811                            "text": "Prompt message"
812                        }
813                    }
814                ]
815            }),
816        ]);
817        let mut client = McpClient::new(transport);
818
819        client.initialize("test-client", "1.0.0").await.unwrap();
820        let result = client.get_prompt("test_prompt", None).await.unwrap();
821
822        assert_eq!(result.messages.len(), 1);
823    }
824
825    #[tokio::test]
826    async fn test_ping() {
827        let transport =
828            MockTransport::with_responses(vec![mock_initialize_response(), serde_json::json!({})]);
829        let mut client = McpClient::new(transport);
830
831        client.initialize("test-client", "1.0.0").await.unwrap();
832        let result = client.ping().await;
833
834        assert!(result.is_ok());
835    }
836
837    #[tokio::test]
838    async fn test_roots_management() {
839        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
840        let notifications = transport.notifications.clone();
841        let mut client = McpClient::new(transport);
842
843        // Initially no roots
844        assert!(client.roots().is_empty());
845
846        // Add a root before initialization (no notification)
847        client.add_root(Root::new("file:///project")).await.unwrap();
848        assert_eq!(client.roots().len(), 1);
849        assert!(notifications.lock().unwrap().is_empty());
850
851        // Initialize
852        client.initialize("test-client", "1.0.0").await.unwrap();
853
854        // Add another root after initialization (should notify)
855        client.add_root(Root::new("file:///other")).await.unwrap();
856        assert_eq!(client.roots().len(), 2);
857        assert_eq!(notifications.lock().unwrap().len(), 2); // initialized + roots changed
858
859        // Remove a root
860        let removed = client.remove_root("file:///project").await.unwrap();
861        assert!(removed);
862        assert_eq!(client.roots().len(), 1);
863
864        // Try to remove non-existent root
865        let not_removed = client.remove_root("file:///nonexistent").await.unwrap();
866        assert!(!not_removed);
867    }
868
869    #[tokio::test]
870    async fn test_with_roots() {
871        let roots = vec![Root::new("file:///test")];
872        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
873        let client = McpClient::new(transport).with_roots(roots);
874
875        assert_eq!(client.roots().len(), 1);
876        assert!(client.capabilities.roots.is_some());
877    }
878
879    #[tokio::test]
880    async fn test_with_capabilities() {
881        let capabilities = ClientCapabilities {
882            sampling: Some(Default::default()),
883            ..Default::default()
884        };
885
886        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
887        let client = McpClient::new(transport).with_capabilities(capabilities);
888
889        assert!(client.capabilities.sampling.is_some());
890    }
891
892    #[tokio::test]
893    async fn test_list_roots() {
894        let roots = vec![
895            Root::new("file:///project1"),
896            Root::with_name("file:///project2", "Project 2"),
897        ];
898        let transport = MockTransport::new();
899        let client = McpClient::new(transport).with_roots(roots);
900
901        let result = client.list_roots();
902        assert_eq!(result.roots.len(), 2);
903        assert_eq!(result.roots[1].name, Some("Project 2".to_string()));
904    }
905}