ultrafast_mcp_core/utils/
cancellation.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, SystemTime, UNIX_EPOCH};
4use tokio::sync::RwLock;
5use tokio::time::{interval, timeout};
6
7use crate::error::{MCPError, MCPResult};
8use crate::types::notifications::{CancelledNotification, PingRequest, PingResponse};
9
10#[derive(Debug)]
12pub struct CancellationManager {
13 active_requests: Arc<RwLock<HashMap<serde_json::Value, CancellableRequest>>>,
15}
16
17#[derive(Debug, Clone)]
19pub struct CancellableRequest {
20 pub id: serde_json::Value,
22
23 pub method: String,
25
26 pub created_at: u64,
28
29 pub cancelled: bool,
31
32 pub cancel_reason: Option<String>,
34}
35
36#[derive(Clone)]
38pub struct PingManager {
39 ping_interval: Duration,
41
42 ping_timeout: Duration,
44
45 enabled: bool,
47
48 ping_sender: Option<Arc<dyn PingSender + Send + Sync>>,
50}
51
52#[async_trait::async_trait]
54pub trait PingSender {
55 async fn send_ping(&self, request: PingRequest) -> MCPResult<PingResponse>;
56}
57
58impl CancellationManager {
59 pub fn new() -> Self {
61 Self {
62 active_requests: Arc::new(RwLock::new(HashMap::new())),
63 }
64 }
65
66 pub async fn register_request(&self, id: serde_json::Value, method: String) -> MCPResult<()> {
68 let request = CancellableRequest {
69 id: id.clone(),
70 method,
71 created_at: current_timestamp(),
72 cancelled: false,
73 cancel_reason: None,
74 };
75
76 let mut active = self.active_requests.write().await;
77 active.insert(id, request);
78 Ok(())
79 }
80
81 pub async fn cancel_request(
83 &self,
84 id: &serde_json::Value,
85 reason: Option<String>,
86 ) -> MCPResult<bool> {
87 let mut active = self.active_requests.write().await;
88
89 let Some(request) = active.get_mut(id) else {
90 return Ok(false);
91 };
92
93 if request.cancelled {
94 return Ok(false);
95 }
96
97 request.cancelled = true;
98 request.cancel_reason = reason;
99 Ok(true)
100 }
101
102 pub async fn is_cancelled(&self, id: &serde_json::Value) -> bool {
104 let active = self.active_requests.read().await;
105 active.get(id).map(|r| r.cancelled).unwrap_or(false)
106 }
107
108 pub async fn complete_request(&self, id: &serde_json::Value) -> MCPResult<()> {
110 let mut active = self.active_requests.write().await;
111 active.remove(id);
112 Ok(())
113 }
114
115 pub async fn active_requests(&self) -> Vec<CancellableRequest> {
117 let active = self.active_requests.read().await;
118 active.values().cloned().collect()
119 }
120
121 pub async fn cleanup_old_requests(&self, max_age: Duration) -> MCPResult<usize> {
123 let cutoff = current_timestamp() - max_age.as_secs();
124 let mut active = self.active_requests.write().await;
125
126 let original_len = active.len();
127 active.retain(|_, request| request.created_at > cutoff);
128 let removed = original_len - active.len();
129
130 Ok(removed)
131 }
132
133 pub async fn handle_cancellation(
135 &self,
136 notification: CancelledNotification,
137 ) -> MCPResult<bool> {
138 self.cancel_request(¬ification.request_id, notification.reason)
139 .await
140 }
141}
142
143impl PingManager {
144 pub fn new(ping_interval: Duration, ping_timeout: Duration) -> Self {
146 Self {
147 ping_interval,
148 ping_timeout,
149 enabled: false,
150 ping_sender: None,
151 }
152 }
153
154 pub fn with_sender(mut self, sender: Arc<dyn PingSender + Send + Sync>) -> Self {
156 self.ping_sender = Some(sender);
157 self
158 }
159
160 pub fn enable(&mut self) {
162 self.enabled = true;
163 }
164
165 pub fn disable(&mut self) {
167 self.enabled = false;
168 }
169
170 pub async fn start_monitoring(&self) -> MCPResult<()> {
172 if !self.enabled || self.ping_sender.is_none() {
173 return Err(MCPError::internal_error(
174 "Ping monitoring not properly configured".to_string(),
175 ));
176 }
177
178 let sender = self.ping_sender.as_ref().unwrap().clone();
179 let ping_interval = self.ping_interval;
180 let ping_timeout = self.ping_timeout;
181
182 tokio::spawn(async move {
183 let mut interval = interval(ping_interval);
184
185 loop {
186 interval.tick().await;
187
188 let ping_request = PingRequest::new().with_data(serde_json::json!({
189 "timestamp": current_timestamp(),
190 "keepalive": true
191 }));
192
193 match timeout(ping_timeout, sender.send_ping(ping_request)).await {
194 Ok(Ok(_response)) => {
195 tracing::debug!("Ping successful");
197 }
198 Ok(Err(e)) => {
199 tracing::warn!("Ping failed: {}", e);
201 break;
203 }
204 Err(_) => {
205 tracing::warn!("Ping timed out after {:?}", ping_timeout);
207 break;
209 }
210 }
211 }
212 });
213
214 Ok(())
215 }
216
217 pub async fn handle_ping(&self, request: PingRequest) -> MCPResult<PingResponse> {
219 Ok(PingResponse { data: request.data })
221 }
222}
223
224impl Default for CancellationManager {
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230impl Default for PingManager {
231 fn default() -> Self {
232 Self::new(Duration::from_secs(30), Duration::from_secs(5))
233 }
234}
235
236impl std::fmt::Debug for PingManager {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 f.debug_struct("PingManager")
239 .field("ping_interval", &self.ping_interval)
240 .field("ping_timeout", &self.ping_timeout)
241 .field("enabled", &self.enabled)
242 .field("ping_sender", &"<callback>")
243 .finish()
244 }
245}
246
247fn current_timestamp() -> u64 {
249 SystemTime::now()
250 .duration_since(UNIX_EPOCH)
251 .unwrap_or_default()
252 .as_secs()
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use tokio::time::sleep;
259
260 #[tokio::test]
261 async fn test_cancellation_manager() {
262 let manager = CancellationManager::new();
263
264 let request_id = serde_json::json!("test-request-1");
265
266 manager
268 .register_request(request_id.clone(), "test_method".to_string())
269 .await
270 .unwrap();
271
272 assert!(!manager.is_cancelled(&request_id).await);
274
275 let cancelled = manager
277 .cancel_request(&request_id, Some("User requested".to_string()))
278 .await
279 .unwrap();
280 assert!(cancelled);
281
282 assert!(manager.is_cancelled(&request_id).await);
284
285 manager.complete_request(&request_id).await.unwrap();
287
288 assert!(!manager.is_cancelled(&request_id).await);
290 }
291
292 #[tokio::test]
293 async fn test_cancellation_cleanup() {
294 let manager = CancellationManager::new();
295
296 for i in 0..5 {
298 let request_id = serde_json::json!(format!("test-request-{}", i));
299 manager
300 .register_request(request_id, "test_method".to_string())
301 .await
302 .unwrap();
303 }
304
305 sleep(Duration::from_millis(100)).await;
307
308 let removed = manager
310 .cleanup_old_requests(Duration::from_millis(50))
311 .await
312 .unwrap();
313 assert_eq!(removed, 5);
314 }
315
316 #[tokio::test]
317 async fn test_ping_manager() {
318 let manager = PingManager::new(Duration::from_secs(1), Duration::from_secs(1));
319
320 let request = PingRequest::new().with_data(serde_json::json!({"test": "data"}));
321 let response = manager.handle_ping(request).await.unwrap();
322
323 assert_eq!(
325 format!("{response:?}"),
326 "PingResponse { data: Some(Object {\"test\": String(\"data\")}) }"
327 );
328 }
329
330 #[test]
331 fn test_ping_response() {
332 let response = PingResponse::new();
333 assert_eq!(format!("{response:?}"), "PingResponse { data: None }");
335 }
336}