Skip to main content

tower_mcp/
client.rs

1//! MCP Client implementation
2//!
3//! Provides client functionality for connecting to MCP servers.
4//!
5//! # Example
6//!
7//! ```rust,no_run
8//! use tower_mcp::client::{McpClient, StdioClientTransport};
9//!
10//! #[tokio::main]
11//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
12//!     // Connect to an MCP server via stdio
13//!     let transport = StdioClientTransport::spawn("my-mcp-server", &["--flag"]).await?;
14//!     let mut client = McpClient::new(transport);
15//!
16//!     // Initialize the connection
17//!     let server_info = client.initialize("my-client", "1.0.0").await?;
18//!     println!("Connected to: {}", server_info.server_info.name);
19//!
20//!     // List available tools
21//!     let tools = client.list_tools().await?;
22//!     for tool in &tools.tools {
23//!         println!("Tool: {}", tool.name);
24//!     }
25//!
26//!     // Call a tool
27//!     let result = client.call_tool("my-tool", serde_json::json!({"arg": "value"})).await?;
28//!     println!("Result: {:?}", result);
29//!
30//!     Ok(())
31//! }
32//! ```
33
34use std::process::Stdio;
35use std::sync::atomic::{AtomicI64, Ordering};
36
37use async_trait::async_trait;
38use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
39use tokio::process::{Child, Command};
40
41use crate::error::{Error, Result};
42use crate::protocol::{
43    CallToolParams, CallToolResult, ClientCapabilities, CompleteParams, CompleteResult,
44    CompletionArgument, CompletionReference, GetPromptParams, GetPromptResult, Implementation,
45    InitializeParams, InitializeResult, JsonRpcRequest, JsonRpcResponse, ListPromptsParams,
46    ListPromptsResult, ListResourcesParams, ListResourcesResult, ListRootsResult, ListToolsParams,
47    ListToolsResult, ReadResourceParams, ReadResourceResult, Root, RootsCapability, notifications,
48};
49
50/// Trait for MCP client transports
51#[async_trait]
52pub trait ClientTransport: Send {
53    /// Send a request and receive a response
54    async fn request(
55        &mut self,
56        method: &str,
57        params: serde_json::Value,
58    ) -> Result<serde_json::Value>;
59
60    /// Send a notification (no response expected)
61    async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()>;
62
63    /// Check if the transport is still connected
64    fn is_connected(&self) -> bool;
65
66    /// Close the transport
67    async fn close(self: Box<Self>) -> Result<()>;
68}
69
70/// MCP Client for connecting to MCP servers
71pub struct McpClient<T: ClientTransport> {
72    transport: T,
73    initialized: bool,
74    server_info: Option<InitializeResult>,
75    /// Client capabilities to declare during initialization
76    capabilities: ClientCapabilities,
77    /// Roots available to the server
78    roots: Vec<Root>,
79}
80
81impl<T: ClientTransport> McpClient<T> {
82    /// Create a new MCP client with the given transport
83    pub fn new(transport: T) -> Self {
84        Self {
85            transport,
86            initialized: false,
87            server_info: None,
88            capabilities: ClientCapabilities::default(),
89            roots: Vec::new(),
90        }
91    }
92
93    /// Create a new MCP client with roots capability
94    ///
95    /// The client will declare roots support during initialization.
96    pub fn with_roots(transport: T, roots: Vec<Root>) -> Self {
97        Self {
98            transport,
99            initialized: false,
100            server_info: None,
101            capabilities: ClientCapabilities {
102                roots: Some(RootsCapability { list_changed: true }),
103                ..Default::default()
104            },
105            roots,
106        }
107    }
108
109    /// Create a new MCP client with custom capabilities
110    pub fn with_capabilities(transport: T, capabilities: ClientCapabilities) -> Self {
111        Self {
112            transport,
113            initialized: false,
114            server_info: None,
115            capabilities,
116            roots: Vec::new(),
117        }
118    }
119
120    /// Get the server info (available after initialization)
121    pub fn server_info(&self) -> Option<&InitializeResult> {
122        self.server_info.as_ref()
123    }
124
125    /// Check if the client is initialized
126    pub fn is_initialized(&self) -> bool {
127        self.initialized
128    }
129
130    /// Get the current roots
131    pub fn roots(&self) -> &[Root] {
132        &self.roots
133    }
134
135    /// Set roots and notify the server if initialized
136    ///
137    /// If the client is already initialized, sends a roots list changed notification.
138    pub async fn set_roots(&mut self, roots: Vec<Root>) -> Result<()> {
139        self.roots = roots;
140        if self.initialized {
141            self.notify_roots_changed().await?;
142        }
143        Ok(())
144    }
145
146    /// Add a root and notify the server if initialized
147    pub async fn add_root(&mut self, root: Root) -> Result<()> {
148        self.roots.push(root);
149        if self.initialized {
150            self.notify_roots_changed().await?;
151        }
152        Ok(())
153    }
154
155    /// Remove a root by URI and notify the server if initialized
156    pub async fn remove_root(&mut self, uri: &str) -> Result<bool> {
157        let initial_len = self.roots.len();
158        self.roots.retain(|r| r.uri != uri);
159        let removed = self.roots.len() < initial_len;
160        if removed && self.initialized {
161            self.notify_roots_changed().await?;
162        }
163        Ok(removed)
164    }
165
166    /// Send roots list changed notification to the server
167    async fn notify_roots_changed(&mut self) -> Result<()> {
168        self.transport
169            .notify(notifications::ROOTS_LIST_CHANGED, serde_json::json!({}))
170            .await
171    }
172
173    /// Get the roots list result (for responding to server's roots/list request)
174    ///
175    /// Returns a result suitable for responding to a roots/list request from the server.
176    pub fn list_roots(&self) -> ListRootsResult {
177        ListRootsResult {
178            roots: self.roots.clone(),
179        }
180    }
181
182    /// Initialize the MCP connection
183    pub async fn initialize(
184        &mut self,
185        client_name: &str,
186        client_version: &str,
187    ) -> Result<&InitializeResult> {
188        let params = InitializeParams {
189            protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
190            capabilities: self.capabilities.clone(),
191            client_info: Implementation {
192                name: client_name.to_string(),
193                version: client_version.to_string(),
194            },
195        };
196
197        let result: InitializeResult = self.request("initialize", &params).await?;
198        self.server_info = Some(result);
199
200        // Send initialized notification
201        self.transport
202            .notify("notifications/initialized", serde_json::json!({}))
203            .await?;
204
205        self.initialized = true;
206
207        Ok(self.server_info.as_ref().unwrap())
208    }
209
210    /// List available tools
211    pub async fn list_tools(&mut self) -> Result<ListToolsResult> {
212        self.ensure_initialized()?;
213        self.request("tools/list", &ListToolsParams { cursor: None })
214            .await
215    }
216
217    /// Call a tool
218    pub async fn call_tool(
219        &mut self,
220        name: &str,
221        arguments: serde_json::Value,
222    ) -> Result<CallToolResult> {
223        self.ensure_initialized()?;
224        let params = CallToolParams {
225            name: name.to_string(),
226            arguments,
227            meta: None,
228        };
229        self.request("tools/call", &params).await
230    }
231
232    /// List available resources
233    pub async fn list_resources(&mut self) -> Result<ListResourcesResult> {
234        self.ensure_initialized()?;
235        self.request("resources/list", &ListResourcesParams { cursor: None })
236            .await
237    }
238
239    /// Read a resource
240    pub async fn read_resource(&mut self, uri: &str) -> Result<ReadResourceResult> {
241        self.ensure_initialized()?;
242        let params = ReadResourceParams {
243            uri: uri.to_string(),
244        };
245        self.request("resources/read", &params).await
246    }
247
248    /// List available prompts
249    pub async fn list_prompts(&mut self) -> Result<ListPromptsResult> {
250        self.ensure_initialized()?;
251        self.request("prompts/list", &ListPromptsParams { cursor: None })
252            .await
253    }
254
255    /// Get a prompt
256    pub async fn get_prompt(
257        &mut self,
258        name: &str,
259        arguments: Option<std::collections::HashMap<String, String>>,
260    ) -> Result<GetPromptResult> {
261        self.ensure_initialized()?;
262        let params = GetPromptParams {
263            name: name.to_string(),
264            arguments: arguments.unwrap_or_default(),
265        };
266        self.request("prompts/get", &params).await
267    }
268
269    /// Ping the server
270    pub async fn ping(&mut self) -> Result<()> {
271        let _: serde_json::Value = self.request("ping", &serde_json::json!({})).await?;
272        Ok(())
273    }
274
275    /// Request completion suggestions from the server
276    ///
277    /// This is used to get autocomplete suggestions for prompt arguments or resource URIs.
278    pub async fn complete(
279        &mut self,
280        reference: CompletionReference,
281        argument_name: &str,
282        argument_value: &str,
283    ) -> Result<CompleteResult> {
284        self.ensure_initialized()?;
285        let params = CompleteParams {
286            reference,
287            argument: CompletionArgument::new(argument_name, argument_value),
288        };
289        self.request("completion/complete", &params).await
290    }
291
292    /// Request completion for a prompt argument
293    pub async fn complete_prompt_arg(
294        &mut self,
295        prompt_name: &str,
296        argument_name: &str,
297        argument_value: &str,
298    ) -> Result<CompleteResult> {
299        self.complete(
300            CompletionReference::prompt(prompt_name),
301            argument_name,
302            argument_value,
303        )
304        .await
305    }
306
307    /// Request completion for a resource URI
308    pub async fn complete_resource_uri(
309        &mut self,
310        resource_uri: &str,
311        argument_name: &str,
312        argument_value: &str,
313    ) -> Result<CompleteResult> {
314        self.complete(
315            CompletionReference::resource(resource_uri),
316            argument_name,
317            argument_value,
318        )
319        .await
320    }
321
322    /// Send a raw request
323    pub async fn request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
324        &mut self,
325        method: &str,
326        params: &P,
327    ) -> Result<R> {
328        let params_value = serde_json::to_value(params)
329            .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
330
331        let result = self.transport.request(method, params_value).await?;
332
333        serde_json::from_value(result)
334            .map_err(|e| Error::Transport(format!("Failed to deserialize response: {}", e)))
335    }
336
337    /// Send a notification
338    pub async fn notify<P: serde::Serialize>(&mut self, method: &str, params: &P) -> Result<()> {
339        let params_value = serde_json::to_value(params)
340            .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
341
342        self.transport.notify(method, params_value).await
343    }
344
345    fn ensure_initialized(&self) -> Result<()> {
346        if !self.initialized {
347            return Err(Error::Transport("Client not initialized".to_string()));
348        }
349        Ok(())
350    }
351}
352
353// ============================================================================
354// Stdio Client Transport
355// ============================================================================
356
357/// Client transport that communicates with a subprocess via stdio
358pub struct StdioClientTransport {
359    child: Option<Child>,
360    stdin: tokio::process::ChildStdin,
361    stdout: BufReader<tokio::process::ChildStdout>,
362    request_id: AtomicI64,
363}
364
365impl StdioClientTransport {
366    /// Spawn a new subprocess and connect to it
367    pub async fn spawn(program: &str, args: &[&str]) -> Result<Self> {
368        let mut cmd = Command::new(program);
369        cmd.args(args)
370            .stdin(Stdio::piped())
371            .stdout(Stdio::piped())
372            .stderr(Stdio::inherit());
373
374        let mut child = cmd
375            .spawn()
376            .map_err(|e| Error::Transport(format!("Failed to spawn {}: {}", program, e)))?;
377
378        let stdin = child
379            .stdin
380            .take()
381            .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
382        let stdout = child
383            .stdout
384            .take()
385            .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
386
387        tracing::info!(program = %program, "Spawned MCP server process");
388
389        Ok(Self {
390            child: Some(child),
391            stdin,
392            stdout: BufReader::new(stdout),
393            request_id: AtomicI64::new(1),
394        })
395    }
396
397    /// Create from an existing child process
398    pub fn from_child(mut child: Child) -> Result<Self> {
399        let stdin = child
400            .stdin
401            .take()
402            .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
403        let stdout = child
404            .stdout
405            .take()
406            .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
407
408        Ok(Self {
409            child: Some(child),
410            stdin,
411            stdout: BufReader::new(stdout),
412            request_id: AtomicI64::new(1),
413        })
414    }
415
416    async fn send_line(&mut self, line: &str) -> Result<()> {
417        self.stdin
418            .write_all(line.as_bytes())
419            .await
420            .map_err(|e| Error::Transport(format!("Failed to write: {}", e)))?;
421        self.stdin
422            .write_all(b"\n")
423            .await
424            .map_err(|e| Error::Transport(format!("Failed to write newline: {}", e)))?;
425        self.stdin
426            .flush()
427            .await
428            .map_err(|e| Error::Transport(format!("Failed to flush: {}", e)))?;
429        Ok(())
430    }
431
432    async fn read_line(&mut self) -> Result<String> {
433        let mut line = String::new();
434        self.stdout
435            .read_line(&mut line)
436            .await
437            .map_err(|e| Error::Transport(format!("Failed to read: {}", e)))?;
438
439        if line.is_empty() {
440            return Err(Error::Transport("Connection closed".to_string()));
441        }
442
443        Ok(line)
444    }
445}
446
447#[async_trait]
448impl ClientTransport for StdioClientTransport {
449    async fn request(
450        &mut self,
451        method: &str,
452        params: serde_json::Value,
453    ) -> Result<serde_json::Value> {
454        let id = self.request_id.fetch_add(1, Ordering::Relaxed);
455        let request = JsonRpcRequest::new(id, method).with_params(params);
456
457        let request_json = serde_json::to_string(&request)
458            .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
459
460        tracing::debug!(method = %method, id = %id, "Sending request");
461        self.send_line(&request_json).await?;
462
463        let response_line = self.read_line().await?;
464        tracing::debug!(response = %response_line.trim(), "Received response");
465
466        let response: JsonRpcResponse = serde_json::from_str(response_line.trim())
467            .map_err(|e| Error::Transport(format!("Failed to parse response: {}", e)))?;
468
469        match response {
470            JsonRpcResponse::Result(r) => Ok(r.result),
471            JsonRpcResponse::Error(e) => Err(Error::JsonRpc(e.error)),
472        }
473    }
474
475    async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
476        let notification = serde_json::json!({
477            "jsonrpc": "2.0",
478            "method": method,
479            "params": params
480        });
481
482        let json = serde_json::to_string(&notification)
483            .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
484
485        tracing::debug!(method = %method, "Sending notification");
486        self.send_line(&json).await
487    }
488
489    fn is_connected(&self) -> bool {
490        // Assume connected if we have a child process handle
491        self.child.is_some()
492    }
493
494    async fn close(mut self: Box<Self>) -> Result<()> {
495        // Close stdin to signal EOF
496        drop(self.stdin);
497
498        if let Some(mut child) = self.child.take() {
499            // Wait for process with timeout
500            let result =
501                tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
502
503            match result {
504                Ok(Ok(status)) => {
505                    tracing::info!(status = ?status, "Child process exited");
506                }
507                Ok(Err(e)) => {
508                    tracing::error!(error = %e, "Error waiting for child");
509                }
510                Err(_) => {
511                    tracing::warn!("Timeout waiting for child, killing");
512                    let _ = child.kill().await;
513                }
514            }
515        }
516
517        Ok(())
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524    use std::collections::VecDeque;
525    use std::sync::{Arc, Mutex};
526
527    /// Mock transport that returns preconfigured responses
528    struct MockTransport {
529        responses: Arc<Mutex<VecDeque<serde_json::Value>>>,
530        requests: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
531        notifications: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
532        connected: bool,
533    }
534
535    impl MockTransport {
536        fn new() -> Self {
537            Self {
538                responses: Arc::new(Mutex::new(VecDeque::new())),
539                requests: Arc::new(Mutex::new(Vec::new())),
540                notifications: Arc::new(Mutex::new(Vec::new())),
541                connected: true,
542            }
543        }
544
545        fn with_responses(responses: Vec<serde_json::Value>) -> Self {
546            Self {
547                responses: Arc::new(Mutex::new(responses.into())),
548                requests: Arc::new(Mutex::new(Vec::new())),
549                notifications: Arc::new(Mutex::new(Vec::new())),
550                connected: true,
551            }
552        }
553
554        #[allow(dead_code)]
555        fn get_requests(&self) -> Vec<(String, serde_json::Value)> {
556            self.requests.lock().unwrap().clone()
557        }
558
559        #[allow(dead_code)]
560        fn get_notifications(&self) -> Vec<(String, serde_json::Value)> {
561            self.notifications.lock().unwrap().clone()
562        }
563    }
564
565    #[async_trait]
566    impl ClientTransport for MockTransport {
567        async fn request(
568            &mut self,
569            method: &str,
570            params: serde_json::Value,
571        ) -> Result<serde_json::Value> {
572            self.requests
573                .lock()
574                .unwrap()
575                .push((method.to_string(), params));
576            self.responses
577                .lock()
578                .unwrap()
579                .pop_front()
580                .ok_or_else(|| Error::Transport("No more mock responses".to_string()))
581        }
582
583        async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
584            self.notifications
585                .lock()
586                .unwrap()
587                .push((method.to_string(), params));
588            Ok(())
589        }
590
591        fn is_connected(&self) -> bool {
592            self.connected
593        }
594
595        async fn close(self: Box<Self>) -> Result<()> {
596            Ok(())
597        }
598    }
599
600    fn mock_initialize_response() -> serde_json::Value {
601        serde_json::json!({
602            "protocolVersion": "2025-03-26",
603            "serverInfo": {
604                "name": "test-server",
605                "version": "1.0.0"
606            },
607            "capabilities": {
608                "tools": {}
609            }
610        })
611    }
612
613    #[tokio::test]
614    async fn test_client_not_initialized() {
615        let mut client = McpClient::new(MockTransport::new());
616
617        // Should fail because not initialized
618        let result = client.list_tools().await;
619        assert!(result.is_err());
620        assert!(result.unwrap_err().to_string().contains("not initialized"));
621    }
622
623    #[tokio::test]
624    async fn test_client_initialize() {
625        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
626        let mut client = McpClient::new(transport);
627
628        assert!(!client.is_initialized());
629
630        let result = client.initialize("test-client", "1.0.0").await;
631        assert!(result.is_ok());
632        assert!(client.is_initialized());
633
634        let server_info = client.server_info().unwrap();
635        assert_eq!(server_info.server_info.name, "test-server");
636    }
637
638    #[tokio::test]
639    async fn test_list_tools() {
640        let transport = MockTransport::with_responses(vec![
641            mock_initialize_response(),
642            serde_json::json!({
643                "tools": [
644                    {
645                        "name": "test_tool",
646                        "description": "A test tool",
647                        "inputSchema": {
648                            "type": "object",
649                            "properties": {}
650                        }
651                    }
652                ]
653            }),
654        ]);
655        let mut client = McpClient::new(transport);
656
657        client.initialize("test-client", "1.0.0").await.unwrap();
658        let tools = client.list_tools().await.unwrap();
659
660        assert_eq!(tools.tools.len(), 1);
661        assert_eq!(tools.tools[0].name, "test_tool");
662    }
663
664    #[tokio::test]
665    async fn test_call_tool() {
666        let transport = MockTransport::with_responses(vec![
667            mock_initialize_response(),
668            serde_json::json!({
669                "content": [
670                    {
671                        "type": "text",
672                        "text": "Tool result"
673                    }
674                ]
675            }),
676        ]);
677        let mut client = McpClient::new(transport);
678
679        client.initialize("test-client", "1.0.0").await.unwrap();
680        let result = client
681            .call_tool("test_tool", serde_json::json!({"arg": "value"}))
682            .await
683            .unwrap();
684
685        assert!(!result.content.is_empty());
686    }
687
688    #[tokio::test]
689    async fn test_list_resources() {
690        let transport = MockTransport::with_responses(vec![
691            mock_initialize_response(),
692            serde_json::json!({
693                "resources": [
694                    {
695                        "uri": "file://test.txt",
696                        "name": "Test File"
697                    }
698                ]
699            }),
700        ]);
701        let mut client = McpClient::new(transport);
702
703        client.initialize("test-client", "1.0.0").await.unwrap();
704        let resources = client.list_resources().await.unwrap();
705
706        assert_eq!(resources.resources.len(), 1);
707        assert_eq!(resources.resources[0].uri, "file://test.txt");
708    }
709
710    #[tokio::test]
711    async fn test_read_resource() {
712        let transport = MockTransport::with_responses(vec![
713            mock_initialize_response(),
714            serde_json::json!({
715                "contents": [
716                    {
717                        "uri": "file://test.txt",
718                        "text": "File contents"
719                    }
720                ]
721            }),
722        ]);
723        let mut client = McpClient::new(transport);
724
725        client.initialize("test-client", "1.0.0").await.unwrap();
726        let result = client.read_resource("file://test.txt").await.unwrap();
727
728        assert_eq!(result.contents.len(), 1);
729        assert_eq!(result.contents[0].text.as_deref(), Some("File contents"));
730    }
731
732    #[tokio::test]
733    async fn test_list_prompts() {
734        let transport = MockTransport::with_responses(vec![
735            mock_initialize_response(),
736            serde_json::json!({
737                "prompts": [
738                    {
739                        "name": "test_prompt",
740                        "description": "A test prompt"
741                    }
742                ]
743            }),
744        ]);
745        let mut client = McpClient::new(transport);
746
747        client.initialize("test-client", "1.0.0").await.unwrap();
748        let prompts = client.list_prompts().await.unwrap();
749
750        assert_eq!(prompts.prompts.len(), 1);
751        assert_eq!(prompts.prompts[0].name, "test_prompt");
752    }
753
754    #[tokio::test]
755    async fn test_get_prompt() {
756        let transport = MockTransport::with_responses(vec![
757            mock_initialize_response(),
758            serde_json::json!({
759                "messages": [
760                    {
761                        "role": "user",
762                        "content": {
763                            "type": "text",
764                            "text": "Prompt message"
765                        }
766                    }
767                ]
768            }),
769        ]);
770        let mut client = McpClient::new(transport);
771
772        client.initialize("test-client", "1.0.0").await.unwrap();
773        let result = client.get_prompt("test_prompt", None).await.unwrap();
774
775        assert_eq!(result.messages.len(), 1);
776    }
777
778    #[tokio::test]
779    async fn test_ping() {
780        let transport =
781            MockTransport::with_responses(vec![mock_initialize_response(), serde_json::json!({})]);
782        let mut client = McpClient::new(transport);
783
784        client.initialize("test-client", "1.0.0").await.unwrap();
785        let result = client.ping().await;
786
787        assert!(result.is_ok());
788    }
789
790    #[tokio::test]
791    async fn test_roots_management() {
792        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
793        let notifications = transport.notifications.clone();
794        let mut client = McpClient::new(transport);
795
796        // Initially no roots
797        assert!(client.roots().is_empty());
798
799        // Add a root before initialization (no notification)
800        client.add_root(Root::new("file:///project")).await.unwrap();
801        assert_eq!(client.roots().len(), 1);
802        assert!(notifications.lock().unwrap().is_empty());
803
804        // Initialize
805        client.initialize("test-client", "1.0.0").await.unwrap();
806
807        // Add another root after initialization (should notify)
808        client.add_root(Root::new("file:///other")).await.unwrap();
809        assert_eq!(client.roots().len(), 2);
810        assert_eq!(notifications.lock().unwrap().len(), 2); // initialized + roots changed
811
812        // Remove a root
813        let removed = client.remove_root("file:///project").await.unwrap();
814        assert!(removed);
815        assert_eq!(client.roots().len(), 1);
816
817        // Try to remove non-existent root
818        let not_removed = client.remove_root("file:///nonexistent").await.unwrap();
819        assert!(!not_removed);
820    }
821
822    #[tokio::test]
823    async fn test_with_roots() {
824        let roots = vec![Root::new("file:///test")];
825        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
826        let client = McpClient::with_roots(transport, roots);
827
828        assert_eq!(client.roots().len(), 1);
829        assert!(client.capabilities.roots.is_some());
830    }
831
832    #[tokio::test]
833    async fn test_with_capabilities() {
834        let capabilities = ClientCapabilities {
835            sampling: Some(Default::default()),
836            ..Default::default()
837        };
838
839        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
840        let client = McpClient::with_capabilities(transport, capabilities);
841
842        assert!(client.capabilities.sampling.is_some());
843    }
844
845    #[tokio::test]
846    async fn test_list_roots() {
847        let roots = vec![
848            Root::new("file:///project1"),
849            Root::with_name("file:///project2", "Project 2"),
850        ];
851        let transport = MockTransport::new();
852        let client = McpClient::with_roots(transport, roots);
853
854        let result = client.list_roots();
855        assert_eq!(result.roots.len(), 2);
856        assert_eq!(result.roots[1].name, Some("Project 2".to_string()));
857    }
858}