Skip to main content

smcp_computer/mcp_clients/
stdio_client.rs

1/**
2* 文件名: stdio_client
3* 作者: JQQ
4* 创建日期: 2025/12/15
5* 最后修改日期: 2025/12/15
6* 版权: 2023 JQQ. All rights reserved.
7* 依赖: tokio, rmcp
8* 描述: STDIO类型的MCP客户端实现,委托 rmcp SDK
9*/
10use super::base_client::BaseMCPClient;
11use super::model::*;
12use super::{ResourceCache, SubscriptionManager};
13use crate::desktop::window_uri::{is_window_uri, WindowURI};
14use async_trait::async_trait;
15use rmcp::model::{
16    CallToolRequestParam, ClientInfo, Implementation, ReadResourceRequestParam,
17    SubscribeRequestParam, UnsubscribeRequestParam,
18};
19use rmcp::service::{RunningService, ServiceExt};
20use rmcp::transport::TokioChildProcess;
21use rmcp::RoleClient;
22use std::process::Stdio;
23use std::sync::Arc;
24use std::time::Duration;
25use tokio::process::{ChildStderr, Command};
26use tokio::sync::Mutex;
27use tracing::{debug, error, info, warn};
28
29/// STDIO 客户端连接超时时间(秒)
30/// Connect timeout for STDIO client (seconds)
31const CONNECT_TIMEOUT_SECS: u64 = 30;
32
33/// STDIO MCP客户端 / STDIO MCP client
34pub struct StdioMCPClient {
35    /// 基础客户端 / Base client
36    base: BaseMCPClient<StdioServerParameters>,
37    /// rmcp 运行服务 / rmcp running service
38    running_service: Arc<Mutex<Option<RunningService<RoleClient, ClientInfo>>>>,
39    /// 子进程 stderr / Child process stderr
40    child_stderr: Arc<Mutex<Option<ChildStderr>>>,
41    /// 订阅管理器 / Subscription manager
42    subscription_manager: SubscriptionManager,
43    /// 资源缓存 / Resource cache
44    resource_cache: ResourceCache,
45}
46
47impl std::fmt::Debug for StdioMCPClient {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        f.debug_struct("StdioMCPClient")
50            .field("command", &self.base.params.command)
51            .field("args", &self.base.params.args)
52            .field("state", &self.base.state())
53            .finish()
54    }
55}
56
57impl StdioMCPClient {
58    /// 创建新的STDIO客户端 / Create new STDIO client
59    pub fn new(params: StdioServerParameters) -> Self {
60        Self {
61            base: BaseMCPClient::new(params),
62            running_service: Arc::new(Mutex::new(None)),
63            child_stderr: Arc::new(Mutex::new(None)),
64            subscription_manager: SubscriptionManager::new(),
65            resource_cache: ResourceCache::new(Duration::from_secs(60)),
66        }
67    }
68
69    // ========== 订阅管理 API / Subscription Management API ==========
70
71    /// 检查是否已订阅指定资源
72    pub async fn is_subscribed(&self, uri: &str) -> bool {
73        self.subscription_manager.is_subscribed(uri).await
74    }
75
76    /// 获取所有订阅的 URI 列表
77    pub async fn get_subscriptions(&self) -> Vec<String> {
78        self.subscription_manager.get_subscriptions().await
79    }
80
81    /// 获取订阅数量
82    pub async fn subscription_count(&self) -> usize {
83        self.subscription_manager.subscription_count().await
84    }
85
86    // ========== 资源缓存 API / Resource Cache API ==========
87
88    /// 获取缓存的资源数据
89    pub async fn get_cached_resource(&self, uri: &str) -> Option<serde_json::Value> {
90        self.resource_cache.get(uri).await
91    }
92
93    /// 检查资源是否已缓存
94    pub async fn has_cache(&self, uri: &str) -> bool {
95        self.resource_cache.contains(uri).await
96    }
97
98    /// 获取缓存大小
99    pub async fn cache_size(&self) -> usize {
100        self.resource_cache.size().await
101    }
102
103    /// 清理过期的缓存
104    pub async fn cleanup_cache(&self) -> usize {
105        self.resource_cache.cleanup_expired().await
106    }
107
108    /// 清空所有缓存
109    pub async fn clear_cache(&self) {
110        self.resource_cache.clear().await
111    }
112
113    /// 获取所有缓存的 URI 列表
114    pub async fn cache_keys(&self) -> Vec<String> {
115        self.resource_cache.keys().await
116    }
117
118    /// 获取 running service 的 guard,验证 service 可用
119    /// Get running service guard, verifying service is available
120    async fn get_service(
121        &self,
122    ) -> Result<
123        tokio::sync::MutexGuard<'_, Option<RunningService<RoleClient, ClientInfo>>>,
124        MCPClientError,
125    > {
126        let guard = self.running_service.lock().await;
127        if guard.is_none() {
128            return Err(MCPClientError::ConnectionError(
129                "Service not available".to_string(),
130            ));
131        }
132        Ok(guard)
133    }
134}
135
136#[async_trait]
137impl MCPClientProtocol for StdioMCPClient {
138    fn state(&self) -> ClientState {
139        self.base.state()
140    }
141
142    async fn connect(&self) -> Result<(), MCPClientError> {
143        if !self.base.can_connect().await {
144            return Err(MCPClientError::ConnectionError(format!(
145                "Cannot connect in state: {}",
146                self.base.get_state().await
147            )));
148        }
149
150        let params = &self.base.params;
151
152        let mut cmd = Command::new(&params.command);
153        cmd.args(&params.args);
154        for (key, value) in &params.env {
155            cmd.env(key, value);
156        }
157        if let Some(cwd) = &params.cwd {
158            cmd.current_dir(cwd);
159        }
160
161        debug!("Starting command: {} {:?}", params.command, params.args);
162
163        let (transport, stderr) = TokioChildProcess::builder(cmd)
164            .stderr(Stdio::piped())
165            .spawn()
166            .map_err(|e| {
167                MCPClientError::ConnectionError(format!("Failed to start process: {}", e))
168            })?;
169
170        *self.child_stderr.lock().await = stderr;
171
172        let client_info = ClientInfo {
173            protocol_version: Default::default(),
174            capabilities: Default::default(),
175            client_info: Implementation {
176                name: "a2c-smcp-rust".to_string(),
177                title: None,
178                version: env!("CARGO_PKG_VERSION").to_string(),
179                icons: None,
180                website_url: None,
181            },
182        };
183
184        let service = tokio::time::timeout(
185            Duration::from_secs(CONNECT_TIMEOUT_SECS),
186            client_info.serve(transport),
187        )
188        .await
189        .map_err(|_| {
190            MCPClientError::TimeoutError(format!(
191                "STDIO connect timed out after {}s",
192                CONNECT_TIMEOUT_SECS
193            ))
194        })?
195        .map_err(|e| MCPClientError::ConnectionError(format!("Initialize failed: {}", e)))?;
196
197        *self.running_service.lock().await = Some(service);
198        self.base.update_state(ClientState::Connected).await;
199        info!("STDIO client connected successfully");
200
201        Ok(())
202    }
203
204    async fn disconnect(&self) -> Result<(), MCPClientError> {
205        if !self.base.can_disconnect().await {
206            return Err(MCPClientError::ConnectionError(format!(
207                "Cannot disconnect in state: {}",
208                self.base.get_state().await
209            )));
210        }
211
212        let service = self.running_service.lock().await.take();
213        if let Some(service) = service {
214            match service.cancel().await {
215                Ok(reason) => {
216                    debug!("Service stopped with reason: {:?}", reason);
217                }
218                Err(e) => {
219                    error!("Error stopping service: {}", e);
220                }
221            }
222        }
223
224        // 清理 stderr handle
225        *self.child_stderr.lock().await = None;
226
227        self.base.update_state(ClientState::Disconnected).await;
228        info!("STDIO client disconnected successfully");
229
230        Ok(())
231    }
232
233    async fn list_tools(&self) -> Result<Vec<Tool>, MCPClientError> {
234        if self.base.get_state().await != ClientState::Connected {
235            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
236        }
237
238        let guard = self.get_service().await?;
239        let service = guard.as_ref().unwrap();
240
241        let tools = service
242            .list_all_tools()
243            .await
244            .map_err(|e| MCPClientError::ProtocolError(format!("List tools error: {}", e)))?;
245
246        info!("Found {} tools", tools.len());
247        Ok(tools)
248    }
249
250    async fn call_tool(
251        &self,
252        tool_name: &str,
253        params: serde_json::Value,
254    ) -> Result<CallToolResult, MCPClientError> {
255        if self.base.get_state().await != ClientState::Connected {
256            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
257        }
258
259        let guard = self.get_service().await?;
260        let service = guard.as_ref().unwrap();
261
262        let result = service
263            .call_tool(CallToolRequestParam {
264                name: tool_name.to_string().into(),
265                arguments: params.as_object().cloned(),
266            })
267            .await
268            .map_err(|e| MCPClientError::ProtocolError(format!("Call tool error: {}", e)))?;
269
270        Ok(result)
271    }
272
273    async fn list_windows(&self) -> Result<Vec<Resource>, MCPClientError> {
274        if self.base.get_state().await != ClientState::Connected {
275            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
276        }
277
278        let guard = self.get_service().await?;
279        let service = guard.as_ref().unwrap();
280
281        let all_resources = service
282            .list_all_resources()
283            .await
284            .map_err(|e| MCPClientError::ProtocolError(format!("List resources error: {}", e)))?;
285
286        // 过滤 window:// 资源并按 priority 排序
287        let mut filtered_resources: Vec<(Resource, i32)> = Vec::new();
288
289        for resource in all_resources {
290            if !is_window_uri(&resource.uri) {
291                continue;
292            }
293
294            let priority = if let Ok(uri) = WindowURI::new(&resource.uri) {
295                uri.priority().unwrap_or(0)
296            } else {
297                0
298            };
299
300            filtered_resources.push((resource, priority));
301        }
302
303        filtered_resources.sort_by(|a, b| b.1.cmp(&a.1));
304
305        Ok(filtered_resources.into_iter().map(|(r, _)| r).collect())
306    }
307
308    async fn get_window_detail(
309        &self,
310        resource: Resource,
311    ) -> Result<ReadResourceResult, MCPClientError> {
312        if self.base.get_state().await != ClientState::Connected {
313            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
314        }
315
316        let guard = self.get_service().await?;
317        let service = guard.as_ref().unwrap();
318
319        let result = service
320            .read_resource(ReadResourceRequestParam {
321                uri: resource.uri.clone(),
322            })
323            .await
324            .map_err(|e| MCPClientError::ProtocolError(format!("Read resource error: {}", e)))?;
325
326        Ok(result)
327    }
328
329    async fn subscribe_window(&self, resource: Resource) -> Result<(), MCPClientError> {
330        if self.base.get_state().await != ClientState::Connected {
331            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
332        }
333
334        let guard = self.get_service().await?;
335        let service = guard.as_ref().unwrap();
336
337        service
338            .subscribe(SubscribeRequestParam {
339                uri: resource.uri.clone(),
340            })
341            .await
342            .map_err(|e| {
343                MCPClientError::ProtocolError(format!("Subscribe resource error: {}", e))
344            })?;
345
346        drop(guard);
347
348        // 订阅成功后,更新本地订阅状态
349        let _ = self
350            .subscription_manager
351            .add_subscription(resource.uri.clone())
352            .await;
353
354        // 立即获取并缓存资源数据
355        match self.get_window_detail(resource.clone()).await {
356            Ok(result) => {
357                if !result.contents.is_empty() {
358                    if let Ok(json_value) = serde_json::to_value(&result.contents[0]) {
359                        self.resource_cache
360                            .set(resource.uri.clone(), json_value, None)
361                            .await;
362                        info!("Subscribed and cached: {}", resource.uri);
363                    }
364                }
365            }
366            Err(e) => {
367                warn!("Failed to fetch resource data after subscription: {:?}", e);
368            }
369        }
370
371        Ok(())
372    }
373
374    async fn unsubscribe_window(&self, resource: Resource) -> Result<(), MCPClientError> {
375        if self.base.get_state().await != ClientState::Connected {
376            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
377        }
378
379        let guard = self.get_service().await?;
380        let service = guard.as_ref().unwrap();
381
382        service
383            .unsubscribe(UnsubscribeRequestParam {
384                uri: resource.uri.clone(),
385            })
386            .await
387            .map_err(|e| {
388                MCPClientError::ProtocolError(format!("Unsubscribe resource error: {}", e))
389            })?;
390
391        drop(guard);
392
393        // 取消订阅成功后,移除本地订阅状态
394        let _ = self
395            .subscription_manager
396            .remove_subscription(&resource.uri)
397            .await;
398
399        // 清理缓存
400        self.resource_cache.remove(&resource.uri).await;
401        info!("Unsubscribed and removed cache: {}", resource.uri);
402
403        Ok(())
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use std::collections::HashMap;
411
412    #[tokio::test]
413    async fn test_stdio_client_creation() {
414        let params = StdioServerParameters {
415            command: "echo".to_string(),
416            args: vec!["hello".to_string()],
417            env: HashMap::new(),
418            cwd: None,
419        };
420
421        let client = StdioMCPClient::new(params);
422        assert_eq!(client.state(), ClientState::Initialized);
423        assert_eq!(client.base.params.command, "echo");
424    }
425
426    #[tokio::test]
427    async fn test_stdio_client_with_env() {
428        let mut env = HashMap::new();
429        env.insert("TEST_VAR".to_string(), "test_value".to_string());
430        env.insert("PATH".to_string(), "/usr/bin".to_string());
431
432        let params = StdioServerParameters {
433            command: "echo".to_string(),
434            args: vec!["test".to_string()],
435            env,
436            cwd: Some("/tmp".to_string()),
437        };
438
439        let client = StdioMCPClient::new(params);
440        assert_eq!(
441            client.base.params.env.get("TEST_VAR"),
442            Some(&"test_value".to_string())
443        );
444        assert_eq!(client.base.params.cwd, Some("/tmp".to_string()));
445    }
446
447    #[tokio::test]
448    async fn test_connect_state_checks() {
449        let params = StdioServerParameters {
450            command: "echo".to_string(),
451            args: vec!["test".to_string()],
452            env: HashMap::new(),
453            cwd: None,
454        };
455
456        let client = StdioMCPClient::new(params);
457
458        // 在已连接状态下尝试连接应该失败
459        client.base.update_state(ClientState::Connected).await;
460        let result = client.connect().await;
461        assert!(result.is_err());
462        assert!(matches!(
463            result.unwrap_err(),
464            MCPClientError::ConnectionError(_)
465        ));
466    }
467
468    #[tokio::test]
469    async fn test_disconnect_state_checks() {
470        let params = StdioServerParameters {
471            command: "echo".to_string(),
472            args: vec!["test".to_string()],
473            env: HashMap::new(),
474            cwd: None,
475        };
476
477        let client = StdioMCPClient::new(params);
478
479        // 在未连接状态下尝试断开应该失败
480        let result = client.disconnect().await;
481        assert!(result.is_err());
482        assert!(matches!(
483            result.unwrap_err(),
484            MCPClientError::ConnectionError(_)
485        ));
486    }
487
488    #[tokio::test]
489    async fn test_list_tools_requires_connection() {
490        let params = StdioServerParameters {
491            command: "echo".to_string(),
492            args: vec!["test".to_string()],
493            env: HashMap::new(),
494            cwd: None,
495        };
496
497        let client = StdioMCPClient::new(params);
498
499        let result = client.list_tools().await;
500        assert!(result.is_err());
501        assert!(matches!(
502            result.unwrap_err(),
503            MCPClientError::ConnectionError(_)
504        ));
505    }
506
507    #[tokio::test]
508    async fn test_call_tool_requires_connection() {
509        let params = StdioServerParameters {
510            command: "echo".to_string(),
511            args: vec!["test".to_string()],
512            env: HashMap::new(),
513            cwd: None,
514        };
515
516        let client = StdioMCPClient::new(params);
517
518        let result = client.call_tool("test_tool", serde_json::json!({})).await;
519        assert!(result.is_err());
520        assert!(matches!(
521            result.unwrap_err(),
522            MCPClientError::ConnectionError(_)
523        ));
524    }
525
526    #[tokio::test]
527    async fn test_list_windows_requires_connection() {
528        let params = StdioServerParameters {
529            command: "echo".to_string(),
530            args: vec!["test".to_string()],
531            env: HashMap::new(),
532            cwd: None,
533        };
534
535        let client = StdioMCPClient::new(params);
536
537        let result = client.list_windows().await;
538        assert!(result.is_err());
539        assert!(matches!(
540            result.unwrap_err(),
541            MCPClientError::ConnectionError(_)
542        ));
543    }
544
545    #[tokio::test]
546    async fn test_get_window_detail_requires_connection() {
547        let params = StdioServerParameters {
548            command: "echo".to_string(),
549            args: vec!["test".to_string()],
550            env: HashMap::new(),
551            cwd: None,
552        };
553
554        let client = StdioMCPClient::new(params);
555
556        let resource = make_resource("window://123", "Test Window", None, None);
557
558        let result = client.get_window_detail(resource).await;
559        assert!(result.is_err());
560        assert!(matches!(
561            result.unwrap_err(),
562            MCPClientError::ConnectionError(_)
563        ));
564    }
565
566    #[tokio::test]
567    async fn test_disconnect_cleanup() {
568        let params = StdioServerParameters {
569            command: "echo".to_string(),
570            args: vec!["test".to_string()],
571            env: HashMap::new(),
572            cwd: None,
573        };
574
575        let client = StdioMCPClient::new(params);
576
577        // 设置为已连接状态
578        client.base.update_state(ClientState::Connected).await;
579
580        // 断开连接
581        let _ = client.disconnect().await;
582
583        // 验证 running_service 被清理
584        let guard = client.running_service.lock().await;
585        assert!(guard.is_none());
586        drop(guard);
587
588        // 验证状态变为已断开
589        assert_eq!(client.base.get_state().await, ClientState::Disconnected);
590    }
591
592    #[tokio::test]
593    async fn test_stdio_client_debug_format() {
594        let params = StdioServerParameters {
595            command: "echo".to_string(),
596            args: vec!["test".to_string()],
597            env: HashMap::new(),
598            cwd: None,
599        };
600
601        let client = StdioMCPClient::new(params);
602
603        let debug_str = format!("{:?}", client);
604        assert!(debug_str.contains("StdioMCPClient"));
605    }
606}