Skip to main content

smcp_computer/mcp_clients/
base_client.rs

1/**
2* 文件名: base_client
3* 作者: JQQ
4* 创建日期: 2025/12/15
5* 最后修改日期: 2025/12/15
6* 版权: 2023 JQQ. All rights reserved.
7* 依赖: tokio, async-trait, serde_json
8* 描述: MCP客户端基础抽象类,提供状态管理和会话生命周期管理
9*/
10use super::model::*;
11use crate::errors::ComputerError;
12use async_trait::async_trait;
13use std::sync::Arc;
14use tokio::sync::{watch, Mutex, RwLock};
15use tokio::task::JoinHandle;
16use tokio::time::{timeout, Duration};
17use tracing::{debug, error, info, warn};
18
19/// MCP客户端基础实现 / Base MCP client implementation
20pub struct BaseMCPClient<P> {
21    /// 服务器参数 / Server parameters
22    pub params: P,
23    /// 当前状态 / Current state
24    state: Arc<RwLock<ClientState>>,
25    /// 状态变化通知 / State change notification
26    state_notifier: watch::Sender<ClientState>,
27    /// 会话保持任务句柄 / Session keep-alive task handle
28    keep_alive_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
29    /// 关闭信号 / Shutdown signal
30    shutdown_tx: Arc<Mutex<Option<watch::Sender<bool>>>>,
31    /// 状态变化回调 / State change callback
32    state_change_callback: Option<Box<dyn Fn(ClientState, ClientState) + Send + Sync>>,
33}
34
35impl<P> BaseMCPClient<P>
36where
37    P: Send + Sync + 'static + std::clone::Clone,
38{
39    /// 创建新的基础客户端 / Create new base client
40    pub fn new(params: P) -> Self {
41        let (state_tx, _) = watch::channel(ClientState::Initialized);
42        let state = Arc::new(RwLock::new(ClientState::Initialized));
43        let (shutdown_tx, _) = watch::channel(false);
44
45        Self {
46            params,
47            state,
48            state_notifier: state_tx,
49            keep_alive_handle: Arc::new(Mutex::new(None)),
50            shutdown_tx: Arc::new(Mutex::new(Some(shutdown_tx))),
51            state_change_callback: None,
52        }
53    }
54
55    /// 设置状态变化回调 / Set state change callback
56    pub fn set_state_change_callback<F>(&mut self, callback: F)
57    where
58        F: Fn(ClientState, ClientState) + Send + Sync + 'static,
59    {
60        self.state_change_callback = Some(Box::new(callback));
61    }
62
63    /// 获取当前状态 / Get current state
64    pub async fn get_state(&self) -> ClientState {
65        *self.state.read().await
66    }
67
68    /// 获取状态变化通知器 / Get state change notifier
69    pub fn get_state_notifier(&self) -> watch::Receiver<ClientState> {
70        self.state_notifier.subscribe()
71    }
72
73    /// 更新状态 / Update state
74    pub async fn update_state(&self, new_state: ClientState) {
75        let mut state = self.state.write().await;
76        let old_state = *state;
77        *state = new_state;
78
79        // 通知状态变化 / Notify state change
80        let _ = self.state_notifier.send(new_state);
81
82        // 调用回调 / Call callback
83        if let Some(ref callback) = self.state_change_callback {
84            callback(old_state, new_state);
85        }
86
87        debug!("State transition: {} -> {}", old_state, new_state);
88    }
89
90    /// 启动会话保持任务 / Start session keep-alive task
91    #[allow(dead_code)]
92    async fn start_keep_alive<T>(&self, session_creator: impl Fn(P) -> T + Send + Sync + 'static)
93    where
94        T: std::future::Future<Output = Result<(), MCPClientError>> + Send + 'static,
95    {
96        let params = self.params.clone();
97        let mut shutdown_rx = self.create_shutdown_receiver().await;
98        let state = self.state.clone();
99
100        let handle = tokio::spawn(async move {
101            debug!("Session keep-alive task started");
102
103            // 创建会话 / Create session
104            let session_future = session_creator(params);
105
106            tokio::select! {
107                result = session_future => {
108                    match result {
109                        Ok(_) => {
110                            debug!("Session completed successfully");
111                            { *state.write().await = ClientState::Disconnected; }
112                        }
113                        Err(e) => {
114                            error!("Session failed: {}", e);
115                            { *state.write().await = ClientState::Error; }
116                        }
117                    }
118                }
119                shutdown_rx = shutdown_rx.changed() => {
120                    if shutdown_rx.is_ok() {
121                        debug!("Session keep-alive task received shutdown signal");
122                    }
123                }
124            }
125
126            debug!("Session keep-alive task ended");
127        });
128
129        *self.keep_alive_handle.lock().await = Some(handle);
130    }
131
132    /// 停止会话保持任务 / Stop session keep-alive task
133    async fn stop_keep_alive(&self) -> Result<(), ComputerError> {
134        // 发送关闭信号 / Send shutdown signal
135        let mut shutdown_tx = self.shutdown_tx.lock().await;
136        if let Some(tx) = shutdown_tx.take() {
137            let _ = tx.send(true);
138        }
139
140        // 等待任务结束 / Wait for task to end
141        let mut handle = self.keep_alive_handle.lock().await;
142        if let Some(h) = handle.take() {
143            match h.await {
144                Ok(_) => debug!("Keep-alive task stopped successfully"),
145                Err(e) => warn!("Keep-alive task stopped with error: {}", e),
146            }
147        }
148
149        // 重新创建关闭信号 / Recreate shutdown signal
150        let (tx, _) = watch::channel(false);
151        *shutdown_tx = Some(tx);
152
153        Ok(())
154    }
155
156    /// 创建关闭信号接收器 / Create shutdown signal receiver
157    #[allow(dead_code)]
158    async fn create_shutdown_receiver(&self) -> watch::Receiver<bool> {
159        let shutdown_tx = self.shutdown_tx.lock().await;
160        shutdown_tx.as_ref().unwrap().subscribe()
161    }
162
163    /// 检查是否可以连接 / Check if can connect
164    pub async fn can_connect(&self) -> bool {
165        matches!(
166            self.get_state().await,
167            ClientState::Initialized | ClientState::Disconnected
168        )
169    }
170
171    /// 检查是否可以断开 / Check if can disconnect
172    pub async fn can_disconnect(&self) -> bool {
173        matches!(self.get_state().await, ClientState::Connected)
174    }
175
176    /// 执行带超时的操作 / Execute operation with timeout
177    #[allow(dead_code)]
178    async fn execute_with_timeout<F, T>(
179        &self,
180        future: F,
181        timeout_secs: u64,
182    ) -> Result<T, MCPClientError>
183    where
184        F: std::future::Future<Output = Result<T, MCPClientError>>,
185    {
186        match timeout(Duration::from_secs(timeout_secs), future).await {
187            Ok(result) => result,
188            Err(_) => Err(MCPClientError::TimeoutError(format!(
189                "Operation timed out after {} seconds",
190                timeout_secs
191            ))),
192        }
193    }
194}
195
196#[async_trait]
197impl<P> MCPClientProtocol for BaseMCPClient<P>
198where
199    P: Send + Sync + Clone + 'static,
200{
201    fn state(&self) -> ClientState {
202        // 使用 try_read 避免阻塞
203        if let Ok(state_guard) = self.state.try_read() {
204            *state_guard
205        } else {
206            // 如果锁被占用,返回一个默认值或尝试阻塞读取
207            // 在测试环境中,我们通常可以假设锁不会被长时间占用
208            tokio::task::block_in_place(|| {
209                tokio::runtime::Handle::current().block_on(async { self.get_state().await })
210            })
211        }
212    }
213
214    async fn connect(&self) -> Result<(), MCPClientError> {
215        if !self.can_connect().await {
216            return Err(MCPClientError::ConnectionError(format!(
217                "Cannot connect in state: {}",
218                self.get_state().await
219            )));
220        }
221
222        self.update_state(ClientState::Connected).await;
223        info!("Connected successfully");
224        Ok(())
225    }
226
227    async fn disconnect(&self) -> Result<(), MCPClientError> {
228        if !self.can_disconnect().await {
229            return Err(MCPClientError::ConnectionError(format!(
230                "Cannot disconnect in state: {}",
231                self.get_state().await
232            )));
233        }
234
235        self.stop_keep_alive()
236            .await
237            .map_err(|e| MCPClientError::Other(e.to_string()))?;
238        self.update_state(ClientState::Disconnected).await;
239        info!("Disconnected successfully");
240        Ok(())
241    }
242
243    async fn list_tools(&self) -> Result<Vec<Tool>, MCPClientError> {
244        if self.get_state().await != ClientState::Connected {
245            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
246        }
247        // 基础实现返回空列表,子类需要重写
248        // Base implementation returns empty list, subclasses need to override
249        Ok(vec![])
250    }
251
252    async fn call_tool(
253        &self,
254        _tool_name: &str,
255        _params: serde_json::Value,
256    ) -> Result<CallToolResult, MCPClientError> {
257        if self.get_state().await != ClientState::Connected {
258            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
259        }
260        // 基础实现返回错误,子类需要重写
261        // Base implementation returns error, subclasses need to override
262        Err(MCPClientError::ProtocolError("Not implemented".to_string()))
263    }
264
265    async fn list_windows(&self) -> Result<Vec<Resource>, MCPClientError> {
266        if self.get_state().await != ClientState::Connected {
267            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
268        }
269        // 基础实现返回空列表,子类需要重写
270        // Base implementation returns empty list, subclasses need to override
271        Ok(vec![])
272    }
273
274    async fn get_window_detail(
275        &self,
276        _resource: Resource,
277    ) -> Result<ReadResourceResult, MCPClientError> {
278        if self.get_state().await != ClientState::Connected {
279            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
280        }
281        // 基础实现返回错误,子类需要重写
282        // Base implementation returns error, subclasses need to override
283        Err(MCPClientError::ProtocolError("Not implemented".to_string()))
284    }
285
286    async fn subscribe_window(&self, _resource: Resource) -> Result<(), MCPClientError> {
287        if self.get_state().await != ClientState::Connected {
288            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
289        }
290        // 基础实现返回错误,子类需要重写
291        // Base implementation returns error, subclasses need to override
292        Err(MCPClientError::ProtocolError("Not implemented".to_string()))
293    }
294
295    async fn unsubscribe_window(&self, _resource: Resource) -> Result<(), MCPClientError> {
296        if self.get_state().await != ClientState::Connected {
297            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
298        }
299        // 基础实现返回错误,子类需要重写
300        // Base implementation returns error, subclasses need to override
301        Err(MCPClientError::ProtocolError("Not implemented".to_string()))
302    }
303}
304
305/// 客户端状态机 / Client state machine
306#[derive(Debug, Clone, Copy, PartialEq, Eq)]
307pub enum StateTransition {
308    InitializeToConnected,
309    ConnectedToDisconnected,
310    AnyToError,
311    ErrorToInitialized,
312}
313
314impl StateTransition {
315    /// 检查状态转换是否有效 / Check if state transition is valid
316    pub fn is_valid(from: ClientState, to: ClientState) -> bool {
317        matches!(
318            (from, to),
319            (ClientState::Initialized, ClientState::Connected)
320                | (ClientState::Connected, ClientState::Disconnected)
321                | (_, ClientState::Error)
322                | (ClientState::Error, ClientState::Initialized)
323                | (ClientState::Disconnected, ClientState::Connected)
324                | (ClientState::Disconnected, ClientState::Initialized)
325        )
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use std::sync::atomic::{AtomicUsize, Ordering};
333    use tokio::time::{sleep, Duration};
334
335    #[tokio::test]
336    async fn test_state_transition_validity() {
337        assert!(StateTransition::is_valid(
338            ClientState::Initialized,
339            ClientState::Connected
340        ));
341        assert!(StateTransition::is_valid(
342            ClientState::Connected,
343            ClientState::Disconnected
344        ));
345        assert!(StateTransition::is_valid(
346            ClientState::Connected,
347            ClientState::Error
348        ));
349        assert!(StateTransition::is_valid(
350            ClientState::Error,
351            ClientState::Initialized
352        ));
353        assert!(!StateTransition::is_valid(
354            ClientState::Connected,
355            ClientState::Initialized
356        ));
357    }
358
359    #[tokio::test]
360    async fn test_base_client_state_management() {
361        let client = BaseMCPClient::new("test");
362        assert_eq!(client.get_state().await, ClientState::Initialized);
363
364        // Test state change notification
365        let mut rx = client.get_state_notifier();
366        assert_eq!(*rx.borrow_and_update(), ClientState::Initialized);
367    }
368
369    #[tokio::test]
370    async fn test_state_change_callback() {
371        let mut client = BaseMCPClient::new("test");
372        let call_count = Arc::new(AtomicUsize::new(0));
373        let call_count_clone = call_count.clone();
374
375        client.set_state_change_callback(move |from, to| {
376            call_count_clone.fetch_add(1, Ordering::SeqCst);
377            println!("State changed from {} to {}", from, to);
378        });
379
380        // 触发状态变化 / Trigger state change
381        client.update_state(ClientState::Connected).await;
382        assert_eq!(client.get_state().await, ClientState::Connected);
383        assert_eq!(call_count.load(Ordering::SeqCst), 1);
384
385        // 再次触发状态变化 / Trigger another state change
386        client.update_state(ClientState::Disconnected).await;
387        assert_eq!(client.get_state().await, ClientState::Disconnected);
388        assert_eq!(call_count.load(Ordering::SeqCst), 2);
389    }
390
391    #[tokio::test]
392    async fn test_can_connect() {
393        let client = BaseMCPClient::new("test");
394
395        // 初始状态可以连接 / Can connect in initial state
396        assert!(client.can_connect().await);
397
398        // 连接后不能再次连接 / Cannot connect after connected
399        client.update_state(ClientState::Connected).await;
400        assert!(!client.can_connect().await);
401
402        // 断开后可以重新连接 / Can reconnect after disconnect
403        client.update_state(ClientState::Disconnected).await;
404        assert!(client.can_connect().await);
405
406        // 错误状态下不能连接 / Cannot connect in error state
407        client.update_state(ClientState::Error).await;
408        assert!(!client.can_connect().await);
409    }
410
411    #[tokio::test]
412    async fn test_can_disconnect() {
413        let client = BaseMCPClient::new("test");
414
415        // 初始状态不能断开 / Cannot disconnect in initial state
416        assert!(!client.can_disconnect().await);
417
418        // 连接后可以断开 / Can disconnect after connected
419        client.update_state(ClientState::Connected).await;
420        assert!(client.can_disconnect().await);
421
422        // 断开后不能再次断开 / Cannot disconnect after disconnected
423        client.update_state(ClientState::Disconnected).await;
424        assert!(!client.can_disconnect().await);
425    }
426
427    #[tokio::test]
428    async fn test_create_shutdown_receiver() {
429        let client = BaseMCPClient::new("test");
430
431        // 创建关闭信号接收器 / Create shutdown signal receiver
432        let mut rx = client.create_shutdown_receiver().await;
433
434        // 初始值应该是 false / Initial value should be false
435        assert!(!*rx.borrow_and_update());
436
437        // 发送关闭信号 / Send shutdown signal
438        {
439            let shutdown_tx = client.shutdown_tx.lock().await;
440            if let Some(tx) = shutdown_tx.as_ref() {
441                let _ = tx.send(true);
442            }
443        }
444
445        // 等待信号传播 / Wait for signal propagation
446        sleep(Duration::from_millis(100)).await;
447        assert!(rx.has_changed().unwrap_or(false));
448    }
449
450    #[tokio::test]
451    async fn test_execute_with_timeout_success() {
452        let client = BaseMCPClient::new("test");
453
454        // 测试成功的操作 / Test successful operation
455        let future = async {
456            sleep(Duration::from_millis(100)).await;
457            Ok::<String, MCPClientError>("success".to_string())
458        };
459
460        let result = client.execute_with_timeout(future, 1).await;
461        assert!(result.is_ok());
462        assert_eq!(result.unwrap(), "success");
463    }
464
465    #[tokio::test]
466    async fn test_execute_with_timeout_failure() {
467        let client = BaseMCPClient::new("test");
468
469        // 测试超时的操作 / Test timeout operation
470        let future = async {
471            sleep(Duration::from_secs(2)).await;
472            Ok::<String, MCPClientError>("success".to_string())
473        };
474
475        let result = client.execute_with_timeout(future, 1).await;
476        assert!(result.is_err());
477        assert!(matches!(
478            result.unwrap_err(),
479            MCPClientError::TimeoutError(_)
480        ));
481    }
482
483    #[tokio::test]
484    async fn test_start_keep_alive() {
485        let client = BaseMCPClient::new("test");
486
487        // 创建一个模拟的会话创建器 / Create a mock session creator
488        let session_creator = |_params: &str| async {
489            sleep(Duration::from_millis(100)).await;
490            Ok::<(), MCPClientError>(())
491        };
492
493        // 启动会话保持任务 / Start keep-alive task
494        client.start_keep_alive(session_creator).await;
495
496        // 等待一小段时间让任务运行 / Wait a bit for task to run
497        sleep(Duration::from_millis(50)).await;
498
499        // 停止会话保持任务 / Stop keep-alive task
500        client.stop_keep_alive().await.unwrap();
501    }
502
503    #[tokio::test]
504    async fn test_start_keep_alive_with_error() {
505        let client = BaseMCPClient::new("test");
506
507        // 创建一个会失败的会话创建器 / Create a failing session creator
508        let session_creator = |_params: &str| async {
509            Err::<(), MCPClientError>(MCPClientError::ConnectionError(
510                "Failed to create session".to_string(),
511            ))
512        };
513
514        // 启动会话保持任务 / Start keep-alive task
515        client.start_keep_alive(session_creator).await;
516
517        // 等待任务完成 / Wait for task to complete
518        sleep(Duration::from_millis(100)).await;
519
520        // 检查状态是否变为错误 / Check if state changed to error
521        assert_eq!(client.get_state().await, ClientState::Error);
522
523        // 停止会话保持任务 / Stop keep-alive task
524        client.stop_keep_alive().await.unwrap();
525    }
526
527    #[tokio::test]
528    async fn test_protocol_connect_state_check() {
529        let client = BaseMCPClient::new("test");
530
531        // 在已连接状态下尝试连接应该失败 / Should fail if already connected
532        client.update_state(ClientState::Connected).await;
533        let result = client.connect().await;
534        assert!(result.is_err());
535        assert!(matches!(
536            result.unwrap_err(),
537            MCPClientError::ConnectionError(_)
538        ));
539    }
540
541    #[tokio::test]
542    async fn test_protocol_disconnect_state_check() {
543        let client = BaseMCPClient::new("test");
544
545        // 在未连接状态下尝试断开应该失败 / Should fail if not connected
546        let result = client.disconnect().await;
547        assert!(result.is_err());
548        assert!(matches!(
549            result.unwrap_err(),
550            MCPClientError::ConnectionError(_)
551        ));
552    }
553
554    #[tokio::test]
555    async fn test_protocol_methods_require_connection() {
556        let client = BaseMCPClient::new("test");
557
558        // 所有方法都应该在未连接状态下失败 / All methods should fail when not connected
559        assert!(client.list_tools().await.is_err());
560        assert!(client
561            .call_tool("test", serde_json::json!({}))
562            .await
563            .is_err());
564        assert!(client.list_windows().await.is_err());
565        assert!(client
566            .get_window_detail(crate::mcp_clients::Resource {
567                uri: "test://".to_string(),
568                name: "test".to_string(),
569                description: None,
570                mime_type: None,
571            })
572            .await
573            .is_err());
574    }
575
576    #[tokio::test]
577    async fn test_multiple_state_change_listeners() {
578        let client = BaseMCPClient::new("test");
579
580        // 创建多个监听器 / Create multiple listeners
581        let mut rx1 = client.get_state_notifier();
582        let mut rx2 = client.get_state_notifier();
583        let mut rx3 = client.get_state_notifier();
584
585        // 更新状态 / Update state
586        client.update_state(ClientState::Connected).await;
587
588        // 所有监听器都应该收到通知 / All listeners should receive notification
589        assert_eq!(*rx1.borrow_and_update(), ClientState::Connected);
590        assert_eq!(*rx2.borrow_and_update(), ClientState::Connected);
591        assert_eq!(*rx3.borrow_and_update(), ClientState::Connected);
592    }
593
594    #[tokio::test]
595    async fn test_client_state_display() {
596        assert_eq!(ClientState::Initialized.to_string(), "initialized");
597        assert_eq!(ClientState::Connected.to_string(), "connected");
598        assert_eq!(ClientState::Disconnected.to_string(), "disconnected");
599        assert_eq!(ClientState::Error.to_string(), "error");
600    }
601}