1use std::process::Stdio;
36use std::sync::atomic::{AtomicI64, Ordering};
37
38use async_trait::async_trait;
39use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
40use tokio::process::{Child, Command};
41
42use crate::error::{Error, Result};
43use crate::protocol::{
44 CallToolParams, CallToolResult, ClientCapabilities, CompleteParams, CompleteResult,
45 CompletionArgument, CompletionReference, GetPromptParams, GetPromptResult, Implementation,
46 InitializeParams, InitializeResult, JsonRpcRequest, JsonRpcResponse, ListPromptsParams,
47 ListPromptsResult, ListResourcesParams, ListResourcesResult, ListRootsResult, ListToolsParams,
48 ListToolsResult, ReadResourceParams, ReadResourceResult, Root, RootsCapability, notifications,
49};
50
51#[async_trait]
53pub trait ClientTransport: Send {
54 async fn request(
56 &mut self,
57 method: &str,
58 params: serde_json::Value,
59 ) -> Result<serde_json::Value>;
60
61 async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()>;
63
64 fn is_connected(&self) -> bool;
66
67 async fn close(self: Box<Self>) -> Result<()>;
69}
70
71pub struct McpClient<T: ClientTransport> {
73 transport: T,
74 initialized: bool,
75 server_info: Option<InitializeResult>,
76 capabilities: ClientCapabilities,
78 roots: Vec<Root>,
80}
81
82impl<T: ClientTransport> McpClient<T> {
83 pub fn new(transport: T) -> Self {
85 Self {
86 transport,
87 initialized: false,
88 server_info: None,
89 capabilities: ClientCapabilities::default(),
90 roots: Vec::new(),
91 }
92 }
93
94 pub fn with_roots(mut self, roots: Vec<Root>) -> Self {
113 self.roots = roots;
114 self.capabilities.roots = Some(RootsCapability { list_changed: true });
115 self
116 }
117
118 pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
137 self.capabilities = capabilities;
138 self
139 }
140
141 pub fn server_info(&self) -> Option<&InitializeResult> {
143 self.server_info.as_ref()
144 }
145
146 pub fn is_initialized(&self) -> bool {
148 self.initialized
149 }
150
151 pub fn roots(&self) -> &[Root] {
153 &self.roots
154 }
155
156 pub async fn set_roots(&mut self, roots: Vec<Root>) -> Result<()> {
160 self.roots = roots;
161 if self.initialized {
162 self.notify_roots_changed().await?;
163 }
164 Ok(())
165 }
166
167 pub async fn add_root(&mut self, root: Root) -> Result<()> {
169 self.roots.push(root);
170 if self.initialized {
171 self.notify_roots_changed().await?;
172 }
173 Ok(())
174 }
175
176 pub async fn remove_root(&mut self, uri: &str) -> Result<bool> {
178 let initial_len = self.roots.len();
179 self.roots.retain(|r| r.uri != uri);
180 let removed = self.roots.len() < initial_len;
181 if removed && self.initialized {
182 self.notify_roots_changed().await?;
183 }
184 Ok(removed)
185 }
186
187 async fn notify_roots_changed(&mut self) -> Result<()> {
189 self.transport
190 .notify(notifications::ROOTS_LIST_CHANGED, serde_json::json!({}))
191 .await
192 }
193
194 pub fn list_roots(&self) -> ListRootsResult {
198 ListRootsResult {
199 roots: self.roots.clone(),
200 }
201 }
202
203 pub async fn initialize(
205 &mut self,
206 client_name: &str,
207 client_version: &str,
208 ) -> Result<&InitializeResult> {
209 let params = InitializeParams {
210 protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
211 capabilities: self.capabilities.clone(),
212 client_info: Implementation {
213 name: client_name.to_string(),
214 version: client_version.to_string(),
215 ..Default::default()
216 },
217 };
218
219 let result: InitializeResult = self.request("initialize", ¶ms).await?;
220 self.server_info = Some(result);
221
222 self.transport
224 .notify("notifications/initialized", serde_json::json!({}))
225 .await?;
226
227 self.initialized = true;
228
229 Ok(self.server_info.as_ref().unwrap())
230 }
231
232 pub async fn list_tools(&mut self) -> Result<ListToolsResult> {
234 self.ensure_initialized()?;
235 self.request("tools/list", &ListToolsParams { cursor: None })
236 .await
237 }
238
239 pub async fn call_tool(
241 &mut self,
242 name: &str,
243 arguments: serde_json::Value,
244 ) -> Result<CallToolResult> {
245 self.ensure_initialized()?;
246 let params = CallToolParams {
247 name: name.to_string(),
248 arguments,
249 meta: None,
250 };
251 self.request("tools/call", ¶ms).await
252 }
253
254 pub async fn list_resources(&mut self) -> Result<ListResourcesResult> {
256 self.ensure_initialized()?;
257 self.request("resources/list", &ListResourcesParams { cursor: None })
258 .await
259 }
260
261 pub async fn read_resource(&mut self, uri: &str) -> Result<ReadResourceResult> {
263 self.ensure_initialized()?;
264 let params = ReadResourceParams {
265 uri: uri.to_string(),
266 };
267 self.request("resources/read", ¶ms).await
268 }
269
270 pub async fn list_prompts(&mut self) -> Result<ListPromptsResult> {
272 self.ensure_initialized()?;
273 self.request("prompts/list", &ListPromptsParams { cursor: None })
274 .await
275 }
276
277 pub async fn get_prompt(
279 &mut self,
280 name: &str,
281 arguments: Option<std::collections::HashMap<String, String>>,
282 ) -> Result<GetPromptResult> {
283 self.ensure_initialized()?;
284 let params = GetPromptParams {
285 name: name.to_string(),
286 arguments: arguments.unwrap_or_default(),
287 };
288 self.request("prompts/get", ¶ms).await
289 }
290
291 pub async fn ping(&mut self) -> Result<()> {
293 let _: serde_json::Value = self.request("ping", &serde_json::json!({})).await?;
294 Ok(())
295 }
296
297 pub async fn complete(
301 &mut self,
302 reference: CompletionReference,
303 argument_name: &str,
304 argument_value: &str,
305 ) -> Result<CompleteResult> {
306 self.ensure_initialized()?;
307 let params = CompleteParams {
308 reference,
309 argument: CompletionArgument::new(argument_name, argument_value),
310 };
311 self.request("completion/complete", ¶ms).await
312 }
313
314 pub async fn complete_prompt_arg(
316 &mut self,
317 prompt_name: &str,
318 argument_name: &str,
319 argument_value: &str,
320 ) -> Result<CompleteResult> {
321 self.complete(
322 CompletionReference::prompt(prompt_name),
323 argument_name,
324 argument_value,
325 )
326 .await
327 }
328
329 pub async fn complete_resource_uri(
331 &mut self,
332 resource_uri: &str,
333 argument_name: &str,
334 argument_value: &str,
335 ) -> Result<CompleteResult> {
336 self.complete(
337 CompletionReference::resource(resource_uri),
338 argument_name,
339 argument_value,
340 )
341 .await
342 }
343
344 pub async fn request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
346 &mut self,
347 method: &str,
348 params: &P,
349 ) -> Result<R> {
350 let params_value = serde_json::to_value(params)
351 .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
352
353 let result = self.transport.request(method, params_value).await?;
354
355 serde_json::from_value(result)
356 .map_err(|e| Error::Transport(format!("Failed to deserialize response: {}", e)))
357 }
358
359 pub async fn notify<P: serde::Serialize>(&mut self, method: &str, params: &P) -> Result<()> {
361 let params_value = serde_json::to_value(params)
362 .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
363
364 self.transport.notify(method, params_value).await
365 }
366
367 fn ensure_initialized(&self) -> Result<()> {
368 if !self.initialized {
369 return Err(Error::Transport("Client not initialized".to_string()));
370 }
371 Ok(())
372 }
373}
374
375pub struct StdioClientTransport {
381 child: Option<Child>,
382 stdin: tokio::process::ChildStdin,
383 stdout: BufReader<tokio::process::ChildStdout>,
384 request_id: AtomicI64,
385}
386
387impl StdioClientTransport {
388 pub async fn spawn(program: &str, args: &[&str]) -> Result<Self> {
390 let mut cmd = Command::new(program);
391 cmd.args(args)
392 .stdin(Stdio::piped())
393 .stdout(Stdio::piped())
394 .stderr(Stdio::inherit());
395
396 let mut child = cmd
397 .spawn()
398 .map_err(|e| Error::Transport(format!("Failed to spawn {}: {}", program, e)))?;
399
400 let stdin = child
401 .stdin
402 .take()
403 .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
404 let stdout = child
405 .stdout
406 .take()
407 .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
408
409 tracing::info!(program = %program, "Spawned MCP server process");
410
411 Ok(Self {
412 child: Some(child),
413 stdin,
414 stdout: BufReader::new(stdout),
415 request_id: AtomicI64::new(1),
416 })
417 }
418
419 pub fn from_child(mut child: Child) -> Result<Self> {
421 let stdin = child
422 .stdin
423 .take()
424 .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
425 let stdout = child
426 .stdout
427 .take()
428 .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
429
430 Ok(Self {
431 child: Some(child),
432 stdin,
433 stdout: BufReader::new(stdout),
434 request_id: AtomicI64::new(1),
435 })
436 }
437
438 async fn send_line(&mut self, line: &str) -> Result<()> {
439 self.stdin
440 .write_all(line.as_bytes())
441 .await
442 .map_err(|e| Error::Transport(format!("Failed to write: {}", e)))?;
443 self.stdin
444 .write_all(b"\n")
445 .await
446 .map_err(|e| Error::Transport(format!("Failed to write newline: {}", e)))?;
447 self.stdin
448 .flush()
449 .await
450 .map_err(|e| Error::Transport(format!("Failed to flush: {}", e)))?;
451 Ok(())
452 }
453
454 async fn read_line(&mut self) -> Result<String> {
455 let mut line = String::new();
456 self.stdout
457 .read_line(&mut line)
458 .await
459 .map_err(|e| Error::Transport(format!("Failed to read: {}", e)))?;
460
461 if line.is_empty() {
462 return Err(Error::Transport("Connection closed".to_string()));
463 }
464
465 Ok(line)
466 }
467}
468
469#[async_trait]
470impl ClientTransport for StdioClientTransport {
471 async fn request(
472 &mut self,
473 method: &str,
474 params: serde_json::Value,
475 ) -> Result<serde_json::Value> {
476 let id = self.request_id.fetch_add(1, Ordering::Relaxed);
477 let request = JsonRpcRequest::new(id, method).with_params(params);
478
479 let request_json = serde_json::to_string(&request)
480 .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
481
482 tracing::debug!(method = %method, id = %id, "Sending request");
483 self.send_line(&request_json).await?;
484
485 let response_line = self.read_line().await?;
486 tracing::debug!(response = %response_line.trim(), "Received response");
487
488 let response: JsonRpcResponse = serde_json::from_str(response_line.trim())
489 .map_err(|e| Error::Transport(format!("Failed to parse response: {}", e)))?;
490
491 match response {
492 JsonRpcResponse::Result(r) => Ok(r.result),
493 JsonRpcResponse::Error(e) => Err(Error::JsonRpc(e.error)),
494 }
495 }
496
497 async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
498 let notification = serde_json::json!({
499 "jsonrpc": "2.0",
500 "method": method,
501 "params": params
502 });
503
504 let json = serde_json::to_string(¬ification)
505 .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
506
507 tracing::debug!(method = %method, "Sending notification");
508 self.send_line(&json).await
509 }
510
511 fn is_connected(&self) -> bool {
512 self.child.is_some()
514 }
515
516 async fn close(mut self: Box<Self>) -> Result<()> {
517 drop(self.stdin);
519
520 if let Some(mut child) = self.child.take() {
521 let result =
523 tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
524
525 match result {
526 Ok(Ok(status)) => {
527 tracing::info!(status = ?status, "Child process exited");
528 }
529 Ok(Err(e)) => {
530 tracing::error!(error = %e, "Error waiting for child");
531 }
532 Err(_) => {
533 tracing::warn!("Timeout waiting for child, killing");
534 let _ = child.kill().await;
535 }
536 }
537 }
538
539 Ok(())
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use std::collections::VecDeque;
547 use std::sync::{Arc, Mutex};
548
549 struct MockTransport {
551 responses: Arc<Mutex<VecDeque<serde_json::Value>>>,
552 requests: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
553 notifications: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
554 connected: bool,
555 }
556
557 impl MockTransport {
558 fn new() -> Self {
559 Self {
560 responses: Arc::new(Mutex::new(VecDeque::new())),
561 requests: Arc::new(Mutex::new(Vec::new())),
562 notifications: Arc::new(Mutex::new(Vec::new())),
563 connected: true,
564 }
565 }
566
567 fn with_responses(responses: Vec<serde_json::Value>) -> Self {
568 Self {
569 responses: Arc::new(Mutex::new(responses.into())),
570 requests: Arc::new(Mutex::new(Vec::new())),
571 notifications: Arc::new(Mutex::new(Vec::new())),
572 connected: true,
573 }
574 }
575
576 #[allow(dead_code)]
577 fn get_requests(&self) -> Vec<(String, serde_json::Value)> {
578 self.requests.lock().unwrap().clone()
579 }
580
581 #[allow(dead_code)]
582 fn get_notifications(&self) -> Vec<(String, serde_json::Value)> {
583 self.notifications.lock().unwrap().clone()
584 }
585 }
586
587 #[async_trait]
588 impl ClientTransport for MockTransport {
589 async fn request(
590 &mut self,
591 method: &str,
592 params: serde_json::Value,
593 ) -> Result<serde_json::Value> {
594 self.requests
595 .lock()
596 .unwrap()
597 .push((method.to_string(), params));
598 self.responses
599 .lock()
600 .unwrap()
601 .pop_front()
602 .ok_or_else(|| Error::Transport("No more mock responses".to_string()))
603 }
604
605 async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
606 self.notifications
607 .lock()
608 .unwrap()
609 .push((method.to_string(), params));
610 Ok(())
611 }
612
613 fn is_connected(&self) -> bool {
614 self.connected
615 }
616
617 async fn close(self: Box<Self>) -> Result<()> {
618 Ok(())
619 }
620 }
621
622 fn mock_initialize_response() -> serde_json::Value {
623 serde_json::json!({
624 "protocolVersion": "2025-11-25",
625 "serverInfo": {
626 "name": "test-server",
627 "version": "1.0.0"
628 },
629 "capabilities": {
630 "tools": {}
631 }
632 })
633 }
634
635 #[tokio::test]
636 async fn test_client_not_initialized() {
637 let mut client = McpClient::new(MockTransport::new());
638
639 let result = client.list_tools().await;
641 assert!(result.is_err());
642 assert!(result.unwrap_err().to_string().contains("not initialized"));
643 }
644
645 #[tokio::test]
646 async fn test_client_initialize() {
647 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
648 let mut client = McpClient::new(transport);
649
650 assert!(!client.is_initialized());
651
652 let result = client.initialize("test-client", "1.0.0").await;
653 assert!(result.is_ok());
654 assert!(client.is_initialized());
655
656 let server_info = client.server_info().unwrap();
657 assert_eq!(server_info.server_info.name, "test-server");
658 }
659
660 #[tokio::test]
661 async fn test_list_tools() {
662 let transport = MockTransport::with_responses(vec![
663 mock_initialize_response(),
664 serde_json::json!({
665 "tools": [
666 {
667 "name": "test_tool",
668 "description": "A test tool",
669 "inputSchema": {
670 "type": "object",
671 "properties": {}
672 }
673 }
674 ]
675 }),
676 ]);
677 let mut client = McpClient::new(transport);
678
679 client.initialize("test-client", "1.0.0").await.unwrap();
680 let tools = client.list_tools().await.unwrap();
681
682 assert_eq!(tools.tools.len(), 1);
683 assert_eq!(tools.tools[0].name, "test_tool");
684 }
685
686 #[tokio::test]
687 async fn test_call_tool() {
688 let transport = MockTransport::with_responses(vec![
689 mock_initialize_response(),
690 serde_json::json!({
691 "content": [
692 {
693 "type": "text",
694 "text": "Tool result"
695 }
696 ]
697 }),
698 ]);
699 let mut client = McpClient::new(transport);
700
701 client.initialize("test-client", "1.0.0").await.unwrap();
702 let result = client
703 .call_tool("test_tool", serde_json::json!({"arg": "value"}))
704 .await
705 .unwrap();
706
707 assert!(!result.content.is_empty());
708 }
709
710 #[tokio::test]
711 async fn test_list_resources() {
712 let transport = MockTransport::with_responses(vec![
713 mock_initialize_response(),
714 serde_json::json!({
715 "resources": [
716 {
717 "uri": "file://test.txt",
718 "name": "Test File"
719 }
720 ]
721 }),
722 ]);
723 let mut client = McpClient::new(transport);
724
725 client.initialize("test-client", "1.0.0").await.unwrap();
726 let resources = client.list_resources().await.unwrap();
727
728 assert_eq!(resources.resources.len(), 1);
729 assert_eq!(resources.resources[0].uri, "file://test.txt");
730 }
731
732 #[tokio::test]
733 async fn test_read_resource() {
734 let transport = MockTransport::with_responses(vec![
735 mock_initialize_response(),
736 serde_json::json!({
737 "contents": [
738 {
739 "uri": "file://test.txt",
740 "text": "File contents"
741 }
742 ]
743 }),
744 ]);
745 let mut client = McpClient::new(transport);
746
747 client.initialize("test-client", "1.0.0").await.unwrap();
748 let result = client.read_resource("file://test.txt").await.unwrap();
749
750 assert_eq!(result.contents.len(), 1);
751 assert_eq!(result.contents[0].text.as_deref(), Some("File contents"));
752 }
753
754 #[tokio::test]
755 async fn test_list_prompts() {
756 let transport = MockTransport::with_responses(vec![
757 mock_initialize_response(),
758 serde_json::json!({
759 "prompts": [
760 {
761 "name": "test_prompt",
762 "description": "A test prompt"
763 }
764 ]
765 }),
766 ]);
767 let mut client = McpClient::new(transport);
768
769 client.initialize("test-client", "1.0.0").await.unwrap();
770 let prompts = client.list_prompts().await.unwrap();
771
772 assert_eq!(prompts.prompts.len(), 1);
773 assert_eq!(prompts.prompts[0].name, "test_prompt");
774 }
775
776 #[tokio::test]
777 async fn test_get_prompt() {
778 let transport = MockTransport::with_responses(vec![
779 mock_initialize_response(),
780 serde_json::json!({
781 "messages": [
782 {
783 "role": "user",
784 "content": {
785 "type": "text",
786 "text": "Prompt message"
787 }
788 }
789 ]
790 }),
791 ]);
792 let mut client = McpClient::new(transport);
793
794 client.initialize("test-client", "1.0.0").await.unwrap();
795 let result = client.get_prompt("test_prompt", None).await.unwrap();
796
797 assert_eq!(result.messages.len(), 1);
798 }
799
800 #[tokio::test]
801 async fn test_ping() {
802 let transport =
803 MockTransport::with_responses(vec![mock_initialize_response(), serde_json::json!({})]);
804 let mut client = McpClient::new(transport);
805
806 client.initialize("test-client", "1.0.0").await.unwrap();
807 let result = client.ping().await;
808
809 assert!(result.is_ok());
810 }
811
812 #[tokio::test]
813 async fn test_roots_management() {
814 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
815 let notifications = transport.notifications.clone();
816 let mut client = McpClient::new(transport);
817
818 assert!(client.roots().is_empty());
820
821 client.add_root(Root::new("file:///project")).await.unwrap();
823 assert_eq!(client.roots().len(), 1);
824 assert!(notifications.lock().unwrap().is_empty());
825
826 client.initialize("test-client", "1.0.0").await.unwrap();
828
829 client.add_root(Root::new("file:///other")).await.unwrap();
831 assert_eq!(client.roots().len(), 2);
832 assert_eq!(notifications.lock().unwrap().len(), 2); let removed = client.remove_root("file:///project").await.unwrap();
836 assert!(removed);
837 assert_eq!(client.roots().len(), 1);
838
839 let not_removed = client.remove_root("file:///nonexistent").await.unwrap();
841 assert!(!not_removed);
842 }
843
844 #[tokio::test]
845 async fn test_with_roots() {
846 let roots = vec![Root::new("file:///test")];
847 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
848 let client = McpClient::new(transport).with_roots(roots);
849
850 assert_eq!(client.roots().len(), 1);
851 assert!(client.capabilities.roots.is_some());
852 }
853
854 #[tokio::test]
855 async fn test_with_capabilities() {
856 let capabilities = ClientCapabilities {
857 sampling: Some(Default::default()),
858 ..Default::default()
859 };
860
861 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
862 let client = McpClient::new(transport).with_capabilities(capabilities);
863
864 assert!(client.capabilities.sampling.is_some());
865 }
866
867 #[tokio::test]
868 async fn test_list_roots() {
869 let roots = vec![
870 Root::new("file:///project1"),
871 Root::with_name("file:///project2", "Project 2"),
872 ];
873 let transport = MockTransport::new();
874 let client = McpClient::new(transport).with_roots(roots);
875
876 let result = client.list_roots();
877 assert_eq!(result.roots.len(), 2);
878 assert_eq!(result.roots[1].name, Some("Project 2".to_string()));
879 }
880}