tower_resilience_timelimiter/
lib.rs1use futures::future::BoxFuture;
112use std::sync::Arc;
113use std::task::{Context, Poll};
114use std::time::Instant;
115use tokio::time::timeout;
116use tower::Service;
117
118#[cfg(feature = "metrics")]
119use metrics::{counter, describe_counter, describe_histogram, histogram};
120
121#[cfg(feature = "tracing")]
122use tracing::{debug, warn};
123
124pub use config::{
125 DynamicTimeout, FixedTimeout, TimeLimiterConfig, TimeLimiterConfigBuilder, TimeoutFn,
126};
127pub use error::TimeLimiterError;
128pub use events::TimeLimiterEvent;
129pub use layer::TimeLimiterLayer;
130
131mod config;
132mod error;
133mod events;
134mod layer;
135
136pub struct TimeLimiter<S, T> {
142 inner: S,
143 config: Arc<TimeLimiterConfig<T>>,
144}
145
146impl<S: Clone, T> Clone for TimeLimiter<S, T> {
147 fn clone(&self) -> Self {
148 Self {
149 inner: self.inner.clone(),
150 config: Arc::clone(&self.config),
151 }
152 }
153}
154
155impl<S, T> TimeLimiter<S, T> {
156 pub(crate) fn new(inner: S, config: Arc<TimeLimiterConfig<T>>) -> Self {
158 #[cfg(feature = "metrics")]
159 {
160 describe_counter!(
161 "timelimiter_calls_total",
162 "Total number of time limiter calls (success, error, or timeout)"
163 );
164 describe_histogram!(
165 "timelimiter_call_duration_seconds",
166 "Duration of calls (successful or failed)"
167 );
168 }
169
170 Self { inner, config }
171 }
172}
173
174impl<S, T, Req> Service<Req> for TimeLimiter<S, T>
175where
176 S: Service<Req> + Clone + Send + 'static,
177 S::Future: Send + 'static,
178 S::Response: Send + 'static,
179 S::Error: Send + 'static,
180 Req: Send + 'static,
181 T: TimeoutFn<Req> + 'static,
182{
183 type Response = S::Response;
184 type Error = TimeLimiterError<S::Error>;
185 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
186
187 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
188 self.inner.poll_ready(cx).map_err(TimeLimiterError::Inner)
189 }
190
191 fn call(&mut self, req: Req) -> Self::Future {
192 let mut inner = self.inner.clone();
193 let config = Arc::clone(&self.config);
194
195 let timeout_duration = config.timeout_source.get_timeout(&req);
197 let cancel_on_timeout = config.cancel_running_future;
198
199 Box::pin(async move {
200 let start = Instant::now();
201
202 let result: Option<Result<S::Response, S::Error>> = if cancel_on_timeout {
204 timeout(timeout_duration, inner.call(req)).await.ok()
206 } else {
207 let (tx, rx) = tokio::sync::oneshot::channel();
209
210 tokio::spawn(async move {
211 let result = inner.call(req).await;
212 let _ = tx.send(result);
214 });
215
216 tokio::select! {
217 result = rx => {
218 result.ok()
220 }
221 _ = tokio::time::sleep(timeout_duration) => {
222 None
224 }
225 }
226 };
227
228 match result {
229 Some(Ok(response)) => {
230 let duration = start.elapsed();
231 config.event_listeners.emit(&TimeLimiterEvent::Success {
232 pattern_name: config.name.clone(),
233 timestamp: Instant::now(),
234 duration,
235 });
236
237 #[cfg(feature = "metrics")]
238 {
239 counter!("timelimiter_calls_total", "timelimiter" => config.name.clone(), "result" => "success").increment(1);
240 histogram!("timelimiter_call_duration_seconds", "timelimiter" => config.name.clone())
241 .record(duration.as_secs_f64());
242 }
243
244 #[cfg(feature = "tracing")]
245 debug!(
246 timelimiter = %config.name,
247 duration_ms = duration.as_millis(),
248 "Call succeeded within timeout"
249 );
250
251 Ok(response)
252 }
253 Some(Err(err)) => {
254 let duration = start.elapsed();
255 config.event_listeners.emit(&TimeLimiterEvent::Error {
256 pattern_name: config.name.clone(),
257 timestamp: Instant::now(),
258 duration,
259 });
260
261 #[cfg(feature = "metrics")]
262 {
263 counter!("timelimiter_calls_total", "timelimiter" => config.name.clone(), "result" => "error").increment(1);
264 histogram!("timelimiter_call_duration_seconds", "timelimiter" => config.name.clone())
265 .record(duration.as_secs_f64());
266 }
267
268 #[cfg(feature = "tracing")]
269 debug!(
270 timelimiter = %config.name,
271 duration_ms = duration.as_millis(),
272 "Call failed within timeout"
273 );
274
275 Err(TimeLimiterError::Inner(err))
276 }
277 None => {
278 config.event_listeners.emit(&TimeLimiterEvent::Timeout {
279 pattern_name: config.name.clone(),
280 timestamp: Instant::now(),
281 timeout_duration,
282 });
283
284 #[cfg(feature = "metrics")]
285 {
286 counter!("timelimiter_calls_total", "timelimiter" => config.name.clone(), "result" => "timeout").increment(1);
287 }
288
289 #[cfg(feature = "tracing")]
290 warn!(
291 timelimiter = %config.name,
292 timeout_ms = timeout_duration.as_millis(),
293 "Call timed out"
294 );
295
296 Err(TimeLimiterError::Timeout)
297 }
298 }
299 })
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use std::sync::atomic::{AtomicUsize, Ordering};
307 use std::time::Duration;
308 use tokio::time::sleep;
309 use tower::{service_fn, Layer, ServiceExt};
310
311 #[tokio::test]
312 async fn test_success_within_timeout() {
313 let layer = TimeLimiterLayer::builder()
315 .timeout_duration(Duration::from_millis(100))
316 .build();
317
318 let svc = service_fn(|_req: ()| async {
319 sleep(Duration::from_millis(10)).await;
320 Ok::<_, ()>("success")
321 });
322
323 let mut service = layer.layer(svc);
324 let result = service.ready().await.unwrap().call(()).await;
325
326 assert!(result.is_ok());
327 assert_eq!(result.unwrap(), "success");
328 }
329
330 #[tokio::test]
331 async fn test_timeout_occurs() {
332 let layer = TimeLimiterLayer::builder()
333 .timeout_duration(Duration::from_millis(10))
334 .build();
335
336 let svc = service_fn(|_req: ()| async {
337 sleep(Duration::from_millis(100)).await;
338 Ok::<_, ()>("success")
339 });
340
341 let mut service = layer.layer(svc);
342 let result = service.ready().await.unwrap().call(()).await;
343
344 assert!(result.is_err());
345 assert!(result.unwrap_err().is_timeout());
346 }
347
348 #[tokio::test]
349 async fn test_inner_error_propagates() {
350 let layer = TimeLimiterLayer::builder()
351 .timeout_duration(Duration::from_millis(100))
352 .build();
353
354 let svc = service_fn(|_req: ()| async { Err::<(), _>("inner error") });
355
356 let mut service = layer.layer(svc);
357 let result = service.ready().await.unwrap().call(()).await;
358
359 assert!(result.is_err());
360 let err = result.unwrap_err();
361 assert!(!err.is_timeout());
362 assert_eq!(err.into_inner(), Some("inner error"));
363 }
364
365 #[tokio::test]
366 async fn test_event_listeners() {
367 let success_count = Arc::new(AtomicUsize::new(0));
368 let timeout_count = Arc::new(AtomicUsize::new(0));
369
370 let sc = Arc::clone(&success_count);
371 let tc = Arc::clone(&timeout_count);
372
373 let layer = TimeLimiterLayer::builder()
374 .timeout_duration(Duration::from_millis(50))
375 .on_success(move |_| {
376 sc.fetch_add(1, Ordering::SeqCst);
377 })
378 .on_timeout(move || {
379 tc.fetch_add(1, Ordering::SeqCst);
380 })
381 .build();
382
383 let svc = service_fn(|_req: ()| async {
385 sleep(Duration::from_millis(10)).await;
386 Ok::<_, ()>("ok")
387 });
388 let mut service = layer.layer(svc);
389 let _ = service.ready().await.unwrap().call(()).await;
390 assert_eq!(success_count.load(Ordering::SeqCst), 1);
391
392 let svc = service_fn(|_req: ()| async {
394 sleep(Duration::from_millis(100)).await;
395 Ok::<_, ()>("ok")
396 });
397 let mut service = layer.layer(svc);
398 let _ = service.ready().await.unwrap().call(()).await;
399 assert_eq!(timeout_count.load(Ordering::SeqCst), 1);
400 }
401
402 #[tokio::test]
403 async fn test_per_request_timeout() {
404 #[derive(Clone)]
405 struct Request {
406 timeout_ms: u64,
407 sleep_ms: u64,
408 }
409
410 let layer = TimeLimiterLayer::builder()
412 .timeout_fn(|req: &Request| Duration::from_millis(req.timeout_ms))
413 .build();
414
415 let svc = service_fn(|req: Request| async move {
416 sleep(Duration::from_millis(req.sleep_ms)).await;
417 Ok::<_, ()>("done")
418 });
419
420 let mut service = layer.layer(svc);
421
422 let result = service
424 .ready()
425 .await
426 .unwrap()
427 .call(Request {
428 timeout_ms: 100,
429 sleep_ms: 10,
430 })
431 .await;
432 assert!(result.is_ok());
433
434 let result = service
436 .ready()
437 .await
438 .unwrap()
439 .call(Request {
440 timeout_ms: 10,
441 sleep_ms: 100,
442 })
443 .await;
444 assert!(result.is_err());
445 assert!(result.unwrap_err().is_timeout());
446 }
447
448 #[tokio::test]
449 async fn test_different_timeouts_per_request() {
450 #[derive(Clone)]
451 struct Request {
452 #[allow(dead_code)]
453 id: u32,
454 timeout_ms: Option<u64>,
455 }
456
457 let layer = TimeLimiterLayer::builder()
458 .timeout_fn(|req: &Request| {
459 req.timeout_ms
460 .map(Duration::from_millis)
461 .unwrap_or(Duration::from_millis(50)) })
463 .build();
464
465 let svc = service_fn(|_req: Request| async move {
466 sleep(Duration::from_millis(30)).await;
467 Ok::<_, ()>("done")
468 });
469
470 let mut service = layer.layer(svc);
471
472 let result = service
474 .ready()
475 .await
476 .unwrap()
477 .call(Request {
478 id: 1,
479 timeout_ms: Some(100),
480 })
481 .await;
482 assert!(result.is_ok());
483
484 let result = service
486 .ready()
487 .await
488 .unwrap()
489 .call(Request {
490 id: 2,
491 timeout_ms: Some(10),
492 })
493 .await;
494 assert!(result.is_err());
495
496 let result = service
498 .ready()
499 .await
500 .unwrap()
501 .call(Request {
502 id: 3,
503 timeout_ms: None,
504 })
505 .await;
506 assert!(result.is_ok());
507 }
508
509 }