smcp_computer/mcp_clients/
base_client.rs1use super::model::*;
11use crate::errors::ComputerError;
12use async_trait::async_trait;
13use std::sync::Arc;
14use tokio::sync::{watch, Mutex, RwLock};
15use tokio::task::JoinHandle;
16use tokio::time::{timeout, Duration};
17use tracing::{debug, error, info, warn};
18
19pub struct BaseMCPClient<P> {
21 pub params: P,
23 state: Arc<RwLock<ClientState>>,
25 state_notifier: watch::Sender<ClientState>,
27 keep_alive_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
29 shutdown_tx: Arc<Mutex<Option<watch::Sender<bool>>>>,
31 state_change_callback: Option<Box<dyn Fn(ClientState, ClientState) + Send + Sync>>,
33}
34
35impl<P> BaseMCPClient<P>
36where
37 P: Send + Sync + 'static + std::clone::Clone,
38{
39 pub fn new(params: P) -> Self {
41 let (state_tx, _) = watch::channel(ClientState::Initialized);
42 let state = Arc::new(RwLock::new(ClientState::Initialized));
43 let (shutdown_tx, _) = watch::channel(false);
44
45 Self {
46 params,
47 state,
48 state_notifier: state_tx,
49 keep_alive_handle: Arc::new(Mutex::new(None)),
50 shutdown_tx: Arc::new(Mutex::new(Some(shutdown_tx))),
51 state_change_callback: None,
52 }
53 }
54
55 pub fn set_state_change_callback<F>(&mut self, callback: F)
57 where
58 F: Fn(ClientState, ClientState) + Send + Sync + 'static,
59 {
60 self.state_change_callback = Some(Box::new(callback));
61 }
62
63 pub async fn get_state(&self) -> ClientState {
65 *self.state.read().await
66 }
67
68 pub fn get_state_notifier(&self) -> watch::Receiver<ClientState> {
70 self.state_notifier.subscribe()
71 }
72
73 pub async fn update_state(&self, new_state: ClientState) {
75 let mut state = self.state.write().await;
76 let old_state = *state;
77 *state = new_state;
78
79 let _ = self.state_notifier.send(new_state);
81
82 if let Some(ref callback) = self.state_change_callback {
84 callback(old_state, new_state);
85 }
86
87 debug!("State transition: {} -> {}", old_state, new_state);
88 }
89
90 #[allow(dead_code)]
92 async fn start_keep_alive<T>(&self, session_creator: impl Fn(P) -> T + Send + Sync + 'static)
93 where
94 T: std::future::Future<Output = Result<(), MCPClientError>> + Send + 'static,
95 {
96 let params = self.params.clone();
97 let mut shutdown_rx = self.create_shutdown_receiver().await;
98 let state = self.state.clone();
99
100 let handle = tokio::spawn(async move {
101 debug!("Session keep-alive task started");
102
103 let session_future = session_creator(params);
105
106 tokio::select! {
107 result = session_future => {
108 match result {
109 Ok(_) => {
110 debug!("Session completed successfully");
111 { *state.write().await = ClientState::Disconnected; }
112 }
113 Err(e) => {
114 error!("Session failed: {}", e);
115 { *state.write().await = ClientState::Error; }
116 }
117 }
118 }
119 shutdown_rx = shutdown_rx.changed() => {
120 if shutdown_rx.is_ok() {
121 debug!("Session keep-alive task received shutdown signal");
122 }
123 }
124 }
125
126 debug!("Session keep-alive task ended");
127 });
128
129 *self.keep_alive_handle.lock().await = Some(handle);
130 }
131
132 async fn stop_keep_alive(&self) -> Result<(), ComputerError> {
134 let mut shutdown_tx = self.shutdown_tx.lock().await;
136 if let Some(tx) = shutdown_tx.take() {
137 let _ = tx.send(true);
138 }
139
140 let mut handle = self.keep_alive_handle.lock().await;
142 if let Some(h) = handle.take() {
143 match h.await {
144 Ok(_) => debug!("Keep-alive task stopped successfully"),
145 Err(e) => warn!("Keep-alive task stopped with error: {}", e),
146 }
147 }
148
149 let (tx, _) = watch::channel(false);
151 *shutdown_tx = Some(tx);
152
153 Ok(())
154 }
155
156 #[allow(dead_code)]
158 async fn create_shutdown_receiver(&self) -> watch::Receiver<bool> {
159 let shutdown_tx = self.shutdown_tx.lock().await;
160 shutdown_tx.as_ref().unwrap().subscribe()
161 }
162
163 pub async fn can_connect(&self) -> bool {
165 matches!(
166 self.get_state().await,
167 ClientState::Initialized | ClientState::Disconnected
168 )
169 }
170
171 pub async fn can_disconnect(&self) -> bool {
173 matches!(self.get_state().await, ClientState::Connected)
174 }
175
176 #[allow(dead_code)]
178 async fn execute_with_timeout<F, T>(
179 &self,
180 future: F,
181 timeout_secs: u64,
182 ) -> Result<T, MCPClientError>
183 where
184 F: std::future::Future<Output = Result<T, MCPClientError>>,
185 {
186 match timeout(Duration::from_secs(timeout_secs), future).await {
187 Ok(result) => result,
188 Err(_) => Err(MCPClientError::TimeoutError(format!(
189 "Operation timed out after {} seconds",
190 timeout_secs
191 ))),
192 }
193 }
194}
195
196#[async_trait]
197impl<P> MCPClientProtocol for BaseMCPClient<P>
198where
199 P: Send + Sync + Clone + 'static,
200{
201 fn state(&self) -> ClientState {
202 if let Ok(state_guard) = self.state.try_read() {
204 *state_guard
205 } else {
206 tokio::task::block_in_place(|| {
209 tokio::runtime::Handle::current().block_on(async { self.get_state().await })
210 })
211 }
212 }
213
214 async fn connect(&self) -> Result<(), MCPClientError> {
215 if !self.can_connect().await {
216 return Err(MCPClientError::ConnectionError(format!(
217 "Cannot connect in state: {}",
218 self.get_state().await
219 )));
220 }
221
222 self.update_state(ClientState::Connected).await;
223 info!("Connected successfully");
224 Ok(())
225 }
226
227 async fn disconnect(&self) -> Result<(), MCPClientError> {
228 if !self.can_disconnect().await {
229 return Err(MCPClientError::ConnectionError(format!(
230 "Cannot disconnect in state: {}",
231 self.get_state().await
232 )));
233 }
234
235 self.stop_keep_alive()
236 .await
237 .map_err(|e| MCPClientError::Other(e.to_string()))?;
238 self.update_state(ClientState::Disconnected).await;
239 info!("Disconnected successfully");
240 Ok(())
241 }
242
243 async fn list_tools(&self) -> Result<Vec<Tool>, MCPClientError> {
244 if self.get_state().await != ClientState::Connected {
245 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
246 }
247 Ok(vec![])
250 }
251
252 async fn call_tool(
253 &self,
254 _tool_name: &str,
255 _params: serde_json::Value,
256 ) -> Result<CallToolResult, MCPClientError> {
257 if self.get_state().await != ClientState::Connected {
258 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
259 }
260 Err(MCPClientError::ProtocolError("Not implemented".to_string()))
263 }
264
265 async fn list_windows(&self) -> Result<Vec<Resource>, MCPClientError> {
266 if self.get_state().await != ClientState::Connected {
267 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
268 }
269 Ok(vec![])
272 }
273
274 async fn get_window_detail(
275 &self,
276 _resource: Resource,
277 ) -> Result<ReadResourceResult, MCPClientError> {
278 if self.get_state().await != ClientState::Connected {
279 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
280 }
281 Err(MCPClientError::ProtocolError("Not implemented".to_string()))
284 }
285
286 async fn subscribe_window(&self, _resource: Resource) -> Result<(), MCPClientError> {
287 if self.get_state().await != ClientState::Connected {
288 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
289 }
290 Err(MCPClientError::ProtocolError("Not implemented".to_string()))
293 }
294
295 async fn unsubscribe_window(&self, _resource: Resource) -> Result<(), MCPClientError> {
296 if self.get_state().await != ClientState::Connected {
297 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
298 }
299 Err(MCPClientError::ProtocolError("Not implemented".to_string()))
302 }
303}
304
305#[derive(Debug, Clone, Copy, PartialEq, Eq)]
307pub enum StateTransition {
308 InitializeToConnected,
309 ConnectedToDisconnected,
310 AnyToError,
311 ErrorToInitialized,
312}
313
314impl StateTransition {
315 pub fn is_valid(from: ClientState, to: ClientState) -> bool {
317 matches!(
318 (from, to),
319 (ClientState::Initialized, ClientState::Connected)
320 | (ClientState::Connected, ClientState::Disconnected)
321 | (_, ClientState::Error)
322 | (ClientState::Error, ClientState::Initialized)
323 | (ClientState::Disconnected, ClientState::Connected)
324 | (ClientState::Disconnected, ClientState::Initialized)
325 )
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use std::sync::atomic::{AtomicUsize, Ordering};
333 use tokio::time::{sleep, Duration};
334
335 #[tokio::test]
336 async fn test_state_transition_validity() {
337 assert!(StateTransition::is_valid(
338 ClientState::Initialized,
339 ClientState::Connected
340 ));
341 assert!(StateTransition::is_valid(
342 ClientState::Connected,
343 ClientState::Disconnected
344 ));
345 assert!(StateTransition::is_valid(
346 ClientState::Connected,
347 ClientState::Error
348 ));
349 assert!(StateTransition::is_valid(
350 ClientState::Error,
351 ClientState::Initialized
352 ));
353 assert!(!StateTransition::is_valid(
354 ClientState::Connected,
355 ClientState::Initialized
356 ));
357 }
358
359 #[tokio::test]
360 async fn test_base_client_state_management() {
361 let client = BaseMCPClient::new("test");
362 assert_eq!(client.get_state().await, ClientState::Initialized);
363
364 let mut rx = client.get_state_notifier();
366 assert_eq!(*rx.borrow_and_update(), ClientState::Initialized);
367 }
368
369 #[tokio::test]
370 async fn test_state_change_callback() {
371 let mut client = BaseMCPClient::new("test");
372 let call_count = Arc::new(AtomicUsize::new(0));
373 let call_count_clone = call_count.clone();
374
375 client.set_state_change_callback(move |from, to| {
376 call_count_clone.fetch_add(1, Ordering::SeqCst);
377 println!("State changed from {} to {}", from, to);
378 });
379
380 client.update_state(ClientState::Connected).await;
382 assert_eq!(client.get_state().await, ClientState::Connected);
383 assert_eq!(call_count.load(Ordering::SeqCst), 1);
384
385 client.update_state(ClientState::Disconnected).await;
387 assert_eq!(client.get_state().await, ClientState::Disconnected);
388 assert_eq!(call_count.load(Ordering::SeqCst), 2);
389 }
390
391 #[tokio::test]
392 async fn test_can_connect() {
393 let client = BaseMCPClient::new("test");
394
395 assert!(client.can_connect().await);
397
398 client.update_state(ClientState::Connected).await;
400 assert!(!client.can_connect().await);
401
402 client.update_state(ClientState::Disconnected).await;
404 assert!(client.can_connect().await);
405
406 client.update_state(ClientState::Error).await;
408 assert!(!client.can_connect().await);
409 }
410
411 #[tokio::test]
412 async fn test_can_disconnect() {
413 let client = BaseMCPClient::new("test");
414
415 assert!(!client.can_disconnect().await);
417
418 client.update_state(ClientState::Connected).await;
420 assert!(client.can_disconnect().await);
421
422 client.update_state(ClientState::Disconnected).await;
424 assert!(!client.can_disconnect().await);
425 }
426
427 #[tokio::test]
428 async fn test_create_shutdown_receiver() {
429 let client = BaseMCPClient::new("test");
430
431 let mut rx = client.create_shutdown_receiver().await;
433
434 assert!(!*rx.borrow_and_update());
436
437 {
439 let shutdown_tx = client.shutdown_tx.lock().await;
440 if let Some(tx) = shutdown_tx.as_ref() {
441 let _ = tx.send(true);
442 }
443 }
444
445 sleep(Duration::from_millis(100)).await;
447 assert!(rx.has_changed().unwrap_or(false));
448 }
449
450 #[tokio::test]
451 async fn test_execute_with_timeout_success() {
452 let client = BaseMCPClient::new("test");
453
454 let future = async {
456 sleep(Duration::from_millis(100)).await;
457 Ok::<String, MCPClientError>("success".to_string())
458 };
459
460 let result = client.execute_with_timeout(future, 1).await;
461 assert!(result.is_ok());
462 assert_eq!(result.unwrap(), "success");
463 }
464
465 #[tokio::test]
466 async fn test_execute_with_timeout_failure() {
467 let client = BaseMCPClient::new("test");
468
469 let future = async {
471 sleep(Duration::from_secs(2)).await;
472 Ok::<String, MCPClientError>("success".to_string())
473 };
474
475 let result = client.execute_with_timeout(future, 1).await;
476 assert!(result.is_err());
477 assert!(matches!(
478 result.unwrap_err(),
479 MCPClientError::TimeoutError(_)
480 ));
481 }
482
483 #[tokio::test]
484 async fn test_start_keep_alive() {
485 let client = BaseMCPClient::new("test");
486
487 let session_creator = |_params: &str| async {
489 sleep(Duration::from_millis(100)).await;
490 Ok::<(), MCPClientError>(())
491 };
492
493 client.start_keep_alive(session_creator).await;
495
496 sleep(Duration::from_millis(50)).await;
498
499 client.stop_keep_alive().await.unwrap();
501 }
502
503 #[tokio::test]
504 async fn test_start_keep_alive_with_error() {
505 let client = BaseMCPClient::new("test");
506
507 let session_creator = |_params: &str| async {
509 Err::<(), MCPClientError>(MCPClientError::ConnectionError(
510 "Failed to create session".to_string(),
511 ))
512 };
513
514 client.start_keep_alive(session_creator).await;
516
517 sleep(Duration::from_millis(100)).await;
519
520 assert_eq!(client.get_state().await, ClientState::Error);
522
523 client.stop_keep_alive().await.unwrap();
525 }
526
527 #[tokio::test]
528 async fn test_protocol_connect_state_check() {
529 let client = BaseMCPClient::new("test");
530
531 client.update_state(ClientState::Connected).await;
533 let result = client.connect().await;
534 assert!(result.is_err());
535 assert!(matches!(
536 result.unwrap_err(),
537 MCPClientError::ConnectionError(_)
538 ));
539 }
540
541 #[tokio::test]
542 async fn test_protocol_disconnect_state_check() {
543 let client = BaseMCPClient::new("test");
544
545 let result = client.disconnect().await;
547 assert!(result.is_err());
548 assert!(matches!(
549 result.unwrap_err(),
550 MCPClientError::ConnectionError(_)
551 ));
552 }
553
554 #[tokio::test]
555 async fn test_protocol_methods_require_connection() {
556 let client = BaseMCPClient::new("test");
557
558 assert!(client.list_tools().await.is_err());
560 assert!(client
561 .call_tool("test", serde_json::json!({}))
562 .await
563 .is_err());
564 assert!(client.list_windows().await.is_err());
565 assert!(client
566 .get_window_detail(crate::mcp_clients::Resource {
567 uri: "test://".to_string(),
568 name: "test".to_string(),
569 description: None,
570 mime_type: None,
571 })
572 .await
573 .is_err());
574 }
575
576 #[tokio::test]
577 async fn test_multiple_state_change_listeners() {
578 let client = BaseMCPClient::new("test");
579
580 let mut rx1 = client.get_state_notifier();
582 let mut rx2 = client.get_state_notifier();
583 let mut rx3 = client.get_state_notifier();
584
585 client.update_state(ClientState::Connected).await;
587
588 assert_eq!(*rx1.borrow_and_update(), ClientState::Connected);
590 assert_eq!(*rx2.borrow_and_update(), ClientState::Connected);
591 assert_eq!(*rx3.borrow_and_update(), ClientState::Connected);
592 }
593
594 #[tokio::test]
595 async fn test_client_state_display() {
596 assert_eq!(ClientState::Initialized.to_string(), "initialized");
597 assert_eq!(ClientState::Connected.to_string(), "connected");
598 assert_eq!(ClientState::Disconnected.to_string(), "disconnected");
599 assert_eq!(ClientState::Error.to_string(), "error");
600 }
601}