1use std::process::Stdio;
36use std::sync::atomic::{AtomicI64, Ordering};
37
38use async_trait::async_trait;
39use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
40use tokio::process::{Child, Command};
41
42use crate::error::{Error, Result};
43use crate::protocol::{
44 CallToolParams, CallToolResult, ClientCapabilities, CompleteParams, CompleteResult,
45 CompletionArgument, CompletionReference, GetPromptParams, GetPromptResult, Implementation,
46 InitializeParams, InitializeResult, JsonRpcRequest, JsonRpcResponse, ListPromptsParams,
47 ListPromptsResult, ListResourcesParams, ListResourcesResult, ListRootsResult, ListToolsParams,
48 ListToolsResult, ReadResourceParams, ReadResourceResult, Root, RootsCapability, notifications,
49};
50
51#[async_trait]
53pub trait ClientTransport: Send {
54 async fn request(
56 &mut self,
57 method: &str,
58 params: serde_json::Value,
59 ) -> Result<serde_json::Value>;
60
61 async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()>;
63
64 fn is_connected(&self) -> bool;
66
67 async fn close(self: Box<Self>) -> Result<()>;
69}
70
71pub struct McpClient<T: ClientTransport> {
73 transport: T,
74 initialized: bool,
75 server_info: Option<InitializeResult>,
76 capabilities: ClientCapabilities,
78 roots: Vec<Root>,
80}
81
82impl<T: ClientTransport> McpClient<T> {
83 pub fn new(transport: T) -> Self {
85 Self {
86 transport,
87 initialized: false,
88 server_info: None,
89 capabilities: ClientCapabilities::default(),
90 roots: Vec::new(),
91 }
92 }
93
94 pub fn with_roots(transport: T, roots: Vec<Root>) -> Self {
98 Self {
99 transport,
100 initialized: false,
101 server_info: None,
102 capabilities: ClientCapabilities {
103 roots: Some(RootsCapability { list_changed: true }),
104 ..Default::default()
105 },
106 roots,
107 }
108 }
109
110 pub fn with_capabilities(transport: T, capabilities: ClientCapabilities) -> Self {
112 Self {
113 transport,
114 initialized: false,
115 server_info: None,
116 capabilities,
117 roots: Vec::new(),
118 }
119 }
120
121 pub fn server_info(&self) -> Option<&InitializeResult> {
123 self.server_info.as_ref()
124 }
125
126 pub fn is_initialized(&self) -> bool {
128 self.initialized
129 }
130
131 pub fn roots(&self) -> &[Root] {
133 &self.roots
134 }
135
136 pub async fn set_roots(&mut self, roots: Vec<Root>) -> Result<()> {
140 self.roots = roots;
141 if self.initialized {
142 self.notify_roots_changed().await?;
143 }
144 Ok(())
145 }
146
147 pub async fn add_root(&mut self, root: Root) -> Result<()> {
149 self.roots.push(root);
150 if self.initialized {
151 self.notify_roots_changed().await?;
152 }
153 Ok(())
154 }
155
156 pub async fn remove_root(&mut self, uri: &str) -> Result<bool> {
158 let initial_len = self.roots.len();
159 self.roots.retain(|r| r.uri != uri);
160 let removed = self.roots.len() < initial_len;
161 if removed && self.initialized {
162 self.notify_roots_changed().await?;
163 }
164 Ok(removed)
165 }
166
167 async fn notify_roots_changed(&mut self) -> Result<()> {
169 self.transport
170 .notify(notifications::ROOTS_LIST_CHANGED, serde_json::json!({}))
171 .await
172 }
173
174 pub fn list_roots(&self) -> ListRootsResult {
178 ListRootsResult {
179 roots: self.roots.clone(),
180 }
181 }
182
183 pub async fn initialize(
185 &mut self,
186 client_name: &str,
187 client_version: &str,
188 ) -> Result<&InitializeResult> {
189 let params = InitializeParams {
190 protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
191 capabilities: self.capabilities.clone(),
192 client_info: Implementation {
193 name: client_name.to_string(),
194 version: client_version.to_string(),
195 ..Default::default()
196 },
197 };
198
199 let result: InitializeResult = self.request("initialize", ¶ms).await?;
200 self.server_info = Some(result);
201
202 self.transport
204 .notify("notifications/initialized", serde_json::json!({}))
205 .await?;
206
207 self.initialized = true;
208
209 Ok(self.server_info.as_ref().unwrap())
210 }
211
212 pub async fn list_tools(&mut self) -> Result<ListToolsResult> {
214 self.ensure_initialized()?;
215 self.request("tools/list", &ListToolsParams { cursor: None })
216 .await
217 }
218
219 pub async fn call_tool(
221 &mut self,
222 name: &str,
223 arguments: serde_json::Value,
224 ) -> Result<CallToolResult> {
225 self.ensure_initialized()?;
226 let params = CallToolParams {
227 name: name.to_string(),
228 arguments,
229 meta: None,
230 };
231 self.request("tools/call", ¶ms).await
232 }
233
234 pub async fn list_resources(&mut self) -> Result<ListResourcesResult> {
236 self.ensure_initialized()?;
237 self.request("resources/list", &ListResourcesParams { cursor: None })
238 .await
239 }
240
241 pub async fn read_resource(&mut self, uri: &str) -> Result<ReadResourceResult> {
243 self.ensure_initialized()?;
244 let params = ReadResourceParams {
245 uri: uri.to_string(),
246 };
247 self.request("resources/read", ¶ms).await
248 }
249
250 pub async fn list_prompts(&mut self) -> Result<ListPromptsResult> {
252 self.ensure_initialized()?;
253 self.request("prompts/list", &ListPromptsParams { cursor: None })
254 .await
255 }
256
257 pub async fn get_prompt(
259 &mut self,
260 name: &str,
261 arguments: Option<std::collections::HashMap<String, String>>,
262 ) -> Result<GetPromptResult> {
263 self.ensure_initialized()?;
264 let params = GetPromptParams {
265 name: name.to_string(),
266 arguments: arguments.unwrap_or_default(),
267 };
268 self.request("prompts/get", ¶ms).await
269 }
270
271 pub async fn ping(&mut self) -> Result<()> {
273 let _: serde_json::Value = self.request("ping", &serde_json::json!({})).await?;
274 Ok(())
275 }
276
277 pub async fn complete(
281 &mut self,
282 reference: CompletionReference,
283 argument_name: &str,
284 argument_value: &str,
285 ) -> Result<CompleteResult> {
286 self.ensure_initialized()?;
287 let params = CompleteParams {
288 reference,
289 argument: CompletionArgument::new(argument_name, argument_value),
290 };
291 self.request("completion/complete", ¶ms).await
292 }
293
294 pub async fn complete_prompt_arg(
296 &mut self,
297 prompt_name: &str,
298 argument_name: &str,
299 argument_value: &str,
300 ) -> Result<CompleteResult> {
301 self.complete(
302 CompletionReference::prompt(prompt_name),
303 argument_name,
304 argument_value,
305 )
306 .await
307 }
308
309 pub async fn complete_resource_uri(
311 &mut self,
312 resource_uri: &str,
313 argument_name: &str,
314 argument_value: &str,
315 ) -> Result<CompleteResult> {
316 self.complete(
317 CompletionReference::resource(resource_uri),
318 argument_name,
319 argument_value,
320 )
321 .await
322 }
323
324 pub async fn request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
326 &mut self,
327 method: &str,
328 params: &P,
329 ) -> Result<R> {
330 let params_value = serde_json::to_value(params)
331 .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
332
333 let result = self.transport.request(method, params_value).await?;
334
335 serde_json::from_value(result)
336 .map_err(|e| Error::Transport(format!("Failed to deserialize response: {}", e)))
337 }
338
339 pub async fn notify<P: serde::Serialize>(&mut self, method: &str, params: &P) -> Result<()> {
341 let params_value = serde_json::to_value(params)
342 .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
343
344 self.transport.notify(method, params_value).await
345 }
346
347 fn ensure_initialized(&self) -> Result<()> {
348 if !self.initialized {
349 return Err(Error::Transport("Client not initialized".to_string()));
350 }
351 Ok(())
352 }
353}
354
355pub struct StdioClientTransport {
361 child: Option<Child>,
362 stdin: tokio::process::ChildStdin,
363 stdout: BufReader<tokio::process::ChildStdout>,
364 request_id: AtomicI64,
365}
366
367impl StdioClientTransport {
368 pub async fn spawn(program: &str, args: &[&str]) -> Result<Self> {
370 let mut cmd = Command::new(program);
371 cmd.args(args)
372 .stdin(Stdio::piped())
373 .stdout(Stdio::piped())
374 .stderr(Stdio::inherit());
375
376 let mut child = cmd
377 .spawn()
378 .map_err(|e| Error::Transport(format!("Failed to spawn {}: {}", program, e)))?;
379
380 let stdin = child
381 .stdin
382 .take()
383 .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
384 let stdout = child
385 .stdout
386 .take()
387 .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
388
389 tracing::info!(program = %program, "Spawned MCP server process");
390
391 Ok(Self {
392 child: Some(child),
393 stdin,
394 stdout: BufReader::new(stdout),
395 request_id: AtomicI64::new(1),
396 })
397 }
398
399 pub fn from_child(mut child: Child) -> Result<Self> {
401 let stdin = child
402 .stdin
403 .take()
404 .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
405 let stdout = child
406 .stdout
407 .take()
408 .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
409
410 Ok(Self {
411 child: Some(child),
412 stdin,
413 stdout: BufReader::new(stdout),
414 request_id: AtomicI64::new(1),
415 })
416 }
417
418 async fn send_line(&mut self, line: &str) -> Result<()> {
419 self.stdin
420 .write_all(line.as_bytes())
421 .await
422 .map_err(|e| Error::Transport(format!("Failed to write: {}", e)))?;
423 self.stdin
424 .write_all(b"\n")
425 .await
426 .map_err(|e| Error::Transport(format!("Failed to write newline: {}", e)))?;
427 self.stdin
428 .flush()
429 .await
430 .map_err(|e| Error::Transport(format!("Failed to flush: {}", e)))?;
431 Ok(())
432 }
433
434 async fn read_line(&mut self) -> Result<String> {
435 let mut line = String::new();
436 self.stdout
437 .read_line(&mut line)
438 .await
439 .map_err(|e| Error::Transport(format!("Failed to read: {}", e)))?;
440
441 if line.is_empty() {
442 return Err(Error::Transport("Connection closed".to_string()));
443 }
444
445 Ok(line)
446 }
447}
448
449#[async_trait]
450impl ClientTransport for StdioClientTransport {
451 async fn request(
452 &mut self,
453 method: &str,
454 params: serde_json::Value,
455 ) -> Result<serde_json::Value> {
456 let id = self.request_id.fetch_add(1, Ordering::Relaxed);
457 let request = JsonRpcRequest::new(id, method).with_params(params);
458
459 let request_json = serde_json::to_string(&request)
460 .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
461
462 tracing::debug!(method = %method, id = %id, "Sending request");
463 self.send_line(&request_json).await?;
464
465 let response_line = self.read_line().await?;
466 tracing::debug!(response = %response_line.trim(), "Received response");
467
468 let response: JsonRpcResponse = serde_json::from_str(response_line.trim())
469 .map_err(|e| Error::Transport(format!("Failed to parse response: {}", e)))?;
470
471 match response {
472 JsonRpcResponse::Result(r) => Ok(r.result),
473 JsonRpcResponse::Error(e) => Err(Error::JsonRpc(e.error)),
474 }
475 }
476
477 async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
478 let notification = serde_json::json!({
479 "jsonrpc": "2.0",
480 "method": method,
481 "params": params
482 });
483
484 let json = serde_json::to_string(¬ification)
485 .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
486
487 tracing::debug!(method = %method, "Sending notification");
488 self.send_line(&json).await
489 }
490
491 fn is_connected(&self) -> bool {
492 self.child.is_some()
494 }
495
496 async fn close(mut self: Box<Self>) -> Result<()> {
497 drop(self.stdin);
499
500 if let Some(mut child) = self.child.take() {
501 let result =
503 tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
504
505 match result {
506 Ok(Ok(status)) => {
507 tracing::info!(status = ?status, "Child process exited");
508 }
509 Ok(Err(e)) => {
510 tracing::error!(error = %e, "Error waiting for child");
511 }
512 Err(_) => {
513 tracing::warn!("Timeout waiting for child, killing");
514 let _ = child.kill().await;
515 }
516 }
517 }
518
519 Ok(())
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526 use std::collections::VecDeque;
527 use std::sync::{Arc, Mutex};
528
529 struct MockTransport {
531 responses: Arc<Mutex<VecDeque<serde_json::Value>>>,
532 requests: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
533 notifications: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
534 connected: bool,
535 }
536
537 impl MockTransport {
538 fn new() -> Self {
539 Self {
540 responses: Arc::new(Mutex::new(VecDeque::new())),
541 requests: Arc::new(Mutex::new(Vec::new())),
542 notifications: Arc::new(Mutex::new(Vec::new())),
543 connected: true,
544 }
545 }
546
547 fn with_responses(responses: Vec<serde_json::Value>) -> Self {
548 Self {
549 responses: Arc::new(Mutex::new(responses.into())),
550 requests: Arc::new(Mutex::new(Vec::new())),
551 notifications: Arc::new(Mutex::new(Vec::new())),
552 connected: true,
553 }
554 }
555
556 #[allow(dead_code)]
557 fn get_requests(&self) -> Vec<(String, serde_json::Value)> {
558 self.requests.lock().unwrap().clone()
559 }
560
561 #[allow(dead_code)]
562 fn get_notifications(&self) -> Vec<(String, serde_json::Value)> {
563 self.notifications.lock().unwrap().clone()
564 }
565 }
566
567 #[async_trait]
568 impl ClientTransport for MockTransport {
569 async fn request(
570 &mut self,
571 method: &str,
572 params: serde_json::Value,
573 ) -> Result<serde_json::Value> {
574 self.requests
575 .lock()
576 .unwrap()
577 .push((method.to_string(), params));
578 self.responses
579 .lock()
580 .unwrap()
581 .pop_front()
582 .ok_or_else(|| Error::Transport("No more mock responses".to_string()))
583 }
584
585 async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
586 self.notifications
587 .lock()
588 .unwrap()
589 .push((method.to_string(), params));
590 Ok(())
591 }
592
593 fn is_connected(&self) -> bool {
594 self.connected
595 }
596
597 async fn close(self: Box<Self>) -> Result<()> {
598 Ok(())
599 }
600 }
601
602 fn mock_initialize_response() -> serde_json::Value {
603 serde_json::json!({
604 "protocolVersion": "2025-11-25",
605 "serverInfo": {
606 "name": "test-server",
607 "version": "1.0.0"
608 },
609 "capabilities": {
610 "tools": {}
611 }
612 })
613 }
614
615 #[tokio::test]
616 async fn test_client_not_initialized() {
617 let mut client = McpClient::new(MockTransport::new());
618
619 let result = client.list_tools().await;
621 assert!(result.is_err());
622 assert!(result.unwrap_err().to_string().contains("not initialized"));
623 }
624
625 #[tokio::test]
626 async fn test_client_initialize() {
627 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
628 let mut client = McpClient::new(transport);
629
630 assert!(!client.is_initialized());
631
632 let result = client.initialize("test-client", "1.0.0").await;
633 assert!(result.is_ok());
634 assert!(client.is_initialized());
635
636 let server_info = client.server_info().unwrap();
637 assert_eq!(server_info.server_info.name, "test-server");
638 }
639
640 #[tokio::test]
641 async fn test_list_tools() {
642 let transport = MockTransport::with_responses(vec![
643 mock_initialize_response(),
644 serde_json::json!({
645 "tools": [
646 {
647 "name": "test_tool",
648 "description": "A test tool",
649 "inputSchema": {
650 "type": "object",
651 "properties": {}
652 }
653 }
654 ]
655 }),
656 ]);
657 let mut client = McpClient::new(transport);
658
659 client.initialize("test-client", "1.0.0").await.unwrap();
660 let tools = client.list_tools().await.unwrap();
661
662 assert_eq!(tools.tools.len(), 1);
663 assert_eq!(tools.tools[0].name, "test_tool");
664 }
665
666 #[tokio::test]
667 async fn test_call_tool() {
668 let transport = MockTransport::with_responses(vec![
669 mock_initialize_response(),
670 serde_json::json!({
671 "content": [
672 {
673 "type": "text",
674 "text": "Tool result"
675 }
676 ]
677 }),
678 ]);
679 let mut client = McpClient::new(transport);
680
681 client.initialize("test-client", "1.0.0").await.unwrap();
682 let result = client
683 .call_tool("test_tool", serde_json::json!({"arg": "value"}))
684 .await
685 .unwrap();
686
687 assert!(!result.content.is_empty());
688 }
689
690 #[tokio::test]
691 async fn test_list_resources() {
692 let transport = MockTransport::with_responses(vec![
693 mock_initialize_response(),
694 serde_json::json!({
695 "resources": [
696 {
697 "uri": "file://test.txt",
698 "name": "Test File"
699 }
700 ]
701 }),
702 ]);
703 let mut client = McpClient::new(transport);
704
705 client.initialize("test-client", "1.0.0").await.unwrap();
706 let resources = client.list_resources().await.unwrap();
707
708 assert_eq!(resources.resources.len(), 1);
709 assert_eq!(resources.resources[0].uri, "file://test.txt");
710 }
711
712 #[tokio::test]
713 async fn test_read_resource() {
714 let transport = MockTransport::with_responses(vec![
715 mock_initialize_response(),
716 serde_json::json!({
717 "contents": [
718 {
719 "uri": "file://test.txt",
720 "text": "File contents"
721 }
722 ]
723 }),
724 ]);
725 let mut client = McpClient::new(transport);
726
727 client.initialize("test-client", "1.0.0").await.unwrap();
728 let result = client.read_resource("file://test.txt").await.unwrap();
729
730 assert_eq!(result.contents.len(), 1);
731 assert_eq!(result.contents[0].text.as_deref(), Some("File contents"));
732 }
733
734 #[tokio::test]
735 async fn test_list_prompts() {
736 let transport = MockTransport::with_responses(vec![
737 mock_initialize_response(),
738 serde_json::json!({
739 "prompts": [
740 {
741 "name": "test_prompt",
742 "description": "A test prompt"
743 }
744 ]
745 }),
746 ]);
747 let mut client = McpClient::new(transport);
748
749 client.initialize("test-client", "1.0.0").await.unwrap();
750 let prompts = client.list_prompts().await.unwrap();
751
752 assert_eq!(prompts.prompts.len(), 1);
753 assert_eq!(prompts.prompts[0].name, "test_prompt");
754 }
755
756 #[tokio::test]
757 async fn test_get_prompt() {
758 let transport = MockTransport::with_responses(vec![
759 mock_initialize_response(),
760 serde_json::json!({
761 "messages": [
762 {
763 "role": "user",
764 "content": {
765 "type": "text",
766 "text": "Prompt message"
767 }
768 }
769 ]
770 }),
771 ]);
772 let mut client = McpClient::new(transport);
773
774 client.initialize("test-client", "1.0.0").await.unwrap();
775 let result = client.get_prompt("test_prompt", None).await.unwrap();
776
777 assert_eq!(result.messages.len(), 1);
778 }
779
780 #[tokio::test]
781 async fn test_ping() {
782 let transport =
783 MockTransport::with_responses(vec![mock_initialize_response(), serde_json::json!({})]);
784 let mut client = McpClient::new(transport);
785
786 client.initialize("test-client", "1.0.0").await.unwrap();
787 let result = client.ping().await;
788
789 assert!(result.is_ok());
790 }
791
792 #[tokio::test]
793 async fn test_roots_management() {
794 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
795 let notifications = transport.notifications.clone();
796 let mut client = McpClient::new(transport);
797
798 assert!(client.roots().is_empty());
800
801 client.add_root(Root::new("file:///project")).await.unwrap();
803 assert_eq!(client.roots().len(), 1);
804 assert!(notifications.lock().unwrap().is_empty());
805
806 client.initialize("test-client", "1.0.0").await.unwrap();
808
809 client.add_root(Root::new("file:///other")).await.unwrap();
811 assert_eq!(client.roots().len(), 2);
812 assert_eq!(notifications.lock().unwrap().len(), 2); let removed = client.remove_root("file:///project").await.unwrap();
816 assert!(removed);
817 assert_eq!(client.roots().len(), 1);
818
819 let not_removed = client.remove_root("file:///nonexistent").await.unwrap();
821 assert!(!not_removed);
822 }
823
824 #[tokio::test]
825 async fn test_with_roots() {
826 let roots = vec![Root::new("file:///test")];
827 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
828 let client = McpClient::with_roots(transport, roots);
829
830 assert_eq!(client.roots().len(), 1);
831 assert!(client.capabilities.roots.is_some());
832 }
833
834 #[tokio::test]
835 async fn test_with_capabilities() {
836 let capabilities = ClientCapabilities {
837 sampling: Some(Default::default()),
838 ..Default::default()
839 };
840
841 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
842 let client = McpClient::with_capabilities(transport, capabilities);
843
844 assert!(client.capabilities.sampling.is_some());
845 }
846
847 #[tokio::test]
848 async fn test_list_roots() {
849 let roots = vec![
850 Root::new("file:///project1"),
851 Root::with_name("file:///project2", "Project 2"),
852 ];
853 let transport = MockTransport::new();
854 let client = McpClient::with_roots(transport, roots);
855
856 let result = client.list_roots();
857 assert_eq!(result.roots.len(), 2);
858 assert_eq!(result.roots[1].name, Some("Project 2".to_string()));
859 }
860}