ultrafast_mcp_core/utils/
cancellation.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, SystemTime, UNIX_EPOCH};
4use tokio::sync::RwLock;
5use tokio::time::{interval, timeout};
6
7use crate::error::{MCPError, MCPResult};
8use crate::types::notifications::{CancelledNotification, PingRequest, PingResponse};
9
10/// Request cancellation manager
11#[derive(Debug)]
12pub struct CancellationManager {
13    /// Active requests that can be cancelled
14    active_requests: Arc<RwLock<HashMap<serde_json::Value, CancellableRequest>>>,
15}
16
17/// A cancellable request
18#[derive(Debug, Clone)]
19pub struct CancellableRequest {
20    /// Request ID
21    pub id: serde_json::Value,
22
23    /// Request method
24    pub method: String,
25
26    /// Timestamp when request was created
27    pub created_at: u64,
28
29    /// Whether the request has been cancelled
30    pub cancelled: bool,
31
32    /// Cancellation reason
33    pub cancel_reason: Option<String>,
34}
35
36/// Ping manager for connection health monitoring
37#[derive(Clone)]
38pub struct PingManager {
39    /// Ping interval
40    ping_interval: Duration,
41
42    /// Ping timeout
43    ping_timeout: Duration,
44
45    /// Whether ping monitoring is enabled
46    enabled: bool,
47
48    /// Callback for sending ping requests
49    ping_sender: Option<Arc<dyn PingSender + Send + Sync>>,
50}
51
52/// Trait for sending ping requests
53#[async_trait::async_trait]
54pub trait PingSender {
55    async fn send_ping(&self, request: PingRequest) -> MCPResult<PingResponse>;
56}
57
58impl CancellationManager {
59    /// Create a new cancellation manager
60    pub fn new() -> Self {
61        Self {
62            active_requests: Arc::new(RwLock::new(HashMap::new())),
63        }
64    }
65
66    /// Register a new request for cancellation tracking
67    pub async fn register_request(&self, id: serde_json::Value, method: String) -> MCPResult<()> {
68        let request = CancellableRequest {
69            id: id.clone(),
70            method,
71            created_at: current_timestamp(),
72            cancelled: false,
73            cancel_reason: None,
74        };
75
76        let mut active = self.active_requests.write().await;
77        active.insert(id, request);
78        Ok(())
79    }
80
81    /// Cancel a request
82    pub async fn cancel_request(
83        &self,
84        id: &serde_json::Value,
85        reason: Option<String>,
86    ) -> MCPResult<bool> {
87        let mut active = self.active_requests.write().await;
88
89        let Some(request) = active.get_mut(id) else {
90            return Ok(false);
91        };
92
93        if request.cancelled {
94            return Ok(false);
95        }
96
97        request.cancelled = true;
98        request.cancel_reason = reason;
99        Ok(true)
100    }
101
102    /// Check if a request has been cancelled
103    pub async fn is_cancelled(&self, id: &serde_json::Value) -> bool {
104        let active = self.active_requests.read().await;
105        active.get(id).map(|r| r.cancelled).unwrap_or(false)
106    }
107
108    /// Remove a completed request
109    pub async fn complete_request(&self, id: &serde_json::Value) -> MCPResult<()> {
110        let mut active = self.active_requests.write().await;
111        active.remove(id);
112        Ok(())
113    }
114
115    /// Get all active requests
116    pub async fn active_requests(&self) -> Vec<CancellableRequest> {
117        let active = self.active_requests.read().await;
118        active.values().cloned().collect()
119    }
120
121    /// Clean up old requests (older than max_age)
122    pub async fn cleanup_old_requests(&self, max_age: Duration) -> MCPResult<usize> {
123        let cutoff = current_timestamp() - max_age.as_secs();
124        let mut active = self.active_requests.write().await;
125
126        let original_len = active.len();
127        active.retain(|_, request| request.created_at > cutoff);
128        let removed = original_len - active.len();
129
130        Ok(removed)
131    }
132
133    /// Handle a cancellation notification
134    pub async fn handle_cancellation(
135        &self,
136        notification: CancelledNotification,
137    ) -> MCPResult<bool> {
138        self.cancel_request(&notification.request_id, notification.reason)
139            .await
140    }
141}
142
143impl PingManager {
144    /// Create a new ping manager
145    pub fn new(ping_interval: Duration, ping_timeout: Duration) -> Self {
146        Self {
147            ping_interval,
148            ping_timeout,
149            enabled: false,
150            ping_sender: None,
151        }
152    }
153
154    /// Set the ping sender
155    pub fn with_sender(mut self, sender: Arc<dyn PingSender + Send + Sync>) -> Self {
156        self.ping_sender = Some(sender);
157        self
158    }
159
160    /// Enable ping monitoring
161    pub fn enable(&mut self) {
162        self.enabled = true;
163    }
164
165    /// Disable ping monitoring
166    pub fn disable(&mut self) {
167        self.enabled = false;
168    }
169
170    /// Start periodic ping monitoring
171    pub async fn start_monitoring(&self) -> MCPResult<()> {
172        if !self.enabled || self.ping_sender.is_none() {
173            return Err(MCPError::internal_error(
174                "Ping monitoring not properly configured".to_string(),
175            ));
176        }
177
178        let sender = self.ping_sender.as_ref().unwrap().clone();
179        let ping_interval = self.ping_interval;
180        let ping_timeout = self.ping_timeout;
181
182        tokio::spawn(async move {
183            let mut interval = interval(ping_interval);
184
185            loop {
186                interval.tick().await;
187
188                let ping_request = PingRequest::new().with_data(serde_json::json!({
189                    "timestamp": current_timestamp(),
190                    "keepalive": true
191                }));
192
193                match timeout(ping_timeout, sender.send_ping(ping_request)).await {
194                    Ok(Ok(_response)) => {
195                        // Ping successful
196                        tracing::debug!("Ping successful");
197                    }
198                    Ok(Err(e)) => {
199                        // Ping failed
200                        tracing::warn!("Ping failed: {}", e);
201                        // Could implement reconnection logic here
202                        break;
203                    }
204                    Err(_) => {
205                        // Ping timed out
206                        tracing::warn!("Ping timed out after {:?}", ping_timeout);
207                        // Could implement reconnection logic here
208                        break;
209                    }
210                }
211            }
212        });
213
214        Ok(())
215    }
216
217    /// Handle a ping request and return a pong response
218    pub async fn handle_ping(&self, request: PingRequest) -> MCPResult<PingResponse> {
219        // Echo back the data as per MCP 2025-06-18 specification
220        Ok(PingResponse { data: request.data })
221    }
222}
223
224impl Default for CancellationManager {
225    fn default() -> Self {
226        Self::new()
227    }
228}
229
230impl Default for PingManager {
231    fn default() -> Self {
232        Self::new(Duration::from_secs(30), Duration::from_secs(5))
233    }
234}
235
236impl std::fmt::Debug for PingManager {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        f.debug_struct("PingManager")
239            .field("ping_interval", &self.ping_interval)
240            .field("ping_timeout", &self.ping_timeout)
241            .field("enabled", &self.enabled)
242            .field("ping_sender", &"<callback>")
243            .finish()
244    }
245}
246
247/// Get current Unix timestamp
248fn current_timestamp() -> u64 {
249    SystemTime::now()
250        .duration_since(UNIX_EPOCH)
251        .unwrap_or_default()
252        .as_secs()
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use tokio::time::sleep;
259
260    #[tokio::test]
261    async fn test_cancellation_manager() {
262        let manager = CancellationManager::new();
263
264        let request_id = serde_json::json!("test-request-1");
265
266        // Register request
267        manager
268            .register_request(request_id.clone(), "test_method".to_string())
269            .await
270            .unwrap();
271
272        // Check not cancelled initially
273        assert!(!manager.is_cancelled(&request_id).await);
274
275        // Cancel request
276        let cancelled = manager
277            .cancel_request(&request_id, Some("User requested".to_string()))
278            .await
279            .unwrap();
280        assert!(cancelled);
281
282        // Check now cancelled
283        assert!(manager.is_cancelled(&request_id).await);
284
285        // Complete request
286        manager.complete_request(&request_id).await.unwrap();
287
288        // Check removed from active requests
289        assert!(!manager.is_cancelled(&request_id).await);
290    }
291
292    #[tokio::test]
293    async fn test_cancellation_cleanup() {
294        let manager = CancellationManager::new();
295
296        // Register multiple requests
297        for i in 0..5 {
298            let request_id = serde_json::json!(format!("test-request-{}", i));
299            manager
300                .register_request(request_id, "test_method".to_string())
301                .await
302                .unwrap();
303        }
304
305        // Wait a bit
306        sleep(Duration::from_millis(100)).await;
307
308        // Cleanup old requests (very short max age)
309        let removed = manager
310            .cleanup_old_requests(Duration::from_millis(50))
311            .await
312            .unwrap();
313        assert_eq!(removed, 5);
314    }
315
316    #[tokio::test]
317    async fn test_ping_manager() {
318        let manager = PingManager::new(Duration::from_secs(1), Duration::from_secs(1));
319
320        let request = PingRequest::new().with_data(serde_json::json!({"test": "data"}));
321        let response = manager.handle_ping(request).await.unwrap();
322
323        // PingResponse should echo back the data as per MCP 2025-06-18 specification
324        assert_eq!(
325            format!("{response:?}"),
326            "PingResponse { data: Some(Object {\"test\": String(\"data\")}) }"
327        );
328    }
329
330    #[test]
331    fn test_ping_response() {
332        let response = PingResponse::new();
333        // PingResponse is empty as per MCP 2025-06-18 specification
334        assert_eq!(format!("{response:?}"), "PingResponse { data: None }");
335    }
336}