1use super::base_client::BaseMCPClient;
11use super::model::*;
12use super::{ResourceCache, SubscriptionManager};
13use crate::desktop::window_uri::{is_window_uri, WindowURI};
14use async_trait::async_trait;
15use serde_json;
16use std::process::Stdio;
17use std::sync::Arc;
18use std::time::Duration;
19use tokio::io::{AsyncBufReadExt, BufReader};
20use tokio::process::{Child, Command};
21use tokio::sync::Mutex;
22use tracing::{debug, error, info, warn};
23
24pub struct StdioMCPClient {
26 base: BaseMCPClient<StdioServerParameters>,
28 child_process: Arc<Mutex<Option<Child>>>,
30 session_id: Arc<Mutex<Option<String>>>,
32 subscription_manager: SubscriptionManager,
34 resource_cache: ResourceCache,
36}
37
38impl std::fmt::Debug for StdioMCPClient {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("StdioMCPClient")
41 .field("command", &self.base.params.command)
42 .field("args", &self.base.params.args)
43 .field("state", &self.base.state())
44 .finish()
45 }
46}
47
48impl StdioMCPClient {
49 pub fn new(params: StdioServerParameters) -> Self {
51 Self {
52 base: BaseMCPClient::new(params),
53 child_process: Arc::new(Mutex::new(None)),
54 session_id: Arc::new(Mutex::new(None)),
55 subscription_manager: SubscriptionManager::new(),
56 resource_cache: ResourceCache::new(Duration::from_secs(60)), }
58 }
59
60 async fn start_child_process(
62 &self,
63 params: &StdioServerParameters,
64 ) -> Result<Child, MCPClientError> {
65 let mut cmd = Command::new(¶ms.command);
66
67 cmd.args(¶ms.args);
69
70 for (key, value) in ¶ms.env {
72 cmd.env(key, value);
73 }
74
75 if let Some(cwd) = ¶ms.cwd {
77 cmd.current_dir(cwd);
78 }
79
80 cmd.stdin(Stdio::piped())
82 .stdout(Stdio::piped())
83 .stderr(Stdio::piped());
84
85 debug!("Starting command: {} {:?}", params.command, params.args);
86
87 let child = cmd.spawn().map_err(|e| {
88 MCPClientError::ConnectionError(format!("Failed to start process: {}", e))
89 })?;
90
91 Ok(child)
92 }
93
94 async fn send_notification(
97 &self,
98 notification: &serde_json::Value,
99 ) -> Result<(), MCPClientError> {
100 let mut child = self.child_process.lock().await;
101 if let Some(ref mut process) = *child {
102 if let Some(stdin) = process.stdin.as_mut() {
103 let notification_str = serde_json::to_string(notification)?;
104 use tokio::io::AsyncWriteExt;
105 stdin.write_all(notification_str.as_bytes()).await?;
106 stdin.write_all(b"\n").await?;
107 stdin.flush().await?;
108
109 debug!("Sent notification: {}", notification_str);
110 info!("Sent notification to MCP server: {}", notification_str);
111 return Ok(());
112 }
113 }
114 Err(MCPClientError::ConnectionError(
115 "Process not available".to_string(),
116 ))
117 }
118
119 async fn send_request(
120 &self,
121 request: &serde_json::Value,
122 ) -> Result<serde_json::Value, MCPClientError> {
123 let mut child = self.child_process.lock().await;
124 if let Some(ref mut process) = *child {
125 if let Some(stdin) = process.stdin.as_mut() {
126 let request_str = serde_json::to_string(request)?;
127 use tokio::io::AsyncWriteExt;
128 stdin.write_all(request_str.as_bytes()).await?;
129 stdin.write_all(b"\n").await?;
130 stdin.flush().await?;
131
132 debug!("Sent request: {}", request_str);
133 info!("Sent request to MCP server: {}", request_str);
134
135 if let Some(stdout) = process.stdout.as_mut() {
137 let mut reader = BufReader::new(stdout);
138 let mut line = String::new();
139
140 info!("Waiting for response from MCP server...");
141
142 return match tokio::time::timeout(
144 std::time::Duration::from_secs(30),
145 reader.read_line(&mut line),
146 )
147 .await
148 {
149 Ok(Ok(0)) => {
150 error!("Process closed stdout without response");
151 Err(MCPClientError::ConnectionError(
152 "Process closed stdout".to_string(),
153 ))
154 }
155 Ok(Ok(_)) => {
156 info!("Received raw response: {}", line.trim());
157 debug!("Received response: {}", line.trim());
158 let response: serde_json::Value = serde_json::from_str(line.trim())
159 .map_err(|e| {
160 error!("Failed to parse JSON response: {}", e);
161 MCPClientError::ProtocolError(format!("Invalid JSON: {}", e))
162 })?;
163 info!("Parsed JSON response: {}", response);
164 Ok(response)
165 }
166 Ok(Err(e)) => Err(MCPClientError::ConnectionError(format!(
167 "Failed to read response: {}",
168 e
169 ))),
170 Err(_) => Err(MCPClientError::TimeoutError(
171 "No response received within timeout".to_string(),
172 )),
173 };
174 }
175 }
176 }
177
178 Err(MCPClientError::ConnectionError(
179 "Process not running".to_string(),
180 ))
181 }
182
183 async fn initialize_session(&self) -> Result<(), MCPClientError> {
185 let init_request = serde_json::json!({
186 "jsonrpc": "2.0",
187 "id": 1,
188 "method": "initialize",
189 "params": {
190 "protocolVersion": "2024-11-05",
191 "capabilities": {
192 "tools": {},
193 "resources": {}
194 },
195 "clientInfo": {
196 "name": "a2c-smcp-rust",
197 "version": "0.1.0"
198 }
199 }
200 });
201
202 let response = self.send_request(&init_request).await?;
203
204 if let Some(error) = response.get("error") {
206 return Err(MCPClientError::ProtocolError(format!(
207 "Initialize error: {}",
208 error
209 )));
210 }
211
212 if let Some(result) = response.get("result") {
213 if let Some(session_id) = result.get("sessionId").and_then(|v| v.as_str()) {
214 *self.session_id.lock().await = Some(session_id.to_string());
215 }
216 }
217
218 let initialized_notification = serde_json::json!({
220 "jsonrpc": "2.0",
221 "method": "notifications/initialized"
222 });
223
224 self.send_notification(&initialized_notification).await?;
226
227 info!("Session initialized successfully");
228 Ok(())
229 }
230
231 pub async fn is_subscribed(&self, uri: &str) -> bool {
235 self.subscription_manager.is_subscribed(uri).await
236 }
237
238 pub async fn get_subscriptions(&self) -> Vec<String> {
240 self.subscription_manager.get_subscriptions().await
241 }
242
243 pub async fn subscription_count(&self) -> usize {
245 self.subscription_manager.subscription_count().await
246 }
247
248 pub async fn get_cached_resource(&self, uri: &str) -> Option<serde_json::Value> {
252 self.resource_cache.get(uri).await
253 }
254
255 pub async fn has_cache(&self, uri: &str) -> bool {
257 self.resource_cache.contains(uri).await
258 }
259
260 pub async fn cache_size(&self) -> usize {
262 self.resource_cache.size().await
263 }
264
265 pub async fn cleanup_cache(&self) -> usize {
267 self.resource_cache.cleanup_expired().await
268 }
269
270 pub async fn clear_cache(&self) {
272 self.resource_cache.clear().await
273 }
274
275 pub async fn cache_keys(&self) -> Vec<String> {
277 self.resource_cache.keys().await
278 }
279}
280
281#[async_trait]
282impl MCPClientProtocol for StdioMCPClient {
283 fn state(&self) -> ClientState {
284 self.base.state()
285 }
286
287 async fn connect(&self) -> Result<(), MCPClientError> {
288 if !self.base.can_connect().await {
290 return Err(MCPClientError::ConnectionError(format!(
291 "Cannot connect in state: {}",
292 self.base.get_state().await
293 )));
294 }
295
296 let params = self.base.params.clone();
298
299 let child = self.start_child_process(¶ms).await?;
301 *self.child_process.lock().await = Some(child);
302
303 self.initialize_session().await?;
305
306 self.base.update_state(ClientState::Connected).await;
308 info!("STDIO client connected successfully");
309
310 Ok(())
311 }
312
313 async fn disconnect(&self) -> Result<(), MCPClientError> {
314 if !self.base.can_disconnect().await {
316 return Err(MCPClientError::ConnectionError(format!(
317 "Cannot disconnect in state: {}",
318 self.base.get_state().await
319 )));
320 }
321
322 let mut child = self.child_process.lock().await;
324 if let Some(mut process) = child.take() {
325 let shutdown_request = serde_json::json!({
327 "jsonrpc": "2.0",
328 "id": 2,
329 "method": "shutdown"
330 });
331
332 if let Some(stdin) = process.stdin.as_mut() {
334 let request_str = serde_json::to_string(&shutdown_request)?;
335 use tokio::io::AsyncWriteExt;
336 if let Err(e) = stdin.write_all(request_str.as_bytes()).await {
337 warn!("Failed to send shutdown request: {}", e);
338 } else {
339 let _ = stdin.write_all(b"\n").await;
340 let _ = stdin.flush().await;
341 }
342 }
343
344 let exit_notification = serde_json::json!({
346 "jsonrpc": "2.0",
347 "method": "exit"
348 });
349
350 if let Some(stdin) = process.stdin.as_mut() {
351 let request_str = serde_json::to_string(&exit_notification)?;
352 use tokio::io::AsyncWriteExt;
353 if let Err(e) = stdin.write_all(request_str.as_bytes()).await {
354 warn!("Failed to send exit notification: {}", e);
355 } else {
356 let _ = stdin.write_all(b"\n").await;
357 let _ = stdin.flush().await;
358 }
359 }
360
361 drop(child);
363
364 match tokio::time::timeout(std::time::Duration::from_secs(5), process.wait()).await {
366 Ok(Ok(status)) => {
367 debug!("Process exited with status: {}", status);
368 }
369 Ok(Err(e)) => {
370 error!("Error waiting for process: {}", e);
371 }
372 Err(_) => {
373 warn!("Process did not exit within timeout, killing it");
374 if let Err(e) = process.kill().await {
375 error!("Failed to kill process: {}", e);
376 }
377 }
378 }
379 } else {
380 drop(child);
382 }
383
384 *self.session_id.lock().await = None;
386
387 self.base.update_state(ClientState::Disconnected).await;
389 info!("STDIO client disconnected successfully");
390
391 Ok(())
392 }
393
394 async fn list_tools(&self) -> Result<Vec<Tool>, MCPClientError> {
395 if self.base.get_state().await != ClientState::Connected {
396 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
397 }
398
399 let request = serde_json::json!({
400 "jsonrpc": "2.0",
401 "id": 3,
402 "method": "tools/list"
403 });
404
405 let response = self.send_request(&request).await?;
406 info!("Received list_tools response: {}", response);
407
408 if let Some(error) = response.get("error") {
409 return Err(MCPClientError::ProtocolError(format!(
410 "List tools error: {}",
411 error
412 )));
413 }
414
415 if let Some(result) = response.get("result") {
416 info!("Result field: {}", result);
417 if let Some(tools) = result.get("tools").and_then(|v| v.as_array()) {
418 info!("Found {} tools", tools.len());
419 let mut tool_list = Vec::new();
420 for (i, tool) in tools.iter().enumerate() {
421 info!("Tool {}: {}", i, tool);
422 if let Ok(parsed_tool) = serde_json::from_value::<Tool>(tool.clone()) {
423 tool_list.push(parsed_tool);
424 } else {
425 warn!("Failed to parse tool {}: {}", i, tool);
426 }
427 }
428 return Ok(tool_list);
429 } else {
430 warn!("No tools array found in result");
431 }
432 } else {
433 warn!("No result field found in response");
434 }
435
436 Ok(vec![])
437 }
438
439 async fn call_tool(
440 &self,
441 tool_name: &str,
442 params: serde_json::Value,
443 ) -> Result<CallToolResult, MCPClientError> {
444 if self.base.get_state().await != ClientState::Connected {
445 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
446 }
447
448 let request = serde_json::json!({
449 "jsonrpc": "2.0",
450 "id": 4,
451 "method": "tools/call",
452 "params": {
453 "name": tool_name,
454 "arguments": params
455 }
456 });
457
458 let response = self.send_request(&request).await?;
459
460 if let Some(error) = response.get("error") {
461 return Err(MCPClientError::ProtocolError(format!(
462 "Call tool error: {}",
463 error
464 )));
465 }
466
467 if let Some(result) = response.get("result") {
468 let call_result: CallToolResult = serde_json::from_value(result.clone())?;
469 return Ok(call_result);
470 }
471
472 Err(MCPClientError::ProtocolError(
473 "Invalid response".to_string(),
474 ))
475 }
476
477 async fn list_windows(&self) -> Result<Vec<Resource>, MCPClientError> {
478 if self.base.get_state().await != ClientState::Connected {
479 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
480 }
481
482 let mut all_resources = Vec::new();
484 let mut cursor: Option<String> = None;
485
486 loop {
487 let mut request = serde_json::json!({
488 "jsonrpc": "2.0",
489 "id": 5,
490 "method": "resources/list"
491 });
492
493 if let Some(ref c) = cursor {
495 request["params"] = serde_json::json!({ "cursor": c });
496 }
497
498 let response = self.send_request(&request).await?;
499
500 if let Some(error) = response.get("error") {
501 return Err(MCPClientError::ProtocolError(format!(
502 "List resources error: {}",
503 error
504 )));
505 }
506
507 if let Some(result) = response.get("result") {
508 if let Some(resources) = result.get("resources").and_then(|v| v.as_array()) {
510 for resource in resources {
511 if let Ok(parsed_resource) =
512 serde_json::from_value::<Resource>(resource.clone())
513 {
514 all_resources.push(parsed_resource);
515 }
516 }
517 }
518
519 cursor = result
521 .get("nextCursor")
522 .and_then(|v| v.as_str())
523 .map(|s| s.to_string());
524
525 if cursor.is_none() {
526 break;
527 }
528 } else {
529 break;
530 }
531 }
532
533 let mut filtered_resources: Vec<(Resource, i32)> = Vec::new();
535
536 for resource in all_resources {
537 if !is_window_uri(&resource.uri) {
538 continue;
539 }
540
541 let priority = if let Ok(uri) = WindowURI::new(&resource.uri) {
543 uri.priority().unwrap_or(0)
544 } else {
545 0
546 };
547
548 filtered_resources.push((resource, priority));
549 }
550
551 filtered_resources.sort_by(|a, b| b.1.cmp(&a.1));
553
554 Ok(filtered_resources.into_iter().map(|(r, _)| r).collect())
556 }
557
558 async fn get_window_detail(
559 &self,
560 resource: Resource,
561 ) -> Result<ReadResourceResult, MCPClientError> {
562 if self.base.get_state().await != ClientState::Connected {
563 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
564 }
565
566 let request = serde_json::json!({
567 "jsonrpc": "2.0",
568 "id": 6,
569 "method": "resources/read",
570 "params": {
571 "uri": resource.uri
572 }
573 });
574
575 let response = self.send_request(&request).await?;
576
577 if let Some(error) = response.get("error") {
578 return Err(MCPClientError::ProtocolError(format!(
579 "Read resource error: {}",
580 error
581 )));
582 }
583
584 if let Some(result) = response.get("result") {
585 let read_result: ReadResourceResult = serde_json::from_value(result.clone())?;
586 return Ok(read_result);
587 }
588
589 Err(MCPClientError::ProtocolError(
590 "Invalid response".to_string(),
591 ))
592 }
593
594 async fn subscribe_window(&self, resource: Resource) -> Result<(), MCPClientError> {
595 if self.base.get_state().await != ClientState::Connected {
596 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
597 }
598
599 let request = serde_json::json!({
600 "jsonrpc": "2.0",
601 "id": 7,
602 "method": "resources/subscribe",
603 "params": {
604 "uri": resource.uri
605 }
606 });
607
608 let response = self.send_request(&request).await?;
609
610 if let Some(error) = response.get("error") {
611 return Err(MCPClientError::ProtocolError(format!(
612 "Subscribe resource error: {}",
613 error
614 )));
615 }
616
617 let _ = self
619 .subscription_manager
620 .add_subscription(resource.uri.clone())
621 .await;
622
623 match self.get_window_detail(resource.clone()).await {
625 Ok(result) => {
626 if !result.contents.is_empty() {
628 if let Ok(json_value) = serde_json::to_value(&result.contents[0]) {
630 self.resource_cache
631 .set(resource.uri.clone(), json_value, None)
632 .await;
633 info!("Subscribed and cached: {}", resource.uri);
634 }
635 }
636 }
637 Err(e) => {
638 warn!("Failed to fetch resource data after subscription: {:?}", e);
639 }
640 }
641
642 Ok(())
643 }
644
645 async fn unsubscribe_window(&self, resource: Resource) -> Result<(), MCPClientError> {
646 if self.base.get_state().await != ClientState::Connected {
647 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
648 }
649
650 let request = serde_json::json!({
651 "jsonrpc": "2.0",
652 "id": 8,
653 "method": "resources/unsubscribe",
654 "params": {
655 "uri": resource.uri
656 }
657 });
658
659 let response = self.send_request(&request).await?;
660
661 if let Some(error) = response.get("error") {
662 return Err(MCPClientError::ProtocolError(format!(
663 "Unsubscribe resource error: {}",
664 error
665 )));
666 }
667
668 let _ = self
670 .subscription_manager
671 .remove_subscription(&resource.uri)
672 .await;
673
674 self.resource_cache.remove(&resource.uri).await;
676 info!("Unsubscribed and removed cache: {}", resource.uri);
677
678 Ok(())
679 }
680}
681
682#[cfg(test)]
683mod tests {
684 use super::*;
685 use serde_json::json;
686 use std::collections::HashMap;
687 use tokio::time::{sleep, Duration};
688
689 #[tokio::test]
690 async fn test_stdio_client_creation() {
691 let params = StdioServerParameters {
692 command: "echo".to_string(),
693 args: vec!["hello".to_string()],
694 env: HashMap::new(),
695 cwd: None,
696 };
697
698 let client = StdioMCPClient::new(params);
699 assert_eq!(client.state(), ClientState::Initialized);
700 assert_eq!(client.base.params.command, "echo");
701 }
702
703 #[tokio::test]
704 async fn test_stdio_client_with_env() {
705 let mut env = HashMap::new();
706 env.insert("TEST_VAR".to_string(), "test_value".to_string());
707 env.insert("PATH".to_string(), "/usr/bin".to_string());
708
709 let params = StdioServerParameters {
710 command: "echo".to_string(),
711 args: vec!["test".to_string()],
712 env,
713 cwd: Some("/tmp".to_string()),
714 };
715
716 let client = StdioMCPClient::new(params);
717 assert_eq!(
718 client.base.params.env.get("TEST_VAR"),
719 Some(&"test_value".to_string())
720 );
721 assert_eq!(client.base.params.cwd, Some("/tmp".to_string()));
722 }
723
724 #[tokio::test]
725 async fn test_session_id_management() {
726 let params = StdioServerParameters {
727 command: "echo".to_string(),
728 args: vec!["test".to_string()],
729 env: HashMap::new(),
730 cwd: None,
731 };
732
733 let client = StdioMCPClient::new(params);
734
735 let session_id = client.session_id.lock().await;
737 assert!(session_id.is_none());
738 drop(session_id);
739
740 *client.session_id.lock().await = Some("session123".to_string());
742 let session_id = client.session_id.lock().await;
743 assert_eq!(session_id.as_ref().unwrap(), "session123");
744 }
745
746 #[tokio::test]
747 async fn test_start_child_process_with_echo() {
748 let params = StdioServerParameters {
749 command: "echo".to_string(),
750 args: vec!["hello world".to_string()],
751 env: HashMap::new(),
752 cwd: None,
753 };
754
755 let client = StdioMCPClient::new(params);
756
757 let result = client.start_child_process(&client.base.params).await;
759 assert!(result.is_ok());
760
761 let mut child = result.unwrap();
763
764 sleep(Duration::from_millis(100)).await;
766
767 let _ = child.kill().await;
769 }
770
771 #[tokio::test]
772 async fn test_start_child_process_with_invalid_command() {
773 let params = StdioServerParameters {
774 command: "nonexistent_command_12345".to_string(),
775 args: vec![],
776 env: HashMap::new(),
777 cwd: None,
778 };
779
780 let client = StdioMCPClient::new(params.clone());
781
782 let result = client.start_child_process(¶ms).await;
784 assert!(result.is_err());
785 assert!(matches!(
786 result.unwrap_err(),
787 MCPClientError::ConnectionError(_)
788 ));
789 }
790
791 #[tokio::test]
792 async fn test_send_request_without_process() {
793 let params = StdioServerParameters {
794 command: "echo".to_string(),
795 args: vec!["test".to_string()],
796 env: HashMap::new(),
797 cwd: None,
798 };
799
800 let client = StdioMCPClient::new(params);
801
802 let request = json!({"jsonrpc": "2.0", "method": "test"});
804 let result = client.send_request(&request).await;
805 assert!(result.is_err());
806 assert!(matches!(
807 result.unwrap_err(),
808 MCPClientError::ConnectionError(_)
809 ));
810 }
811
812 #[tokio::test]
813 async fn test_connect_state_checks() {
814 let params = StdioServerParameters {
815 command: "echo".to_string(),
816 args: vec!["test".to_string()],
817 env: HashMap::new(),
818 cwd: None,
819 };
820
821 let client = StdioMCPClient::new(params);
822
823 client.base.update_state(ClientState::Connected).await;
825 let result = client.connect().await;
826 assert!(result.is_err());
827 assert!(matches!(
828 result.unwrap_err(),
829 MCPClientError::ConnectionError(_)
830 ));
831 }
832
833 #[tokio::test]
834 async fn test_disconnect_state_checks() {
835 let params = StdioServerParameters {
836 command: "echo".to_string(),
837 args: vec!["test".to_string()],
838 env: HashMap::new(),
839 cwd: None,
840 };
841
842 let client = StdioMCPClient::new(params);
843
844 let result = client.disconnect().await;
846 assert!(result.is_err());
847 assert!(matches!(
848 result.unwrap_err(),
849 MCPClientError::ConnectionError(_)
850 ));
851 }
852
853 #[tokio::test]
854 async fn test_list_tools_requires_connection() {
855 let params = StdioServerParameters {
856 command: "echo".to_string(),
857 args: vec!["test".to_string()],
858 env: HashMap::new(),
859 cwd: None,
860 };
861
862 let client = StdioMCPClient::new(params);
863
864 let result = client.list_tools().await;
866 assert!(result.is_err());
867 assert!(matches!(
868 result.unwrap_err(),
869 MCPClientError::ConnectionError(_)
870 ));
871 }
872
873 #[tokio::test]
874 async fn test_call_tool_requires_connection() {
875 let params = StdioServerParameters {
876 command: "echo".to_string(),
877 args: vec!["test".to_string()],
878 env: HashMap::new(),
879 cwd: None,
880 };
881
882 let client = StdioMCPClient::new(params);
883
884 let result = client.call_tool("test_tool", json!({})).await;
886 assert!(result.is_err());
887 assert!(matches!(
888 result.unwrap_err(),
889 MCPClientError::ConnectionError(_)
890 ));
891 }
892
893 #[tokio::test]
894 async fn test_list_windows_requires_connection() {
895 let params = StdioServerParameters {
896 command: "echo".to_string(),
897 args: vec!["test".to_string()],
898 env: HashMap::new(),
899 cwd: None,
900 };
901
902 let client = StdioMCPClient::new(params);
903
904 let result = client.list_windows().await;
906 assert!(result.is_err());
907 assert!(matches!(
908 result.unwrap_err(),
909 MCPClientError::ConnectionError(_)
910 ));
911 }
912
913 #[tokio::test]
914 async fn test_get_window_detail_requires_connection() {
915 let params = StdioServerParameters {
916 command: "echo".to_string(),
917 args: vec!["test".to_string()],
918 env: HashMap::new(),
919 cwd: None,
920 };
921
922 let client = StdioMCPClient::new(params);
923
924 let resource = Resource {
925 uri: "window://123".to_string(),
926 name: "Test Window".to_string(),
927 description: None,
928 mime_type: None,
929 };
930
931 let result = client.get_window_detail(resource).await;
933 assert!(result.is_err());
934 assert!(matches!(
935 result.unwrap_err(),
936 MCPClientError::ConnectionError(_)
937 ));
938 }
939
940 #[tokio::test]
941 async fn test_initialize_session_request_format() {
942 let params = StdioServerParameters {
943 command: "echo".to_string(),
944 args: vec!["test".to_string()],
945 env: HashMap::new(),
946 cwd: None,
947 };
948
949 let client = StdioMCPClient::new(params);
950
951 let result = client.initialize_session().await;
953 assert!(result.is_err());
954 }
955
956 #[tokio::test]
957 async fn test_disconnect_cleanup() {
958 let params = StdioServerParameters {
959 command: "echo".to_string(),
960 args: vec!["test".to_string()],
961 env: HashMap::new(),
962 cwd: None,
963 };
964
965 let client = StdioMCPClient::new(params);
966
967 *client.session_id.lock().await = Some("session123".to_string());
969
970 client.base.update_state(ClientState::Connected).await;
972
973 let _ = client.disconnect().await;
975
976 let session_id = client.session_id.lock().await;
978 assert!(session_id.is_none());
979
980 assert_eq!(client.base.get_state().await, ClientState::Disconnected);
982 }
983
984 #[tokio::test]
985 async fn test_child_process_cleanup() {
986 let params = StdioServerParameters {
987 command: "sleep".to_string(),
988 args: vec!["10".to_string()],
989 env: HashMap::new(),
990 cwd: None,
991 };
992
993 let client = StdioMCPClient::new(params.clone());
994
995 let child = client.start_child_process(¶ms).await.unwrap();
997 *client.child_process.lock().await = Some(child);
998
999 client.base.update_state(ClientState::Connected).await;
1001
1002 let child_guard = client.child_process.lock().await;
1004 assert!(child_guard.is_some());
1005 drop(child_guard);
1006
1007 let _ = client.disconnect().await;
1009
1010 let child_guard = client.child_process.lock().await;
1012 assert!(child_guard.is_none());
1013 }
1014
1015 #[tokio::test]
1016 async fn test_error_handling_in_list_tools() {
1017 let params = StdioServerParameters {
1018 command: "echo".to_string(),
1019 args: vec!["test".to_string()],
1020 env: HashMap::new(),
1021 cwd: None,
1022 };
1023
1024 let client = StdioMCPClient::new(params);
1025
1026 client.base.update_state(ClientState::Connected).await;
1028
1029 let result = client.list_tools().await;
1031 assert!(result.is_err());
1032 }
1033
1034 #[tokio::test]
1035 async fn test_error_handling_in_call_tool() {
1036 let params = StdioServerParameters {
1037 command: "echo".to_string(),
1038 args: vec!["test".to_string()],
1039 env: HashMap::new(),
1040 cwd: None,
1041 };
1042
1043 let client = StdioMCPClient::new(params);
1044
1045 client.base.update_state(ClientState::Connected).await;
1047
1048 let result = client
1050 .call_tool("test_tool", json!({"param": "value"}))
1051 .await;
1052 assert!(result.is_err());
1053 }
1054
1055 #[tokio::test]
1056 async fn test_start_child_process_with_working_directory() {
1057 let params = StdioServerParameters {
1058 command: "pwd".to_string(),
1059 args: vec![],
1060 env: HashMap::new(),
1061 cwd: Some("/tmp".to_string()),
1062 };
1063
1064 let client = StdioMCPClient::new(params.clone());
1065
1066 let result = client.start_child_process(¶ms).await;
1068 assert!(result.is_ok());
1069
1070 let mut child = result.unwrap();
1071
1072 let _ = child.wait().await;
1074 }
1075
1076 #[tokio::test]
1077 async fn test_stdio_client_debug_format() {
1078 let params = StdioServerParameters {
1079 command: "echo".to_string(),
1080 args: vec!["test".to_string()],
1081 env: HashMap::new(),
1082 cwd: None,
1083 };
1084
1085 let client = StdioMCPClient::new(params);
1086
1087 let debug_str = format!("{:?}", client);
1089 assert!(debug_str.contains("StdioMCPClient"));
1090 }
1091}