smcp_computer/mcp_clients/
stdio_client.rs1use super::base_client::BaseMCPClient;
11use super::model::*;
12use super::{ResourceCache, SubscriptionManager};
13use crate::desktop::window_uri::{is_window_uri, WindowURI};
14use async_trait::async_trait;
15use rmcp::model::{
16 CallToolRequestParam, ClientInfo, Implementation, ReadResourceRequestParam,
17 SubscribeRequestParam, UnsubscribeRequestParam,
18};
19use rmcp::service::{RunningService, ServiceExt};
20use rmcp::transport::TokioChildProcess;
21use rmcp::RoleClient;
22use std::process::Stdio;
23use std::sync::Arc;
24use std::time::Duration;
25use tokio::process::{ChildStderr, Command};
26use tokio::sync::Mutex;
27use tracing::{debug, error, info, warn};
28
29const CONNECT_TIMEOUT_SECS: u64 = 30;
32
33pub struct StdioMCPClient {
35 base: BaseMCPClient<StdioServerParameters>,
37 running_service: Arc<Mutex<Option<RunningService<RoleClient, ClientInfo>>>>,
39 child_stderr: Arc<Mutex<Option<ChildStderr>>>,
41 subscription_manager: SubscriptionManager,
43 resource_cache: ResourceCache,
45}
46
47impl std::fmt::Debug for StdioMCPClient {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 f.debug_struct("StdioMCPClient")
50 .field("command", &self.base.params.command)
51 .field("args", &self.base.params.args)
52 .field("state", &self.base.state())
53 .finish()
54 }
55}
56
57impl StdioMCPClient {
58 pub fn new(params: StdioServerParameters) -> Self {
60 Self {
61 base: BaseMCPClient::new(params),
62 running_service: Arc::new(Mutex::new(None)),
63 child_stderr: Arc::new(Mutex::new(None)),
64 subscription_manager: SubscriptionManager::new(),
65 resource_cache: ResourceCache::new(Duration::from_secs(60)),
66 }
67 }
68
69 pub async fn is_subscribed(&self, uri: &str) -> bool {
73 self.subscription_manager.is_subscribed(uri).await
74 }
75
76 pub async fn get_subscriptions(&self) -> Vec<String> {
78 self.subscription_manager.get_subscriptions().await
79 }
80
81 pub async fn subscription_count(&self) -> usize {
83 self.subscription_manager.subscription_count().await
84 }
85
86 pub async fn get_cached_resource(&self, uri: &str) -> Option<serde_json::Value> {
90 self.resource_cache.get(uri).await
91 }
92
93 pub async fn has_cache(&self, uri: &str) -> bool {
95 self.resource_cache.contains(uri).await
96 }
97
98 pub async fn cache_size(&self) -> usize {
100 self.resource_cache.size().await
101 }
102
103 pub async fn cleanup_cache(&self) -> usize {
105 self.resource_cache.cleanup_expired().await
106 }
107
108 pub async fn clear_cache(&self) {
110 self.resource_cache.clear().await
111 }
112
113 pub async fn cache_keys(&self) -> Vec<String> {
115 self.resource_cache.keys().await
116 }
117
118 async fn get_service(
121 &self,
122 ) -> Result<
123 tokio::sync::MutexGuard<'_, Option<RunningService<RoleClient, ClientInfo>>>,
124 MCPClientError,
125 > {
126 let guard = self.running_service.lock().await;
127 if guard.is_none() {
128 return Err(MCPClientError::ConnectionError(
129 "Service not available".to_string(),
130 ));
131 }
132 Ok(guard)
133 }
134}
135
136#[async_trait]
137impl MCPClientProtocol for StdioMCPClient {
138 fn state(&self) -> ClientState {
139 self.base.state()
140 }
141
142 async fn connect(&self) -> Result<(), MCPClientError> {
143 if !self.base.can_connect().await {
144 return Err(MCPClientError::ConnectionError(format!(
145 "Cannot connect in state: {}",
146 self.base.get_state().await
147 )));
148 }
149
150 let params = &self.base.params;
151
152 let mut cmd = Command::new(¶ms.command);
153 cmd.args(¶ms.args);
154 for (key, value) in ¶ms.env {
155 cmd.env(key, value);
156 }
157 if let Some(cwd) = ¶ms.cwd {
158 cmd.current_dir(cwd);
159 }
160
161 debug!("Starting command: {} {:?}", params.command, params.args);
162
163 let (transport, stderr) = TokioChildProcess::builder(cmd)
164 .stderr(Stdio::piped())
165 .spawn()
166 .map_err(|e| {
167 MCPClientError::ConnectionError(format!("Failed to start process: {}", e))
168 })?;
169
170 *self.child_stderr.lock().await = stderr;
171
172 let client_info = ClientInfo {
173 protocol_version: Default::default(),
174 capabilities: Default::default(),
175 client_info: Implementation {
176 name: "a2c-smcp-rust".to_string(),
177 title: None,
178 version: env!("CARGO_PKG_VERSION").to_string(),
179 icons: None,
180 website_url: None,
181 },
182 };
183
184 let service = tokio::time::timeout(
185 Duration::from_secs(CONNECT_TIMEOUT_SECS),
186 client_info.serve(transport),
187 )
188 .await
189 .map_err(|_| {
190 MCPClientError::TimeoutError(format!(
191 "STDIO connect timed out after {}s",
192 CONNECT_TIMEOUT_SECS
193 ))
194 })?
195 .map_err(|e| MCPClientError::ConnectionError(format!("Initialize failed: {}", e)))?;
196
197 *self.running_service.lock().await = Some(service);
198 self.base.update_state(ClientState::Connected).await;
199 info!("STDIO client connected successfully");
200
201 Ok(())
202 }
203
204 async fn disconnect(&self) -> Result<(), MCPClientError> {
205 if !self.base.can_disconnect().await {
206 return Err(MCPClientError::ConnectionError(format!(
207 "Cannot disconnect in state: {}",
208 self.base.get_state().await
209 )));
210 }
211
212 let service = self.running_service.lock().await.take();
213 if let Some(service) = service {
214 match service.cancel().await {
215 Ok(reason) => {
216 debug!("Service stopped with reason: {:?}", reason);
217 }
218 Err(e) => {
219 error!("Error stopping service: {}", e);
220 }
221 }
222 }
223
224 *self.child_stderr.lock().await = None;
226
227 self.base.update_state(ClientState::Disconnected).await;
228 info!("STDIO client disconnected successfully");
229
230 Ok(())
231 }
232
233 async fn list_tools(&self) -> Result<Vec<Tool>, MCPClientError> {
234 if self.base.get_state().await != ClientState::Connected {
235 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
236 }
237
238 let guard = self.get_service().await?;
239 let service = guard.as_ref().unwrap();
240
241 let tools = service
242 .list_all_tools()
243 .await
244 .map_err(|e| MCPClientError::ProtocolError(format!("List tools error: {}", e)))?;
245
246 info!("Found {} tools", tools.len());
247 Ok(tools)
248 }
249
250 async fn call_tool(
251 &self,
252 tool_name: &str,
253 params: serde_json::Value,
254 ) -> Result<CallToolResult, MCPClientError> {
255 if self.base.get_state().await != ClientState::Connected {
256 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
257 }
258
259 let guard = self.get_service().await?;
260 let service = guard.as_ref().unwrap();
261
262 let result = service
263 .call_tool(CallToolRequestParam {
264 name: tool_name.to_string().into(),
265 arguments: params.as_object().cloned(),
266 })
267 .await
268 .map_err(|e| MCPClientError::ProtocolError(format!("Call tool error: {}", e)))?;
269
270 Ok(result)
271 }
272
273 async fn list_windows(&self) -> Result<Vec<Resource>, MCPClientError> {
274 if self.base.get_state().await != ClientState::Connected {
275 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
276 }
277
278 let guard = self.get_service().await?;
279 let service = guard.as_ref().unwrap();
280
281 let all_resources = service
282 .list_all_resources()
283 .await
284 .map_err(|e| MCPClientError::ProtocolError(format!("List resources error: {}", e)))?;
285
286 let mut filtered_resources: Vec<(Resource, i32)> = Vec::new();
288
289 for resource in all_resources {
290 if !is_window_uri(&resource.uri) {
291 continue;
292 }
293
294 let priority = if let Ok(uri) = WindowURI::new(&resource.uri) {
295 uri.priority().unwrap_or(0)
296 } else {
297 0
298 };
299
300 filtered_resources.push((resource, priority));
301 }
302
303 filtered_resources.sort_by(|a, b| b.1.cmp(&a.1));
304
305 Ok(filtered_resources.into_iter().map(|(r, _)| r).collect())
306 }
307
308 async fn get_window_detail(
309 &self,
310 resource: Resource,
311 ) -> Result<ReadResourceResult, MCPClientError> {
312 if self.base.get_state().await != ClientState::Connected {
313 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
314 }
315
316 let guard = self.get_service().await?;
317 let service = guard.as_ref().unwrap();
318
319 let result = service
320 .read_resource(ReadResourceRequestParam {
321 uri: resource.uri.clone(),
322 })
323 .await
324 .map_err(|e| MCPClientError::ProtocolError(format!("Read resource error: {}", e)))?;
325
326 Ok(result)
327 }
328
329 async fn subscribe_window(&self, resource: Resource) -> Result<(), MCPClientError> {
330 if self.base.get_state().await != ClientState::Connected {
331 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
332 }
333
334 let guard = self.get_service().await?;
335 let service = guard.as_ref().unwrap();
336
337 service
338 .subscribe(SubscribeRequestParam {
339 uri: resource.uri.clone(),
340 })
341 .await
342 .map_err(|e| {
343 MCPClientError::ProtocolError(format!("Subscribe resource error: {}", e))
344 })?;
345
346 drop(guard);
347
348 let _ = self
350 .subscription_manager
351 .add_subscription(resource.uri.clone())
352 .await;
353
354 match self.get_window_detail(resource.clone()).await {
356 Ok(result) => {
357 if !result.contents.is_empty() {
358 if let Ok(json_value) = serde_json::to_value(&result.contents[0]) {
359 self.resource_cache
360 .set(resource.uri.clone(), json_value, None)
361 .await;
362 info!("Subscribed and cached: {}", resource.uri);
363 }
364 }
365 }
366 Err(e) => {
367 warn!("Failed to fetch resource data after subscription: {:?}", e);
368 }
369 }
370
371 Ok(())
372 }
373
374 async fn unsubscribe_window(&self, resource: Resource) -> Result<(), MCPClientError> {
375 if self.base.get_state().await != ClientState::Connected {
376 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
377 }
378
379 let guard = self.get_service().await?;
380 let service = guard.as_ref().unwrap();
381
382 service
383 .unsubscribe(UnsubscribeRequestParam {
384 uri: resource.uri.clone(),
385 })
386 .await
387 .map_err(|e| {
388 MCPClientError::ProtocolError(format!("Unsubscribe resource error: {}", e))
389 })?;
390
391 drop(guard);
392
393 let _ = self
395 .subscription_manager
396 .remove_subscription(&resource.uri)
397 .await;
398
399 self.resource_cache.remove(&resource.uri).await;
401 info!("Unsubscribed and removed cache: {}", resource.uri);
402
403 Ok(())
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use std::collections::HashMap;
411
412 #[tokio::test]
413 async fn test_stdio_client_creation() {
414 let params = StdioServerParameters {
415 command: "echo".to_string(),
416 args: vec!["hello".to_string()],
417 env: HashMap::new(),
418 cwd: None,
419 };
420
421 let client = StdioMCPClient::new(params);
422 assert_eq!(client.state(), ClientState::Initialized);
423 assert_eq!(client.base.params.command, "echo");
424 }
425
426 #[tokio::test]
427 async fn test_stdio_client_with_env() {
428 let mut env = HashMap::new();
429 env.insert("TEST_VAR".to_string(), "test_value".to_string());
430 env.insert("PATH".to_string(), "/usr/bin".to_string());
431
432 let params = StdioServerParameters {
433 command: "echo".to_string(),
434 args: vec!["test".to_string()],
435 env,
436 cwd: Some("/tmp".to_string()),
437 };
438
439 let client = StdioMCPClient::new(params);
440 assert_eq!(
441 client.base.params.env.get("TEST_VAR"),
442 Some(&"test_value".to_string())
443 );
444 assert_eq!(client.base.params.cwd, Some("/tmp".to_string()));
445 }
446
447 #[tokio::test]
448 async fn test_connect_state_checks() {
449 let params = StdioServerParameters {
450 command: "echo".to_string(),
451 args: vec!["test".to_string()],
452 env: HashMap::new(),
453 cwd: None,
454 };
455
456 let client = StdioMCPClient::new(params);
457
458 client.base.update_state(ClientState::Connected).await;
460 let result = client.connect().await;
461 assert!(result.is_err());
462 assert!(matches!(
463 result.unwrap_err(),
464 MCPClientError::ConnectionError(_)
465 ));
466 }
467
468 #[tokio::test]
469 async fn test_disconnect_state_checks() {
470 let params = StdioServerParameters {
471 command: "echo".to_string(),
472 args: vec!["test".to_string()],
473 env: HashMap::new(),
474 cwd: None,
475 };
476
477 let client = StdioMCPClient::new(params);
478
479 let result = client.disconnect().await;
481 assert!(result.is_err());
482 assert!(matches!(
483 result.unwrap_err(),
484 MCPClientError::ConnectionError(_)
485 ));
486 }
487
488 #[tokio::test]
489 async fn test_list_tools_requires_connection() {
490 let params = StdioServerParameters {
491 command: "echo".to_string(),
492 args: vec!["test".to_string()],
493 env: HashMap::new(),
494 cwd: None,
495 };
496
497 let client = StdioMCPClient::new(params);
498
499 let result = client.list_tools().await;
500 assert!(result.is_err());
501 assert!(matches!(
502 result.unwrap_err(),
503 MCPClientError::ConnectionError(_)
504 ));
505 }
506
507 #[tokio::test]
508 async fn test_call_tool_requires_connection() {
509 let params = StdioServerParameters {
510 command: "echo".to_string(),
511 args: vec!["test".to_string()],
512 env: HashMap::new(),
513 cwd: None,
514 };
515
516 let client = StdioMCPClient::new(params);
517
518 let result = client.call_tool("test_tool", serde_json::json!({})).await;
519 assert!(result.is_err());
520 assert!(matches!(
521 result.unwrap_err(),
522 MCPClientError::ConnectionError(_)
523 ));
524 }
525
526 #[tokio::test]
527 async fn test_list_windows_requires_connection() {
528 let params = StdioServerParameters {
529 command: "echo".to_string(),
530 args: vec!["test".to_string()],
531 env: HashMap::new(),
532 cwd: None,
533 };
534
535 let client = StdioMCPClient::new(params);
536
537 let result = client.list_windows().await;
538 assert!(result.is_err());
539 assert!(matches!(
540 result.unwrap_err(),
541 MCPClientError::ConnectionError(_)
542 ));
543 }
544
545 #[tokio::test]
546 async fn test_get_window_detail_requires_connection() {
547 let params = StdioServerParameters {
548 command: "echo".to_string(),
549 args: vec!["test".to_string()],
550 env: HashMap::new(),
551 cwd: None,
552 };
553
554 let client = StdioMCPClient::new(params);
555
556 let resource = make_resource("window://123", "Test Window", None, None);
557
558 let result = client.get_window_detail(resource).await;
559 assert!(result.is_err());
560 assert!(matches!(
561 result.unwrap_err(),
562 MCPClientError::ConnectionError(_)
563 ));
564 }
565
566 #[tokio::test]
567 async fn test_disconnect_cleanup() {
568 let params = StdioServerParameters {
569 command: "echo".to_string(),
570 args: vec!["test".to_string()],
571 env: HashMap::new(),
572 cwd: None,
573 };
574
575 let client = StdioMCPClient::new(params);
576
577 client.base.update_state(ClientState::Connected).await;
579
580 let _ = client.disconnect().await;
582
583 let guard = client.running_service.lock().await;
585 assert!(guard.is_none());
586 drop(guard);
587
588 assert_eq!(client.base.get_state().await, ClientState::Disconnected);
590 }
591
592 #[tokio::test]
593 async fn test_stdio_client_debug_format() {
594 let params = StdioServerParameters {
595 command: "echo".to_string(),
596 args: vec!["test".to_string()],
597 env: HashMap::new(),
598 cwd: None,
599 };
600
601 let client = StdioMCPClient::new(params);
602
603 let debug_str = format!("{:?}", client);
604 assert!(debug_str.contains("StdioMCPClient"));
605 }
606}