1mod backoff;
158mod budget;
159mod config;
160mod events;
161mod layer;
162mod policy;
163
164pub use backoff::{
165 ExponentialBackoff, ExponentialRandomBackoff, FixedInterval, FnInterval, IntervalFunction,
166};
167pub use budget::{AimdBudget, RetryBudget, RetryBudgetBuilder, TokenBucketBudget};
168pub use config::{MaxAttemptsSource, RetryConfig, RetryConfigBuilder};
169pub use events::RetryEvent;
170pub use layer::RetryLayer;
171pub use policy::{ResponsePredicate, RetryPolicy, RetryPredicate};
172
173use futures::future::BoxFuture;
174use std::marker::PhantomData;
175use std::sync::Arc;
176use std::task::{Context, Poll};
177use std::time::Instant;
178use tower::Service;
179
180#[cfg(feature = "metrics")]
181use metrics::{counter, describe_counter, describe_histogram, histogram};
182
183#[cfg(feature = "tracing")]
184use tracing::{debug, info, warn};
185
186pub struct Retry<S, Req, Res, E> {
191 inner: S,
192 config: Arc<RetryConfig<Req, Res, E>>,
193 _phantom: PhantomData<Req>,
194}
195
196impl<S, Req, Res, E> Retry<S, Req, Res, E> {
197 pub fn new(
199 inner: S,
200 config: Arc<RetryConfig<Req, Res, E>>,
201 _phantom: PhantomData<Req>,
202 ) -> Self {
203 #[cfg(feature = "metrics")]
204 {
205 describe_counter!(
206 "retry_calls_total",
207 "Total number of retry operations (success or exhausted)"
208 );
209 describe_counter!(
210 "retry_attempts_total",
211 "Total number of retry attempts across all calls"
212 );
213 describe_histogram!("retry_attempts", "Number of attempts per successful call");
214 }
215
216 Self {
217 inner,
218 config,
219 _phantom,
220 }
221 }
222}
223
224impl<S, Req, Res, E> Clone for Retry<S, Req, Res, E>
225where
226 S: Clone,
227{
228 fn clone(&self) -> Self {
229 Self {
230 inner: self.inner.clone(),
231 config: Arc::clone(&self.config),
232 _phantom: PhantomData,
233 }
234 }
235}
236
237impl<S, Req, Res, E> Service<Req> for Retry<S, Req, Res, E>
238where
239 S: Service<Req, Response = Res, Error = E> + Clone + Send + 'static,
240 S::Future: Send + 'static,
241 Req: Clone + Send + 'static,
242 Res: Send + 'static,
243 E: Clone + Send + 'static,
244{
245 type Response = Res;
246 type Error = E;
247 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
248
249 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
250 self.inner.poll_ready(cx)
251 }
252
253 fn call(&mut self, req: Req) -> Self::Future {
254 let mut service = self.inner.clone();
255 let config = Arc::clone(&self.config);
256
257 let max_attempts = config.max_attempts_source.get_max_attempts(&req);
259
260 Box::pin(async move {
261 let mut attempt = 0;
262
263 loop {
264 let result = service.call(req.clone()).await;
265
266 match result {
267 Ok(response) => {
268 if config.policy.should_retry_response(&response) {
270 if attempt + 1 >= max_attempts {
272 #[cfg(feature = "metrics")]
273 {
274 counter!("retry_calls_total", "retry" => config.name.clone(), "result" => "exhausted").increment(1);
275 }
276
277 #[cfg(feature = "tracing")]
278 warn!(retry = %config.name, attempts = attempt + 1, max_attempts = max_attempts, "Retry attempts exhausted (response predicate)");
279
280 let event = RetryEvent::Error {
281 pattern_name: config.name.clone(),
282 timestamp: Instant::now(),
283 attempts: attempt + 1,
284 };
285 config.event_listeners.emit(&event);
286 return Ok(response);
287 }
288
289 if let Some(ref budget) = config.budget {
291 if !budget.try_withdraw() {
292 #[cfg(feature = "metrics")]
293 {
294 counter!("retry_calls_total", "retry" => config.name.clone(), "result" => "budget_exhausted").increment(1);
295 }
296
297 #[cfg(feature = "tracing")]
298 warn!(retry = %config.name, attempt = attempt + 1, "Retry budget exhausted (response predicate)");
299
300 let event = RetryEvent::BudgetExhausted {
301 pattern_name: config.name.clone(),
302 timestamp: Instant::now(),
303 attempt: attempt + 1,
304 };
305 config.event_listeners.emit(&event);
306 return Ok(response);
307 }
308 }
309
310 let delay = config.policy.next_backoff(attempt);
312
313 #[cfg(feature = "metrics")]
314 {
315 counter!("retry_attempts_total", "retry" => config.name.clone())
316 .increment(1);
317 }
318
319 #[cfg(feature = "tracing")]
320 debug!(retry = %config.name, attempt = attempt + 1, delay_ms = delay.as_millis(), "Retrying after response predicate match");
321
322 let event = RetryEvent::Retry {
323 pattern_name: config.name.clone(),
324 timestamp: Instant::now(),
325 attempt,
326 delay,
327 };
328 config.event_listeners.emit(&event);
329
330 tokio::time::sleep(delay).await;
331 attempt += 1;
332 continue;
333 }
334
335 if let Some(ref budget) = config.budget {
337 budget.deposit();
338 }
339
340 #[cfg(feature = "metrics")]
341 {
342 counter!("retry_calls_total", "retry" => config.name.clone(), "result" => "success").increment(1);
343 histogram!("retry_attempts", "retry" => config.name.clone())
344 .record((attempt + 1) as f64);
345 }
346
347 #[cfg(feature = "tracing")]
348 {
349 if attempt > 0 {
350 info!(retry = %config.name, attempts = attempt + 1, "Request succeeded after retries");
351 } else {
352 debug!(retry = %config.name, "Request succeeded on first attempt");
353 }
354 }
355
356 let event = RetryEvent::Success {
357 pattern_name: config.name.clone(),
358 timestamp: Instant::now(),
359 attempts: attempt + 1,
360 };
361 config.event_listeners.emit(&event);
362 return Ok(response);
363 }
364 Err(error) => {
365 if !config.policy.should_retry(&error) {
367 #[cfg(feature = "tracing")]
368 debug!(retry = %config.name, "Error not retryable, failing immediately");
369
370 let event = RetryEvent::IgnoredError {
371 pattern_name: config.name.clone(),
372 timestamp: Instant::now(),
373 };
374 config.event_listeners.emit(&event);
375 return Err(error);
376 }
377
378 if attempt + 1 >= max_attempts {
380 #[cfg(feature = "metrics")]
381 {
382 counter!("retry_calls_total", "retry" => config.name.clone(), "result" => "exhausted").increment(1);
383 }
384
385 #[cfg(feature = "tracing")]
386 warn!(retry = %config.name, attempts = attempt + 1, max_attempts = max_attempts, "Retry attempts exhausted");
387
388 let event = RetryEvent::Error {
389 pattern_name: config.name.clone(),
390 timestamp: Instant::now(),
391 attempts: attempt + 1,
392 };
393 config.event_listeners.emit(&event);
394 return Err(error);
395 }
396
397 if let Some(ref budget) = config.budget {
399 if !budget.try_withdraw() {
400 #[cfg(feature = "metrics")]
401 {
402 counter!("retry_calls_total", "retry" => config.name.clone(), "result" => "budget_exhausted").increment(1);
403 }
404
405 #[cfg(feature = "tracing")]
406 warn!(retry = %config.name, attempt = attempt + 1, "Retry budget exhausted, failing immediately");
407
408 let event = RetryEvent::BudgetExhausted {
409 pattern_name: config.name.clone(),
410 timestamp: Instant::now(),
411 attempt: attempt + 1,
412 };
413 config.event_listeners.emit(&event);
414 return Err(error);
415 }
416 }
417
418 let delay = config.policy.next_backoff(attempt);
420
421 #[cfg(feature = "metrics")]
422 {
423 counter!("retry_attempts_total", "retry" => config.name.clone())
424 .increment(1);
425 }
426
427 #[cfg(feature = "tracing")]
428 debug!(retry = %config.name, attempt = attempt + 1, delay_ms = delay.as_millis(), "Retrying after delay");
429
430 let event = RetryEvent::Retry {
431 pattern_name: config.name.clone(),
432 timestamp: Instant::now(),
433 attempt,
434 delay,
435 };
436 config.event_listeners.emit(&event);
437
438 tokio::time::sleep(delay).await;
439 attempt += 1;
440 }
441 }
442 }
443 })
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450 use std::sync::atomic::{AtomicUsize, Ordering};
451 use std::time::Duration;
452 use tower::service_fn;
453 use tower::{Layer, ServiceExt};
454
455 #[derive(Debug, Clone)]
456 struct TestError {
457 #[allow(dead_code)]
458 message: String,
459 }
460
461 impl TestError {
462 fn new(message: &str) -> Self {
463 Self {
464 message: message.to_string(),
465 }
466 }
467 }
468
469 #[tokio::test]
470 async fn successful_request_no_retry() {
471 let call_count = Arc::new(AtomicUsize::new(0));
472 let cc = Arc::clone(&call_count);
473
474 let service = service_fn(move |req: String| {
475 let cc = Arc::clone(&cc);
476 async move {
477 cc.fetch_add(1, Ordering::SeqCst);
478 Ok::<_, TestError>(format!("Response: {}", req))
479 }
480 });
481
482 let layer = RetryLayer::<String, String, TestError>::builder()
483 .max_attempts(3)
484 .fixed_backoff(Duration::from_millis(10))
485 .build();
486
487 let mut service = layer.layer(service);
488
489 let response = service
490 .ready()
491 .await
492 .unwrap()
493 .call("test".to_string())
494 .await
495 .unwrap();
496
497 assert_eq!(response, "Response: test");
498 assert_eq!(call_count.load(Ordering::SeqCst), 1);
499 }
500
501 #[tokio::test]
502 async fn retries_on_failure() {
503 let call_count = Arc::new(AtomicUsize::new(0));
504 let cc = Arc::clone(&call_count);
505
506 let service = service_fn(move |_req: String| {
507 let cc = Arc::clone(&cc);
508 async move {
509 let count = cc.fetch_add(1, Ordering::SeqCst);
510 if count < 2 {
511 Err(TestError::new("temporary failure"))
512 } else {
513 Ok::<_, TestError>("success".to_string())
514 }
515 }
516 });
517
518 let layer = RetryLayer::<String, String, TestError>::builder()
519 .max_attempts(3)
520 .fixed_backoff(Duration::from_millis(10))
521 .build();
522
523 let mut service = layer.layer(service);
524
525 let response = service
526 .ready()
527 .await
528 .unwrap()
529 .call("test".to_string())
530 .await
531 .unwrap();
532
533 assert_eq!(response, "success");
534 assert_eq!(call_count.load(Ordering::SeqCst), 3);
535 }
536
537 #[tokio::test]
538 async fn exhausts_retries() {
539 let call_count = Arc::new(AtomicUsize::new(0));
540 let cc = Arc::clone(&call_count);
541
542 let service = service_fn(move |_req: String| {
543 let cc = Arc::clone(&cc);
544 async move {
545 cc.fetch_add(1, Ordering::SeqCst);
546 Err::<String, _>(TestError::new("permanent failure"))
547 }
548 });
549
550 let layer = RetryLayer::<String, String, TestError>::builder()
551 .max_attempts(3)
552 .fixed_backoff(Duration::from_millis(10))
553 .build();
554
555 let mut service = layer.layer(service);
556
557 let result = service
558 .ready()
559 .await
560 .unwrap()
561 .call("test".to_string())
562 .await;
563
564 assert!(result.is_err());
565 assert_eq!(call_count.load(Ordering::SeqCst), 3);
566 }
567
568 #[tokio::test]
569 async fn retry_predicate_filters_errors() {
570 let call_count = Arc::new(AtomicUsize::new(0));
571 let cc = Arc::clone(&call_count);
572
573 let service = service_fn(move |_req: String| {
574 let cc = Arc::clone(&cc);
575 async move {
576 cc.fetch_add(1, Ordering::SeqCst);
577 Err::<String, _>(TestError::new("non-retryable"))
578 }
579 });
580
581 let layer = RetryLayer::<String, String, TestError>::builder()
582 .max_attempts(3)
583 .fixed_backoff(Duration::from_millis(10))
584 .retry_on(|_: &TestError| false) .build();
586
587 let mut service = layer.layer(service);
588
589 let result = service
590 .ready()
591 .await
592 .unwrap()
593 .call("test".to_string())
594 .await;
595
596 assert!(result.is_err());
597 assert_eq!(call_count.load(Ordering::SeqCst), 1); }
599
600 #[tokio::test]
601 async fn event_listeners_called() {
602 let retry_count = Arc::new(AtomicUsize::new(0));
603 let success_count = Arc::new(AtomicUsize::new(0));
604
605 let rc = Arc::clone(&retry_count);
606 let sc = Arc::clone(&success_count);
607
608 let call_count = Arc::new(AtomicUsize::new(0));
609 let cc = Arc::clone(&call_count);
610
611 let service = service_fn(move |_req: String| {
612 let cc = Arc::clone(&cc);
613 async move {
614 let count = cc.fetch_add(1, Ordering::SeqCst);
615 if count < 2 {
616 Err(TestError::new("temporary"))
617 } else {
618 Ok::<_, TestError>("success".to_string())
619 }
620 }
621 });
622
623 let layer = RetryLayer::<String, String, TestError>::builder()
624 .max_attempts(3)
625 .fixed_backoff(Duration::from_millis(10))
626 .on_retry(move |_, _| {
627 rc.fetch_add(1, Ordering::SeqCst);
628 })
629 .on_success(move |_| {
630 sc.fetch_add(1, Ordering::SeqCst);
631 })
632 .build();
633
634 let mut service = layer.layer(service);
635
636 let _ = service
637 .ready()
638 .await
639 .unwrap()
640 .call("test".to_string())
641 .await;
642
643 assert_eq!(retry_count.load(Ordering::SeqCst), 2); assert_eq!(success_count.load(Ordering::SeqCst), 1); }
646
647 #[tokio::test]
648 async fn budget_limits_retries() {
649 let call_count = Arc::new(AtomicUsize::new(0));
650 let budget_exhausted_count = Arc::new(AtomicUsize::new(0));
651
652 let cc = Arc::clone(&call_count);
653 let bec = Arc::clone(&budget_exhausted_count);
654
655 let budget = RetryBudgetBuilder::new()
657 .token_bucket()
658 .max_tokens(1)
659 .initial_tokens(1)
660 .build();
661
662 let service = service_fn(move |_req: String| {
663 let cc = Arc::clone(&cc);
664 async move {
665 cc.fetch_add(1, Ordering::SeqCst);
666 Err::<String, _>(TestError::new("always fails"))
667 }
668 });
669
670 let layer = RetryLayer::<String, String, TestError>::builder()
671 .max_attempts(5)
672 .fixed_backoff(Duration::from_millis(1))
673 .budget(budget)
674 .on_budget_exhausted(move |_| {
675 bec.fetch_add(1, Ordering::SeqCst);
676 })
677 .build();
678
679 let mut service = layer.layer(service);
680
681 let result = service
682 .ready()
683 .await
684 .unwrap()
685 .call("test".to_string())
686 .await;
687
688 assert!(result.is_err());
689 assert_eq!(call_count.load(Ordering::SeqCst), 2);
691 assert_eq!(budget_exhausted_count.load(Ordering::SeqCst), 1);
693 }
694
695 #[tokio::test]
696 async fn budget_replenishes_on_success() {
697 let budget = RetryBudgetBuilder::new()
698 .token_bucket()
699 .max_tokens(10)
700 .initial_tokens(0) .build();
702
703 assert_eq!(budget.balance(), 0);
705 assert!(!budget.try_withdraw());
706
707 budget.deposit();
709 assert_eq!(budget.balance(), 1);
710
711 assert!(budget.try_withdraw());
713 assert_eq!(budget.balance(), 0);
714 }
715
716 #[tokio::test]
717 async fn per_request_max_attempts() {
718 #[derive(Clone)]
719 struct Request {
720 is_idempotent: bool,
721 }
722
723 let call_count = Arc::new(AtomicUsize::new(0));
724 let cc = Arc::clone(&call_count);
725
726 let service = service_fn(move |_req: Request| {
727 let cc = Arc::clone(&cc);
728 async move {
729 cc.fetch_add(1, Ordering::SeqCst);
730 Err::<String, _>(TestError::new("always fails"))
731 }
732 });
733
734 let layer = RetryLayer::<Request, String, TestError>::builder()
735 .max_attempts_fn(|req: &Request| if req.is_idempotent { 5 } else { 1 })
736 .fixed_backoff(Duration::from_millis(1))
737 .build();
738
739 let mut service = layer.layer(service);
740
741 call_count.store(0, Ordering::SeqCst);
743 let _ = service
744 .ready()
745 .await
746 .unwrap()
747 .call(Request {
748 is_idempotent: false,
749 })
750 .await;
751 assert_eq!(call_count.load(Ordering::SeqCst), 1);
752
753 call_count.store(0, Ordering::SeqCst);
755 let _ = service
756 .ready()
757 .await
758 .unwrap()
759 .call(Request {
760 is_idempotent: true,
761 })
762 .await;
763 assert_eq!(call_count.load(Ordering::SeqCst), 5);
764 }
765
766 #[tokio::test]
767 async fn per_request_max_attempts_with_success() {
768 #[derive(Clone)]
769 struct Request {
770 max_retries: usize,
771 succeed_on_attempt: usize,
772 }
773
774 let call_count = Arc::new(AtomicUsize::new(0));
775 let cc = Arc::clone(&call_count);
776
777 let service = service_fn(move |req: Request| {
778 let cc = Arc::clone(&cc);
779 async move {
780 let attempt = cc.fetch_add(1, Ordering::SeqCst);
781 if attempt >= req.succeed_on_attempt {
782 Ok::<_, TestError>("success".to_string())
783 } else {
784 Err(TestError::new("not yet"))
785 }
786 }
787 });
788
789 let layer = RetryLayer::<Request, String, TestError>::builder()
790 .max_attempts_fn(|req: &Request| req.max_retries)
791 .fixed_backoff(Duration::from_millis(1))
792 .build();
793
794 let mut service = layer.layer(service);
795
796 call_count.store(0, Ordering::SeqCst);
798 let result = service
799 .ready()
800 .await
801 .unwrap()
802 .call(Request {
803 max_retries: 5,
804 succeed_on_attempt: 2,
805 })
806 .await;
807 assert!(result.is_ok());
808 assert_eq!(call_count.load(Ordering::SeqCst), 3);
809
810 call_count.store(0, Ordering::SeqCst);
812 let result = service
813 .ready()
814 .await
815 .unwrap()
816 .call(Request {
817 max_retries: 2,
818 succeed_on_attempt: 2,
819 })
820 .await;
821 assert!(result.is_err());
822 assert_eq!(call_count.load(Ordering::SeqCst), 2);
823 }
824
825 }