turbomcp_server/
timeout.rs

1//! Tool timeout and cancellation management
2//!
3//! This module provides proven timeout and cancellation capabilities
4//! for tool execution in TurboMCP servers. It follows 2025 Rust async best
5//! practices for cancellation safety and proper resource cleanup.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::RwLock;
11use tokio::time::Instant;
12use tokio_util::sync::CancellationToken;
13use tracing::{debug, error, instrument, warn};
14use uuid::Uuid;
15
16use crate::ServerError;
17use crate::config::TimeoutConfig;
18use crate::metrics::ServerMetrics;
19
20/// Tool timeout and cancellation manager
21///
22/// Manages per-tool timeout policies and provides cancellation-safe
23/// timeout enforcement for tool execution with comprehensive audit logging.
24#[derive(Debug, Clone)]
25pub struct ToolTimeoutManager {
26    /// Default timeout configuration
27    config: TimeoutConfig,
28
29    /// Active tool executions for monitoring and cancellation
30    active_executions: Arc<RwLock<HashMap<Uuid, ToolExecution>>>,
31
32    /// Metrics for security audit and monitoring
33    metrics: Arc<ServerMetrics>,
34}
35
36/// Information about an active tool execution
37#[derive(Debug, Clone)]
38struct ToolExecution {
39    /// Tool name being executed
40    tool_name: String,
41    /// When execution started
42    started_at: Instant,
43    /// Configured timeout duration
44    timeout_duration: Duration,
45    /// Cancellation token for cooperative cancellation
46    cancellation_token: CancellationToken,
47    /// Whether execution has been marked for cancellation
48    cancelled: bool,
49}
50
51impl ToolTimeoutManager {
52    /// Create a new timeout manager with the given configuration and metrics
53    pub fn new(config: TimeoutConfig, metrics: Arc<ServerMetrics>) -> Self {
54        Self {
55            config,
56            active_executions: Arc::new(RwLock::new(HashMap::new())),
57            metrics,
58        }
59    }
60
61    /// Get the timeout duration for a specific tool
62    ///
63    /// Returns per-tool override if configured, otherwise returns default timeout.
64    pub fn get_tool_timeout(&self, tool_name: &str) -> Duration {
65        self.config
66            .tool_timeouts
67            .get(tool_name)
68            .map(|&seconds| Duration::from_secs(seconds))
69            .unwrap_or(self.config.tool_execution_timeout)
70    }
71
72    /// Execute a tool with timeout and cooperative cancellation support
73    ///
74    /// This is the primary method for executing tools with comprehensive
75    /// timeout and cancellation handling following Tokio best practices.
76    ///
77    /// Returns both the result and the cancellation token for context propagation.
78    #[instrument(skip(self, operation), fields(tool_name = %tool_name))]
79    pub async fn execute_with_timeout_and_cancellation<F, T>(
80        &self,
81        tool_name: &str,
82        operation: F,
83    ) -> Result<(T, CancellationToken), ToolTimeoutError>
84    where
85        F: std::future::Future<Output = Result<T, ServerError>>,
86        T: Send,
87    {
88        let execution_id = Uuid::new_v4();
89        let timeout_duration = self.get_tool_timeout(tool_name);
90        let started_at = Instant::now();
91
92        // Create cancellation token for cooperative cancellation
93        let cancellation_token = CancellationToken::new();
94
95        // Register this execution for monitoring and update active execution count
96        {
97            let mut executions = self.active_executions.write().await;
98            executions.insert(
99                execution_id,
100                ToolExecution {
101                    tool_name: tool_name.to_string(),
102                    started_at,
103                    timeout_duration,
104                    cancellation_token: cancellation_token.clone(),
105                    cancelled: false,
106                },
107            );
108        }
109
110        // Update active executions metric
111        self.metrics
112            .tool_executions_active
113            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
114
115        debug!(
116            tool_name = %tool_name,
117            execution_id = %execution_id,
118            timeout_seconds = timeout_duration.as_secs(),
119            "Starting tool execution with timeout"
120        );
121
122        // Execute with cooperative cancellation using tokio::select!
123        // This allows tools to respond to cancellation signals gracefully
124        let result = tokio::select! {
125            // Tool execution completed
126            operation_result = operation => {
127                TimeoutResult::Completed(operation_result)
128            },
129            // Timeout occurred
130            _ = tokio::time::sleep(timeout_duration) => {
131                // Cancel the token to signal cooperative cancellation
132                cancellation_token.cancel();
133                TimeoutResult::TimedOut
134            },
135            // Explicit cancellation requested
136            _ = cancellation_token.cancelled() => {
137                warn!(
138                    tool_name = %tool_name,
139                    execution_id = %execution_id,
140                    "Tool execution cancelled cooperatively"
141                );
142                TimeoutResult::Cancelled
143            },
144        };
145
146        // Clean up execution tracking and update active count
147        {
148            let mut executions = self.active_executions.write().await;
149            executions.remove(&execution_id);
150        }
151
152        // Decrement active executions metric
153        self.metrics
154            .tool_executions_active
155            .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
156
157        let elapsed = started_at.elapsed();
158
159        match result {
160            TimeoutResult::Completed(Ok(value)) => {
161                // Record successful execution metrics
162                self.metrics
163                    .tool_executions_successful
164                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
165
166                debug!(
167                    tool_name = %tool_name,
168                    execution_id = %execution_id,
169                    elapsed_ms = elapsed.as_millis(),
170                    "Tool execution completed successfully"
171                );
172                Ok((value, cancellation_token))
173            }
174            TimeoutResult::Completed(Err(server_error)) => {
175                warn!(
176                    tool_name = %tool_name,
177                    execution_id = %execution_id,
178                    elapsed_ms = elapsed.as_millis(),
179                    error = %server_error,
180                    "Tool execution failed with server error"
181                );
182                Err(ToolTimeoutError::ServerError(server_error))
183            }
184            TimeoutResult::TimedOut => {
185                // Record timeout metrics for security monitoring
186                self.metrics
187                    .tool_timeouts_total
188                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
189                self.metrics.timeout_wasted_time_us.fetch_add(
190                    elapsed.as_micros() as u64,
191                    std::sync::atomic::Ordering::Relaxed,
192                );
193                self.metrics
194                    .errors_timeout
195                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
196
197                // Security audit event - potential DoS indicator
198                error!(
199                    tool_name = %tool_name,
200                    execution_id = %execution_id,
201                    timeout_seconds = timeout_duration.as_secs(),
202                    elapsed_ms = elapsed.as_millis(),
203                    event_type = "TIMEOUT_EVENT",
204                    security_concern = "potential_dos_indicator",
205                    "🔒 SECURITY AUDIT: Tool execution timed out"
206                );
207                Err(ToolTimeoutError::Timeout {
208                    tool_name: tool_name.to_string(),
209                    timeout_duration,
210                    elapsed,
211                })
212            }
213            TimeoutResult::Cancelled => {
214                // Record cancellation metrics for monitoring
215                self.metrics
216                    .tool_cancellations_total
217                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
218
219                // Audit event for cancellation (normal operation)
220                warn!(
221                    tool_name = %tool_name,
222                    execution_id = %execution_id,
223                    elapsed_ms = elapsed.as_millis(),
224                    event_type = "CANCELLATION_EVENT",
225                    "🔒 AUDIT: Tool execution was cancelled cooperatively"
226                );
227                Err(ToolTimeoutError::Cancelled {
228                    tool_name: tool_name.to_string(),
229                    elapsed,
230                })
231            }
232        }
233    }
234
235    /// Execute a tool with a provided cancellation token
236    ///
237    /// This method allows external code to provide the cancellation token,
238    /// enabling tight integration with RequestContext and other systems.
239    #[instrument(skip(self, operation, cancellation_token), fields(tool_name = %tool_name))]
240    pub async fn execute_with_external_token<F, T>(
241        &self,
242        tool_name: &str,
243        operation: F,
244        cancellation_token: CancellationToken,
245    ) -> Result<T, ToolTimeoutError>
246    where
247        F: std::future::Future<Output = Result<T, ServerError>>,
248        T: Send,
249    {
250        let execution_id = Uuid::new_v4();
251        let timeout_duration = self.get_tool_timeout(tool_name);
252        let started_at = Instant::now();
253
254        // Register this execution for monitoring (using provided token)
255        {
256            let mut executions = self.active_executions.write().await;
257            executions.insert(
258                execution_id,
259                ToolExecution {
260                    tool_name: tool_name.to_string(),
261                    started_at,
262                    timeout_duration,
263                    cancellation_token: cancellation_token.clone(),
264                    cancelled: false,
265                },
266            );
267        }
268
269        debug!(
270            tool_name = %tool_name,
271            execution_id = %execution_id,
272            timeout_seconds = timeout_duration.as_secs(),
273            "Starting tool execution with provided cancellation token"
274        );
275
276        // Execute with cooperative cancellation using the provided token
277        let result = tokio::select! {
278            // Tool execution completed
279            operation_result = operation => {
280                TimeoutResult::Completed(operation_result)
281            },
282            // Timeout occurred
283            _ = tokio::time::sleep(timeout_duration) => {
284                // Cancel the token to signal cooperative cancellation
285                cancellation_token.cancel();
286                TimeoutResult::TimedOut
287            },
288            // Explicit cancellation requested via external token
289            _ = cancellation_token.cancelled() => {
290                warn!(
291                    tool_name = %tool_name,
292                    execution_id = %execution_id,
293                    "Tool execution cancelled via external token"
294                );
295                TimeoutResult::Cancelled
296            },
297        };
298
299        // Clean up execution tracking and update active count
300        {
301            let mut executions = self.active_executions.write().await;
302            executions.remove(&execution_id);
303        }
304
305        // Decrement active executions metric
306        self.metrics
307            .tool_executions_active
308            .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
309
310        let elapsed = started_at.elapsed();
311
312        match result {
313            TimeoutResult::Completed(Ok(value)) => {
314                debug!(
315                    tool_name = %tool_name,
316                    execution_id = %execution_id,
317                    elapsed_ms = elapsed.as_millis(),
318                    "Tool execution completed successfully"
319                );
320                Ok(value)
321            }
322            TimeoutResult::Completed(Err(server_error)) => {
323                warn!(
324                    tool_name = %tool_name,
325                    execution_id = %execution_id,
326                    elapsed_ms = elapsed.as_millis(),
327                    error = %server_error,
328                    "Tool execution failed with server error"
329                );
330                Err(ToolTimeoutError::ServerError(server_error))
331            }
332            TimeoutResult::TimedOut => {
333                // Record timeout metrics for security monitoring
334                self.metrics
335                    .tool_timeouts_total
336                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
337                self.metrics.timeout_wasted_time_us.fetch_add(
338                    elapsed.as_micros() as u64,
339                    std::sync::atomic::Ordering::Relaxed,
340                );
341                self.metrics
342                    .errors_timeout
343                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
344
345                // Security audit event - potential DoS indicator
346                error!(
347                    tool_name = %tool_name,
348                    execution_id = %execution_id,
349                    timeout_seconds = timeout_duration.as_secs(),
350                    elapsed_ms = elapsed.as_millis(),
351                    event_type = "TIMEOUT_EVENT",
352                    security_concern = "potential_dos_indicator",
353                    "🔒 SECURITY AUDIT: Tool execution timed out"
354                );
355                Err(ToolTimeoutError::Timeout {
356                    tool_name: tool_name.to_string(),
357                    timeout_duration,
358                    elapsed,
359                })
360            }
361            TimeoutResult::Cancelled => {
362                // Record cancellation metrics for monitoring
363                self.metrics
364                    .tool_cancellations_total
365                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
366
367                // Audit event for cancellation (normal operation)
368                warn!(
369                    tool_name = %tool_name,
370                    execution_id = %execution_id,
371                    elapsed_ms = elapsed.as_millis(),
372                    event_type = "CANCELLATION_EVENT",
373                    "🔒 AUDIT: Tool execution was cancelled cooperatively"
374                );
375                Err(ToolTimeoutError::Cancelled {
376                    tool_name: tool_name.to_string(),
377                    elapsed,
378                })
379            }
380        }
381    }
382
383    /// Execute a tool with timeout (backward compatible API)
384    ///
385    /// This method provides backward compatibility for existing code that doesn't
386    /// need access to the cancellation token. New code should use
387    /// `execute_with_timeout_and_cancellation` for cooperative cancellation support.
388    #[instrument(skip(self, operation), fields(tool_name = %tool_name))]
389    pub async fn execute_with_timeout<F, T>(
390        &self,
391        tool_name: &str,
392        operation: F,
393    ) -> Result<T, ToolTimeoutError>
394    where
395        F: std::future::Future<Output = Result<T, ServerError>>,
396        T: Send,
397    {
398        // Use the new method and discard the cancellation token for compatibility
399        match self
400            .execute_with_timeout_and_cancellation(tool_name, operation)
401            .await
402        {
403            Ok((result, _token)) => Ok(result),
404            Err(error) => Err(error),
405        }
406    }
407
408    /// Get statistics about active tool executions
409    ///
410    /// Returns information about currently running tools for monitoring
411    /// and debugging purposes.
412    pub async fn get_active_executions(&self) -> Vec<ActiveExecutionInfo> {
413        let executions = self.active_executions.read().await;
414        executions
415            .iter()
416            .map(|(&id, execution)| ActiveExecutionInfo {
417                execution_id: id,
418                tool_name: execution.tool_name.clone(),
419                started_at: execution.started_at,
420                timeout_duration: execution.timeout_duration,
421                elapsed: execution.started_at.elapsed(),
422                cancellation_token: execution.cancellation_token.clone(),
423                cancelled: execution.cancelled,
424            })
425            .collect()
426    }
427
428    /// Cancel all active executions (for graceful shutdown)
429    ///
430    /// Signals cooperative cancellation to all active tool executions.
431    /// Tools that check their cancellation tokens will receive the signal
432    /// and can perform graceful cleanup before terminating.
433    #[instrument(skip(self))]
434    pub async fn cancel_all_executions(&self) {
435        let mut executions = self.active_executions.write().await;
436        let count = executions.len();
437
438        if count > 0 {
439            // Security audit event for bulk cancellation - could indicate emergency shutdown
440            warn!(
441                active_count = count,
442                event_type = "BULK_CANCELLATION_EVENT",
443                security_note = "emergency_shutdown_or_resource_cleanup",
444                "🔒 SECURITY AUDIT: Cancelling all active tool executions"
445            );
446
447            for execution in executions.values_mut() {
448                // Signal cooperative cancellation via the token
449                execution.cancellation_token.cancel();
450                execution.cancelled = true;
451            }
452
453            // Update cancellation metrics for bulk operation
454            self.metrics
455                .tool_cancellations_total
456                .fetch_add(count as u64, std::sync::atomic::Ordering::Relaxed);
457
458            debug!(
459                cancelled_count = count,
460                "Sent cooperative cancellation signals to all active tool executions"
461            );
462        }
463
464        // Cooperative cancellation tokens allow tools to respond gracefully
465        // Tools that check cancellation_token.is_cancelled() will see the signal
466    }
467}
468
469/// Information about an active tool execution
470#[derive(Debug, Clone)]
471pub struct ActiveExecutionInfo {
472    /// Unique execution identifier
473    pub execution_id: Uuid,
474    /// Tool name being executed
475    pub tool_name: String,
476    /// When execution started
477    pub started_at: Instant,
478    /// Configured timeout duration
479    pub timeout_duration: Duration,
480    /// How long execution has been running
481    pub elapsed: Duration,
482    /// Cancellation token for this execution
483    pub cancellation_token: CancellationToken,
484    /// Whether execution has been cancelled
485    pub cancelled: bool,
486}
487
488/// Tool timeout error types
489#[derive(Debug, thiserror::Error)]
490pub enum ToolTimeoutError {
491    /// Tool execution exceeded configured timeout
492    #[error("Tool '{tool_name}' timed out after {timeout_duration:?} (elapsed: {elapsed:?})")]
493    Timeout {
494        /// Name of the tool that timed out
495        tool_name: String,
496        /// Configured timeout duration that was exceeded
497        timeout_duration: Duration,
498        /// Actual time elapsed before timeout
499        elapsed: Duration,
500    },
501
502    /// Tool execution was cancelled cooperatively
503    #[error("Tool '{tool_name}' was cancelled (elapsed: {elapsed:?})")]
504    Cancelled {
505        /// Name of the tool that was cancelled
506        tool_name: String,
507        /// Time elapsed before cancellation
508        elapsed: Duration,
509    },
510
511    /// Tool execution failed with server error
512    #[error("Tool execution failed: {0}")]
513    ServerError(ServerError),
514}
515
516/// Internal result type for timeout operations
517#[derive(Debug)]
518enum TimeoutResult<T> {
519    Completed(Result<T, ServerError>),
520    TimedOut,
521    Cancelled,
522}
523
524impl From<ToolTimeoutError> for ServerError {
525    fn from(timeout_error: ToolTimeoutError) -> Self {
526        match timeout_error {
527            ToolTimeoutError::Timeout {
528                tool_name,
529                timeout_duration,
530                ..
531            } => ServerError::timeout(
532                format!("Tool '{}'", tool_name),
533                timeout_duration.as_millis() as u64,
534            ),
535            ToolTimeoutError::Cancelled { tool_name, .. } => {
536                ServerError::handler(format!("Tool '{}' was cancelled", tool_name))
537            }
538            ToolTimeoutError::ServerError(server_error) => server_error,
539        }
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546    use crate::ServerError;
547    use tokio::time::{Duration, sleep};
548
549    fn create_test_config() -> TimeoutConfig {
550        let mut tool_timeouts = HashMap::new();
551        tool_timeouts.insert("fast_tool".to_string(), 1); // 1 second
552        tool_timeouts.insert("slow_tool".to_string(), 5); // 5 seconds
553
554        TimeoutConfig {
555            request_timeout: Duration::from_secs(30),
556            connection_timeout: Duration::from_secs(10),
557            keep_alive_timeout: Duration::from_secs(60),
558            tool_execution_timeout: Duration::from_secs(3), // 3 second default
559            tool_timeouts,
560        }
561    }
562
563    fn create_test_metrics() -> Arc<ServerMetrics> {
564        Arc::new(ServerMetrics::new())
565    }
566
567    #[tokio::test]
568    async fn test_successful_tool_execution() {
569        let manager = ToolTimeoutManager::new(create_test_config(), create_test_metrics());
570
571        let result = manager
572            .execute_with_timeout("test_tool", async {
573                sleep(Duration::from_millis(100)).await;
574                Ok::<String, ServerError>("success".to_string())
575            })
576            .await;
577
578        assert!(result.is_ok());
579        assert_eq!(result.unwrap(), "success");
580    }
581
582    #[tokio::test]
583    async fn test_tool_timeout() {
584        let manager = ToolTimeoutManager::new(create_test_config(), create_test_metrics());
585
586        // This should timeout after 1 second (fast_tool override)
587        let result = manager
588            .execute_with_timeout("fast_tool", async {
589                sleep(Duration::from_secs(2)).await; // Sleep longer than timeout
590                Ok::<String, ServerError>("should_not_reach".to_string())
591            })
592            .await;
593
594        assert!(result.is_err());
595        match result.unwrap_err() {
596            ToolTimeoutError::Timeout {
597                tool_name,
598                timeout_duration,
599                ..
600            } => {
601                assert_eq!(tool_name, "fast_tool");
602                assert_eq!(timeout_duration, Duration::from_secs(1));
603            }
604            _ => panic!("Expected timeout error"),
605        }
606    }
607
608    #[tokio::test]
609    async fn test_per_tool_timeout_override() {
610        let manager = ToolTimeoutManager::new(create_test_config(), create_test_metrics());
611
612        // Test that slow_tool gets its 5-second override
613        assert_eq!(
614            manager.get_tool_timeout("slow_tool"),
615            Duration::from_secs(5)
616        );
617
618        // Test that unknown tool gets default timeout
619        assert_eq!(
620            manager.get_tool_timeout("unknown_tool"),
621            Duration::from_secs(3)
622        );
623    }
624
625    #[tokio::test]
626    async fn test_server_error_propagation() {
627        let manager = ToolTimeoutManager::new(create_test_config(), create_test_metrics());
628
629        let result = manager
630            .execute_with_timeout("test_tool", async {
631                Err::<String, ServerError>(ServerError::handler("custom error"))
632            })
633            .await;
634
635        assert!(result.is_err());
636        match result.unwrap_err() {
637            ToolTimeoutError::ServerError(server_error) => {
638                assert!(server_error.to_string().contains("custom error"));
639            }
640            _ => panic!("Expected server error"),
641        }
642    }
643
644    #[tokio::test]
645    async fn test_active_executions_tracking() {
646        let manager = ToolTimeoutManager::new(create_test_config(), create_test_metrics());
647
648        let manager_clone = manager.clone();
649        let _handle = tokio::spawn(async move {
650            let _ = manager_clone
651                .execute_with_timeout("long_running", async {
652                    sleep(Duration::from_millis(100)).await;
653                    Ok::<String, ServerError>("done".to_string())
654                })
655                .await;
656        });
657
658        // Give the task time to start
659        sleep(Duration::from_millis(10)).await;
660
661        let active = manager.get_active_executions().await;
662        // Should have at least one execution (may be completed by now)
663        // This test mainly ensures the tracking API works
664        assert!(active.len() <= 1); // Could be 0 if already completed
665    }
666}