Skip to main content

tower_resilience_timelimiter/
lib.rs

1//! Advanced timeout handling for Tower services.
2//!
3//! Provides timeout functionality with:
4//! - Configurable timeout duration (fixed or per-request)
5//! - Optional future cancellation on timeout
6//! - Event system for observability (onSuccess, onError, onTimeout)
7//! - Metrics integration
8//!
9//! ## Presets
10//!
11//! ```rust
12//! use tower_resilience_timelimiter::TimeLimiterLayer;
13//!
14//! let fast = TimeLimiterLayer::fast().build();        // 1s, cancel on timeout
15//! let standard = TimeLimiterLayer::standard().build(); // 5s, cancel on timeout
16//! let slow = TimeLimiterLayer::slow().build();         // 30s, cancel on timeout
17//! let stream = TimeLimiterLayer::streaming().build();  // 60s, no cancellation
18//! ```
19//!
20//! Presets return builders, so you can customize further:
21//!
22//! ```rust
23//! use tower_resilience_timelimiter::TimeLimiterLayer;
24//!
25//! let layer = TimeLimiterLayer::fast()
26//!     .name("api-timeout")
27//!     .on_timeout(|| eprintln!("Request timed out!"))
28//!     .build();
29//! ```
30//!
31//! ## Basic Example (Fixed Timeout - No Type Parameters!)
32//!
33//! ```rust
34//! use tower_resilience_timelimiter::TimeLimiterLayer;
35//! use tower::{Layer, service_fn};
36//! use std::time::Duration;
37//!
38//! # async fn example() {
39//! // No type parameters needed for fixed timeout!
40//! let layer = TimeLimiterLayer::builder()
41//!     .timeout_duration(Duration::from_secs(5))
42//!     .cancel_running_future(true)
43//!     .on_timeout(|| {
44//!         eprintln!("Request timed out!");
45//!     })
46//!     .build();
47//!
48//! let svc = service_fn(|req: String| async move {
49//!     Ok::<String, ()>(req)
50//! });
51//!
52//! let mut service = layer.layer(svc);
53//! # }
54//! ```
55//!
56//! ## Per-Request Timeout
57//!
58//! Extract timeout from the request itself for different SLAs:
59//!
60//! ```rust
61//! use tower_resilience_timelimiter::TimeLimiterLayer;
62//! use tower::{Layer, service_fn};
63//! use std::time::Duration;
64//!
65//! #[derive(Clone)]
66//! struct MyRequest {
67//!     operation: String,
68//!     timeout_ms: Option<u64>,
69//! }
70//!
71//! # async fn example() {
72//! // Types inferred from closure signature
73//! let layer = TimeLimiterLayer::builder()
74//!     .timeout_fn(|req: &MyRequest| {
75//!         req.timeout_ms
76//!             .map(Duration::from_millis)
77//!             .unwrap_or(Duration::from_secs(5))
78//!     })
79//!     .build();
80//!
81//! let svc = service_fn(|req: MyRequest| async move {
82//!     Ok::<String, ()>(format!("Processed: {}", req.operation))
83//! });
84//!
85//! let mut service = layer.layer(svc);
86//! # }
87//! ```
88//!
89//! ## Event Listeners
90//!
91//! ```rust
92//! use tower_resilience_timelimiter::TimeLimiterLayer;
93//! use std::time::Duration;
94//!
95//! # async fn example() {
96//! let layer = TimeLimiterLayer::builder()
97//!     .timeout_duration(Duration::from_secs(5))
98//!     .on_success(|duration| {
99//!         println!("Call succeeded in {:?}", duration);
100//!     })
101//!     .on_error(|duration| {
102//!         println!("Call failed after {:?}", duration);
103//!     })
104//!     .on_timeout(|| {
105//!         println!("Call timed out");
106//!     })
107//!     .build();
108//! # }
109//! ```
110
111use futures::future::BoxFuture;
112use std::sync::Arc;
113use std::task::{Context, Poll};
114use std::time::Instant;
115use tokio::time::timeout;
116use tower::Service;
117
118#[cfg(feature = "metrics")]
119use metrics::{counter, describe_counter, describe_histogram, histogram};
120
121#[cfg(feature = "tracing")]
122use tracing::{debug, warn};
123
124pub use config::{
125    DynamicTimeout, FixedTimeout, TimeLimiterConfig, TimeLimiterConfigBuilder, TimeoutFn,
126};
127pub use error::TimeLimiterError;
128pub use events::TimeLimiterEvent;
129pub use layer::TimeLimiterLayer;
130
131mod config;
132mod error;
133mod events;
134mod layer;
135
136/// A Tower service that applies timeout limiting to an inner service.
137///
138/// The type parameter `T` is the timeout source:
139/// - `FixedTimeout` - uses the same timeout for all requests
140/// - `DynamicTimeout<F>` - extracts timeout from each request using closure F
141pub struct TimeLimiter<S, T> {
142    inner: S,
143    config: Arc<TimeLimiterConfig<T>>,
144}
145
146impl<S: Clone, T> Clone for TimeLimiter<S, T> {
147    fn clone(&self) -> Self {
148        Self {
149            inner: self.inner.clone(),
150            config: Arc::clone(&self.config),
151        }
152    }
153}
154
155impl<S, T> TimeLimiter<S, T> {
156    /// Creates a new time limiter wrapping the given service.
157    pub(crate) fn new(inner: S, config: Arc<TimeLimiterConfig<T>>) -> Self {
158        #[cfg(feature = "metrics")]
159        {
160            describe_counter!(
161                "timelimiter_calls_total",
162                "Total number of time limiter calls (success, error, or timeout)"
163            );
164            describe_histogram!(
165                "timelimiter_call_duration_seconds",
166                "Duration of calls (successful or failed)"
167            );
168        }
169
170        Self { inner, config }
171    }
172}
173
174impl<S, T, Req> Service<Req> for TimeLimiter<S, T>
175where
176    S: Service<Req> + Clone + Send + 'static,
177    S::Future: Send + 'static,
178    S::Response: Send + 'static,
179    S::Error: Send + 'static,
180    Req: Send + 'static,
181    T: TimeoutFn<Req> + 'static,
182{
183    type Response = S::Response;
184    type Error = TimeLimiterError<S::Error>;
185    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
186
187    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
188        self.inner.poll_ready(cx).map_err(TimeLimiterError::Inner)
189    }
190
191    fn call(&mut self, req: Req) -> Self::Future {
192        let mut inner = self.inner.clone();
193        let config = Arc::clone(&self.config);
194
195        // Extract timeout from request before moving it
196        let timeout_duration = config.timeout_source.get_timeout(&req);
197        let cancel_on_timeout = config.cancel_running_future;
198
199        Box::pin(async move {
200            let start = Instant::now();
201
202            // Use Option to represent timeout (None = timed out, Some = got result)
203            let result: Option<Result<S::Response, S::Error>> = if cancel_on_timeout {
204                // Default behavior: timeout cancels the future by dropping it
205                timeout(timeout_duration, inner.call(req)).await.ok()
206            } else {
207                // Non-cancelling behavior: spawn the future and let it continue on timeout
208                let (tx, rx) = tokio::sync::oneshot::channel();
209
210                tokio::spawn(async move {
211                    let result = inner.call(req).await;
212                    // Ignore send error - receiver may have been dropped on timeout
213                    let _ = tx.send(result);
214                });
215
216                tokio::select! {
217                    result = rx => {
218                        // Task completed - unwrap the channel result
219                        result.ok()
220                    }
221                    _ = tokio::time::sleep(timeout_duration) => {
222                        // Timeout fired, but the spawned task continues running
223                        None
224                    }
225                }
226            };
227
228            match result {
229                Some(Ok(response)) => {
230                    let duration = start.elapsed();
231                    config.event_listeners.emit(&TimeLimiterEvent::Success {
232                        pattern_name: config.name.clone(),
233                        timestamp: Instant::now(),
234                        duration,
235                    });
236
237                    #[cfg(feature = "metrics")]
238                    {
239                        counter!("timelimiter_calls_total", "timelimiter" => config.name.clone(), "result" => "success").increment(1);
240                        histogram!("timelimiter_call_duration_seconds", "timelimiter" => config.name.clone())
241                            .record(duration.as_secs_f64());
242                    }
243
244                    #[cfg(feature = "tracing")]
245                    debug!(
246                        timelimiter = %config.name,
247                        duration_ms = duration.as_millis(),
248                        "Call succeeded within timeout"
249                    );
250
251                    Ok(response)
252                }
253                Some(Err(err)) => {
254                    let duration = start.elapsed();
255                    config.event_listeners.emit(&TimeLimiterEvent::Error {
256                        pattern_name: config.name.clone(),
257                        timestamp: Instant::now(),
258                        duration,
259                    });
260
261                    #[cfg(feature = "metrics")]
262                    {
263                        counter!("timelimiter_calls_total", "timelimiter" => config.name.clone(), "result" => "error").increment(1);
264                        histogram!("timelimiter_call_duration_seconds", "timelimiter" => config.name.clone())
265                            .record(duration.as_secs_f64());
266                    }
267
268                    #[cfg(feature = "tracing")]
269                    debug!(
270                        timelimiter = %config.name,
271                        duration_ms = duration.as_millis(),
272                        "Call failed within timeout"
273                    );
274
275                    Err(TimeLimiterError::Inner(err))
276                }
277                None => {
278                    config.event_listeners.emit(&TimeLimiterEvent::Timeout {
279                        pattern_name: config.name.clone(),
280                        timestamp: Instant::now(),
281                        timeout_duration,
282                    });
283
284                    #[cfg(feature = "metrics")]
285                    {
286                        counter!("timelimiter_calls_total", "timelimiter" => config.name.clone(), "result" => "timeout").increment(1);
287                    }
288
289                    #[cfg(feature = "tracing")]
290                    warn!(
291                        timelimiter = %config.name,
292                        timeout_ms = timeout_duration.as_millis(),
293                        "Call timed out"
294                    );
295
296                    Err(TimeLimiterError::Timeout)
297                }
298            }
299        })
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use std::sync::atomic::{AtomicUsize, Ordering};
307    use std::time::Duration;
308    use tokio::time::sleep;
309    use tower::{service_fn, Layer, ServiceExt};
310
311    #[tokio::test]
312    async fn test_success_within_timeout() {
313        // No type parameters needed!
314        let layer = TimeLimiterLayer::builder()
315            .timeout_duration(Duration::from_millis(100))
316            .build();
317
318        let svc = service_fn(|_req: ()| async {
319            sleep(Duration::from_millis(10)).await;
320            Ok::<_, ()>("success")
321        });
322
323        let mut service = layer.layer(svc);
324        let result = service.ready().await.unwrap().call(()).await;
325
326        assert!(result.is_ok());
327        assert_eq!(result.unwrap(), "success");
328    }
329
330    #[tokio::test]
331    async fn test_timeout_occurs() {
332        let layer = TimeLimiterLayer::builder()
333            .timeout_duration(Duration::from_millis(10))
334            .build();
335
336        let svc = service_fn(|_req: ()| async {
337            sleep(Duration::from_millis(100)).await;
338            Ok::<_, ()>("success")
339        });
340
341        let mut service = layer.layer(svc);
342        let result = service.ready().await.unwrap().call(()).await;
343
344        assert!(result.is_err());
345        assert!(result.unwrap_err().is_timeout());
346    }
347
348    #[tokio::test]
349    async fn test_inner_error_propagates() {
350        let layer = TimeLimiterLayer::builder()
351            .timeout_duration(Duration::from_millis(100))
352            .build();
353
354        let svc = service_fn(|_req: ()| async { Err::<(), _>("inner error") });
355
356        let mut service = layer.layer(svc);
357        let result = service.ready().await.unwrap().call(()).await;
358
359        assert!(result.is_err());
360        let err = result.unwrap_err();
361        assert!(!err.is_timeout());
362        assert_eq!(err.into_inner(), Some("inner error"));
363    }
364
365    #[tokio::test]
366    async fn test_event_listeners() {
367        let success_count = Arc::new(AtomicUsize::new(0));
368        let timeout_count = Arc::new(AtomicUsize::new(0));
369
370        let sc = Arc::clone(&success_count);
371        let tc = Arc::clone(&timeout_count);
372
373        let layer = TimeLimiterLayer::builder()
374            .timeout_duration(Duration::from_millis(50))
375            .on_success(move |_| {
376                sc.fetch_add(1, Ordering::SeqCst);
377            })
378            .on_timeout(move || {
379                tc.fetch_add(1, Ordering::SeqCst);
380            })
381            .build();
382
383        // Test success
384        let svc = service_fn(|_req: ()| async {
385            sleep(Duration::from_millis(10)).await;
386            Ok::<_, ()>("ok")
387        });
388        let mut service = layer.layer(svc);
389        let _ = service.ready().await.unwrap().call(()).await;
390        assert_eq!(success_count.load(Ordering::SeqCst), 1);
391
392        // Test timeout
393        let svc = service_fn(|_req: ()| async {
394            sleep(Duration::from_millis(100)).await;
395            Ok::<_, ()>("ok")
396        });
397        let mut service = layer.layer(svc);
398        let _ = service.ready().await.unwrap().call(()).await;
399        assert_eq!(timeout_count.load(Ordering::SeqCst), 1);
400    }
401
402    #[tokio::test]
403    async fn test_per_request_timeout() {
404        #[derive(Clone)]
405        struct Request {
406            timeout_ms: u64,
407            sleep_ms: u64,
408        }
409
410        // Types inferred from closure
411        let layer = TimeLimiterLayer::builder()
412            .timeout_fn(|req: &Request| Duration::from_millis(req.timeout_ms))
413            .build();
414
415        let svc = service_fn(|req: Request| async move {
416            sleep(Duration::from_millis(req.sleep_ms)).await;
417            Ok::<_, ()>("done")
418        });
419
420        let mut service = layer.layer(svc);
421
422        // Request with long timeout, short sleep - should succeed
423        let result = service
424            .ready()
425            .await
426            .unwrap()
427            .call(Request {
428                timeout_ms: 100,
429                sleep_ms: 10,
430            })
431            .await;
432        assert!(result.is_ok());
433
434        // Request with short timeout, long sleep - should timeout
435        let result = service
436            .ready()
437            .await
438            .unwrap()
439            .call(Request {
440                timeout_ms: 10,
441                sleep_ms: 100,
442            })
443            .await;
444        assert!(result.is_err());
445        assert!(result.unwrap_err().is_timeout());
446    }
447
448    #[tokio::test]
449    async fn test_different_timeouts_per_request() {
450        #[derive(Clone)]
451        struct Request {
452            #[allow(dead_code)]
453            id: u32,
454            timeout_ms: Option<u64>,
455        }
456
457        let layer = TimeLimiterLayer::builder()
458            .timeout_fn(|req: &Request| {
459                req.timeout_ms
460                    .map(Duration::from_millis)
461                    .unwrap_or(Duration::from_millis(50)) // default
462            })
463            .build();
464
465        let svc = service_fn(|_req: Request| async move {
466            sleep(Duration::from_millis(30)).await;
467            Ok::<_, ()>("done")
468        });
469
470        let mut service = layer.layer(svc);
471
472        // Request with custom timeout (100ms) - should succeed (30ms < 100ms)
473        let result = service
474            .ready()
475            .await
476            .unwrap()
477            .call(Request {
478                id: 1,
479                timeout_ms: Some(100),
480            })
481            .await;
482        assert!(result.is_ok());
483
484        // Request with custom timeout (10ms) - should timeout (30ms > 10ms)
485        let result = service
486            .ready()
487            .await
488            .unwrap()
489            .call(Request {
490                id: 2,
491                timeout_ms: Some(10),
492            })
493            .await;
494        assert!(result.is_err());
495
496        // Request with default timeout (50ms) - should succeed (30ms < 50ms)
497        let result = service
498            .ready()
499            .await
500            .unwrap()
501            .call(Request {
502                id: 3,
503                timeout_ms: None,
504            })
505            .await;
506        assert!(result.is_ok());
507    }
508
509    // Note: The cancel_running_future flag is tested in tests/timelimiter/cancellation.rs
510}