1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::RwLock;
11use tokio::time::Instant;
12use tokio_util::sync::CancellationToken;
13use tracing::{debug, error, instrument, warn};
14use uuid::Uuid;
15
16use crate::ServerError;
17use crate::config::TimeoutConfig;
18use crate::metrics::ServerMetrics;
19
20#[derive(Debug, Clone)]
25pub struct ToolTimeoutManager {
26 config: TimeoutConfig,
28
29 active_executions: Arc<RwLock<HashMap<Uuid, ToolExecution>>>,
31
32 metrics: Arc<ServerMetrics>,
34}
35
36#[derive(Debug, Clone)]
38struct ToolExecution {
39 tool_name: String,
41 started_at: Instant,
43 timeout_duration: Duration,
45 cancellation_token: CancellationToken,
47 cancelled: bool,
49}
50
51impl ToolTimeoutManager {
52 pub fn new(config: TimeoutConfig, metrics: Arc<ServerMetrics>) -> Self {
54 Self {
55 config,
56 active_executions: Arc::new(RwLock::new(HashMap::new())),
57 metrics,
58 }
59 }
60
61 pub fn get_tool_timeout(&self, tool_name: &str) -> Duration {
65 self.config
66 .tool_timeouts
67 .get(tool_name)
68 .map(|&seconds| Duration::from_secs(seconds))
69 .unwrap_or(self.config.tool_execution_timeout)
70 }
71
72 #[instrument(skip(self, operation), fields(tool_name = %tool_name))]
79 pub async fn execute_with_timeout_and_cancellation<F, T>(
80 &self,
81 tool_name: &str,
82 operation: F,
83 ) -> Result<(T, CancellationToken), ToolTimeoutError>
84 where
85 F: std::future::Future<Output = Result<T, ServerError>>,
86 T: Send,
87 {
88 let execution_id = Uuid::new_v4();
89 let timeout_duration = self.get_tool_timeout(tool_name);
90 let started_at = Instant::now();
91
92 let cancellation_token = CancellationToken::new();
94
95 {
97 let mut executions = self.active_executions.write().await;
98 executions.insert(
99 execution_id,
100 ToolExecution {
101 tool_name: tool_name.to_string(),
102 started_at,
103 timeout_duration,
104 cancellation_token: cancellation_token.clone(),
105 cancelled: false,
106 },
107 );
108 }
109
110 self.metrics
112 .tool_executions_active
113 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
114
115 debug!(
116 tool_name = %tool_name,
117 execution_id = %execution_id,
118 timeout_seconds = timeout_duration.as_secs(),
119 "Starting tool execution with timeout"
120 );
121
122 let result = tokio::select! {
125 operation_result = operation => {
127 TimeoutResult::Completed(operation_result)
128 },
129 _ = tokio::time::sleep(timeout_duration) => {
131 cancellation_token.cancel();
133 TimeoutResult::TimedOut
134 },
135 _ = cancellation_token.cancelled() => {
137 warn!(
138 tool_name = %tool_name,
139 execution_id = %execution_id,
140 "Tool execution cancelled cooperatively"
141 );
142 TimeoutResult::Cancelled
143 },
144 };
145
146 {
148 let mut executions = self.active_executions.write().await;
149 executions.remove(&execution_id);
150 }
151
152 self.metrics
154 .tool_executions_active
155 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
156
157 let elapsed = started_at.elapsed();
158
159 match result {
160 TimeoutResult::Completed(Ok(value)) => {
161 self.metrics
163 .tool_executions_successful
164 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
165
166 debug!(
167 tool_name = %tool_name,
168 execution_id = %execution_id,
169 elapsed_ms = elapsed.as_millis(),
170 "Tool execution completed successfully"
171 );
172 Ok((value, cancellation_token))
173 }
174 TimeoutResult::Completed(Err(server_error)) => {
175 warn!(
176 tool_name = %tool_name,
177 execution_id = %execution_id,
178 elapsed_ms = elapsed.as_millis(),
179 error = %server_error,
180 "Tool execution failed with server error"
181 );
182 Err(ToolTimeoutError::ServerError(server_error))
183 }
184 TimeoutResult::TimedOut => {
185 self.metrics
187 .tool_timeouts_total
188 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
189 self.metrics.timeout_wasted_time_us.fetch_add(
190 elapsed.as_micros() as u64,
191 std::sync::atomic::Ordering::Relaxed,
192 );
193 self.metrics
194 .errors_timeout
195 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
196
197 error!(
199 tool_name = %tool_name,
200 execution_id = %execution_id,
201 timeout_seconds = timeout_duration.as_secs(),
202 elapsed_ms = elapsed.as_millis(),
203 event_type = "TIMEOUT_EVENT",
204 security_concern = "potential_dos_indicator",
205 "🔒 SECURITY AUDIT: Tool execution timed out"
206 );
207 Err(ToolTimeoutError::Timeout {
208 tool_name: tool_name.to_string(),
209 timeout_duration,
210 elapsed,
211 })
212 }
213 TimeoutResult::Cancelled => {
214 self.metrics
216 .tool_cancellations_total
217 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
218
219 warn!(
221 tool_name = %tool_name,
222 execution_id = %execution_id,
223 elapsed_ms = elapsed.as_millis(),
224 event_type = "CANCELLATION_EVENT",
225 "🔒 AUDIT: Tool execution was cancelled cooperatively"
226 );
227 Err(ToolTimeoutError::Cancelled {
228 tool_name: tool_name.to_string(),
229 elapsed,
230 })
231 }
232 }
233 }
234
235 #[instrument(skip(self, operation, cancellation_token), fields(tool_name = %tool_name))]
240 pub async fn execute_with_external_token<F, T>(
241 &self,
242 tool_name: &str,
243 operation: F,
244 cancellation_token: CancellationToken,
245 ) -> Result<T, ToolTimeoutError>
246 where
247 F: std::future::Future<Output = Result<T, ServerError>>,
248 T: Send,
249 {
250 let execution_id = Uuid::new_v4();
251 let timeout_duration = self.get_tool_timeout(tool_name);
252 let started_at = Instant::now();
253
254 {
256 let mut executions = self.active_executions.write().await;
257 executions.insert(
258 execution_id,
259 ToolExecution {
260 tool_name: tool_name.to_string(),
261 started_at,
262 timeout_duration,
263 cancellation_token: cancellation_token.clone(),
264 cancelled: false,
265 },
266 );
267 }
268
269 debug!(
270 tool_name = %tool_name,
271 execution_id = %execution_id,
272 timeout_seconds = timeout_duration.as_secs(),
273 "Starting tool execution with provided cancellation token"
274 );
275
276 let result = tokio::select! {
278 operation_result = operation => {
280 TimeoutResult::Completed(operation_result)
281 },
282 _ = tokio::time::sleep(timeout_duration) => {
284 cancellation_token.cancel();
286 TimeoutResult::TimedOut
287 },
288 _ = cancellation_token.cancelled() => {
290 warn!(
291 tool_name = %tool_name,
292 execution_id = %execution_id,
293 "Tool execution cancelled via external token"
294 );
295 TimeoutResult::Cancelled
296 },
297 };
298
299 {
301 let mut executions = self.active_executions.write().await;
302 executions.remove(&execution_id);
303 }
304
305 self.metrics
307 .tool_executions_active
308 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
309
310 let elapsed = started_at.elapsed();
311
312 match result {
313 TimeoutResult::Completed(Ok(value)) => {
314 debug!(
315 tool_name = %tool_name,
316 execution_id = %execution_id,
317 elapsed_ms = elapsed.as_millis(),
318 "Tool execution completed successfully"
319 );
320 Ok(value)
321 }
322 TimeoutResult::Completed(Err(server_error)) => {
323 warn!(
324 tool_name = %tool_name,
325 execution_id = %execution_id,
326 elapsed_ms = elapsed.as_millis(),
327 error = %server_error,
328 "Tool execution failed with server error"
329 );
330 Err(ToolTimeoutError::ServerError(server_error))
331 }
332 TimeoutResult::TimedOut => {
333 self.metrics
335 .tool_timeouts_total
336 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
337 self.metrics.timeout_wasted_time_us.fetch_add(
338 elapsed.as_micros() as u64,
339 std::sync::atomic::Ordering::Relaxed,
340 );
341 self.metrics
342 .errors_timeout
343 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
344
345 error!(
347 tool_name = %tool_name,
348 execution_id = %execution_id,
349 timeout_seconds = timeout_duration.as_secs(),
350 elapsed_ms = elapsed.as_millis(),
351 event_type = "TIMEOUT_EVENT",
352 security_concern = "potential_dos_indicator",
353 "🔒 SECURITY AUDIT: Tool execution timed out"
354 );
355 Err(ToolTimeoutError::Timeout {
356 tool_name: tool_name.to_string(),
357 timeout_duration,
358 elapsed,
359 })
360 }
361 TimeoutResult::Cancelled => {
362 self.metrics
364 .tool_cancellations_total
365 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
366
367 warn!(
369 tool_name = %tool_name,
370 execution_id = %execution_id,
371 elapsed_ms = elapsed.as_millis(),
372 event_type = "CANCELLATION_EVENT",
373 "🔒 AUDIT: Tool execution was cancelled cooperatively"
374 );
375 Err(ToolTimeoutError::Cancelled {
376 tool_name: tool_name.to_string(),
377 elapsed,
378 })
379 }
380 }
381 }
382
383 #[instrument(skip(self, operation), fields(tool_name = %tool_name))]
389 pub async fn execute_with_timeout<F, T>(
390 &self,
391 tool_name: &str,
392 operation: F,
393 ) -> Result<T, ToolTimeoutError>
394 where
395 F: std::future::Future<Output = Result<T, ServerError>>,
396 T: Send,
397 {
398 match self
400 .execute_with_timeout_and_cancellation(tool_name, operation)
401 .await
402 {
403 Ok((result, _token)) => Ok(result),
404 Err(error) => Err(error),
405 }
406 }
407
408 pub async fn get_active_executions(&self) -> Vec<ActiveExecutionInfo> {
413 let executions = self.active_executions.read().await;
414 executions
415 .iter()
416 .map(|(&id, execution)| ActiveExecutionInfo {
417 execution_id: id,
418 tool_name: execution.tool_name.clone(),
419 started_at: execution.started_at,
420 timeout_duration: execution.timeout_duration,
421 elapsed: execution.started_at.elapsed(),
422 cancellation_token: execution.cancellation_token.clone(),
423 cancelled: execution.cancelled,
424 })
425 .collect()
426 }
427
428 #[instrument(skip(self))]
434 pub async fn cancel_all_executions(&self) {
435 let mut executions = self.active_executions.write().await;
436 let count = executions.len();
437
438 if count > 0 {
439 warn!(
441 active_count = count,
442 event_type = "BULK_CANCELLATION_EVENT",
443 security_note = "emergency_shutdown_or_resource_cleanup",
444 "🔒 SECURITY AUDIT: Cancelling all active tool executions"
445 );
446
447 for execution in executions.values_mut() {
448 execution.cancellation_token.cancel();
450 execution.cancelled = true;
451 }
452
453 self.metrics
455 .tool_cancellations_total
456 .fetch_add(count as u64, std::sync::atomic::Ordering::Relaxed);
457
458 debug!(
459 cancelled_count = count,
460 "Sent cooperative cancellation signals to all active tool executions"
461 );
462 }
463
464 }
467}
468
469#[derive(Debug, Clone)]
471pub struct ActiveExecutionInfo {
472 pub execution_id: Uuid,
474 pub tool_name: String,
476 pub started_at: Instant,
478 pub timeout_duration: Duration,
480 pub elapsed: Duration,
482 pub cancellation_token: CancellationToken,
484 pub cancelled: bool,
486}
487
488#[derive(Debug, thiserror::Error)]
490pub enum ToolTimeoutError {
491 #[error("Tool '{tool_name}' timed out after {timeout_duration:?} (elapsed: {elapsed:?})")]
493 Timeout {
494 tool_name: String,
496 timeout_duration: Duration,
498 elapsed: Duration,
500 },
501
502 #[error("Tool '{tool_name}' was cancelled (elapsed: {elapsed:?})")]
504 Cancelled {
505 tool_name: String,
507 elapsed: Duration,
509 },
510
511 #[error("Tool execution failed: {0}")]
513 ServerError(ServerError),
514}
515
516#[derive(Debug)]
518enum TimeoutResult<T> {
519 Completed(Result<T, ServerError>),
520 TimedOut,
521 Cancelled,
522}
523
524impl From<ToolTimeoutError> for ServerError {
525 fn from(timeout_error: ToolTimeoutError) -> Self {
526 match timeout_error {
527 ToolTimeoutError::Timeout {
528 tool_name,
529 timeout_duration,
530 ..
531 } => ServerError::timeout(
532 format!("Tool '{}'", tool_name),
533 timeout_duration.as_millis() as u64,
534 ),
535 ToolTimeoutError::Cancelled { tool_name, .. } => {
536 ServerError::handler(format!("Tool '{}' was cancelled", tool_name))
537 }
538 ToolTimeoutError::ServerError(server_error) => server_error,
539 }
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use crate::ServerError;
547 use tokio::time::{Duration, sleep};
548
549 fn create_test_config() -> TimeoutConfig {
550 let mut tool_timeouts = HashMap::new();
551 tool_timeouts.insert("fast_tool".to_string(), 1); tool_timeouts.insert("slow_tool".to_string(), 5); TimeoutConfig {
555 request_timeout: Duration::from_secs(30),
556 connection_timeout: Duration::from_secs(10),
557 keep_alive_timeout: Duration::from_secs(60),
558 tool_execution_timeout: Duration::from_secs(3), tool_timeouts,
560 }
561 }
562
563 fn create_test_metrics() -> Arc<ServerMetrics> {
564 Arc::new(ServerMetrics::new())
565 }
566
567 #[tokio::test]
568 async fn test_successful_tool_execution() {
569 let manager = ToolTimeoutManager::new(create_test_config(), create_test_metrics());
570
571 let result = manager
572 .execute_with_timeout("test_tool", async {
573 sleep(Duration::from_millis(100)).await;
574 Ok::<String, ServerError>("success".to_string())
575 })
576 .await;
577
578 assert!(result.is_ok());
579 assert_eq!(result.unwrap(), "success");
580 }
581
582 #[tokio::test]
583 async fn test_tool_timeout() {
584 let manager = ToolTimeoutManager::new(create_test_config(), create_test_metrics());
585
586 let result = manager
588 .execute_with_timeout("fast_tool", async {
589 sleep(Duration::from_secs(2)).await; Ok::<String, ServerError>("should_not_reach".to_string())
591 })
592 .await;
593
594 assert!(result.is_err());
595 match result.unwrap_err() {
596 ToolTimeoutError::Timeout {
597 tool_name,
598 timeout_duration,
599 ..
600 } => {
601 assert_eq!(tool_name, "fast_tool");
602 assert_eq!(timeout_duration, Duration::from_secs(1));
603 }
604 _ => panic!("Expected timeout error"),
605 }
606 }
607
608 #[tokio::test]
609 async fn test_per_tool_timeout_override() {
610 let manager = ToolTimeoutManager::new(create_test_config(), create_test_metrics());
611
612 assert_eq!(
614 manager.get_tool_timeout("slow_tool"),
615 Duration::from_secs(5)
616 );
617
618 assert_eq!(
620 manager.get_tool_timeout("unknown_tool"),
621 Duration::from_secs(3)
622 );
623 }
624
625 #[tokio::test]
626 async fn test_server_error_propagation() {
627 let manager = ToolTimeoutManager::new(create_test_config(), create_test_metrics());
628
629 let result = manager
630 .execute_with_timeout("test_tool", async {
631 Err::<String, ServerError>(ServerError::handler("custom error"))
632 })
633 .await;
634
635 assert!(result.is_err());
636 match result.unwrap_err() {
637 ToolTimeoutError::ServerError(server_error) => {
638 assert!(server_error.to_string().contains("custom error"));
639 }
640 _ => panic!("Expected server error"),
641 }
642 }
643
644 #[tokio::test]
645 async fn test_active_executions_tracking() {
646 let manager = ToolTimeoutManager::new(create_test_config(), create_test_metrics());
647
648 let manager_clone = manager.clone();
649 let _handle = tokio::spawn(async move {
650 let _ = manager_clone
651 .execute_with_timeout("long_running", async {
652 sleep(Duration::from_millis(100)).await;
653 Ok::<String, ServerError>("done".to_string())
654 })
655 .await;
656 });
657
658 sleep(Duration::from_millis(10)).await;
660
661 let active = manager.get_active_executions().await;
662 assert!(active.len() <= 1); }
666}