1use std::future::Future;
36use std::pin::Pin;
37use std::sync::Arc;
38use std::time::{Duration, Instant};
39
40use tokio::sync::Mutex;
41use tokio::time::{sleep, timeout};
42use tower::{BoxError, Layer, Service, ServiceExt};
43
44#[derive(Debug, Clone, Copy)]
47pub enum BackoffKind {
48 Fixed,
49 Exponential,
50}
51
52#[derive(Debug, Clone, Copy)]
53pub struct Backoff {
54 pub kind: BackoffKind,
55 pub initial: Duration,
56 pub factor: f32,
57 pub max: Duration,
58}
59
60impl Backoff {
61 pub fn fixed(delay: Duration) -> Self {
62 Self {
63 kind: BackoffKind::Fixed,
64 initial: delay,
65 factor: 1.0,
66 max: delay,
67 }
68 }
69 pub fn exponential(initial: Duration, factor: f32, max: Duration) -> Self {
70 Self {
71 kind: BackoffKind::Exponential,
72 initial,
73 factor,
74 max,
75 }
76 }
77 pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
78 match self.kind {
79 BackoffKind::Fixed => self.initial,
80 BackoffKind::Exponential => {
81 let mult = self.factor.powi(attempt as i32);
82 let d = self.initial.mul_f32(mult);
83 if d > self.max {
84 self.max
85 } else {
86 d
87 }
88 }
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy)]
94pub struct RetryPolicy {
95 pub max_retries: usize,
96 pub backoff: Backoff,
97}
98
99pub trait ErrorClassifier: Send + Sync + 'static {
100 fn retryable(&self, error: &BoxError) -> bool;
101}
102
103#[derive(Debug, Clone, Copy)]
104pub struct AlwaysRetry;
105impl ErrorClassifier for AlwaysRetry {
106 fn retryable(&self, _error: &BoxError) -> bool {
107 true
108 }
109}
110
111pub struct RetryLayer<C> {
112 policy: RetryPolicy,
113 classifier: C,
114}
115
116impl<C> RetryLayer<C> {
117 pub fn new(policy: RetryPolicy, classifier: C) -> Self {
118 Self { policy, classifier }
119 }
120}
121
122pub struct Retry<S, C> {
123 inner: Arc<Mutex<S>>,
124 policy: RetryPolicy,
125 classifier: C,
126}
127
128impl<S, C> Layer<S> for RetryLayer<C>
129where
130 C: Clone,
131{
132 type Service = Retry<S, C>;
133 fn layer(&self, inner: S) -> Self::Service {
134 Retry {
135 inner: Arc::new(Mutex::new(inner)),
136 policy: self.policy,
137 classifier: self.classifier.clone(),
138 }
139 }
140}
141
142impl<S, C, Req> Service<Req> for Retry<S, C>
143where
144 Req: Clone + Send + 'static,
145 S: Service<Req, Error = BoxError> + Send + 'static,
146 S::Future: Send + 'static,
147 S::Response: Send + 'static,
148 C: ErrorClassifier + Send + Sync + Clone + 'static,
149{
150 type Response = S::Response;
151 type Error = BoxError;
152 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
153
154 fn poll_ready(
155 &mut self,
156 _cx: &mut std::task::Context<'_>,
157 ) -> std::task::Poll<Result<(), Self::Error>> {
158 std::task::Poll::Ready(Ok(()))
159 }
160
161 fn call(&mut self, req: Req) -> Self::Future {
162 let policy = self.policy;
163 let classifier = self.classifier.clone();
164 let req0 = req.clone();
165 let mut attempts: usize = 0;
166 let inner = self.inner.clone();
167 Box::pin(async move {
168 loop {
169 let result = {
170 let mut guard = inner.lock().await;
171 ServiceExt::ready(&mut *guard)
172 .await?
173 .call(req0.clone())
174 .await
175 };
176 match result {
177 Ok(resp) => return Ok(resp),
178 Err(e) => {
179 if attempts >= policy.max_retries || !classifier.retryable(&e) {
180 return Err(e);
181 }
182 let delay = policy.backoff.delay_for_attempt(attempts);
183 attempts += 1;
184 sleep(delay).await;
185 }
186 }
187 }
188 })
189 }
190}
191
192pub struct TimeoutLayer {
195 dur: Duration,
196}
197
198impl TimeoutLayer {
199 pub fn new(dur: Duration) -> Self {
200 Self { dur }
201 }
202}
203
204pub struct Timeout<S> {
205 inner: S,
206 dur: Duration,
207}
208
209impl<S> Layer<S> for TimeoutLayer {
210 type Service = Timeout<S>;
211 fn layer(&self, inner: S) -> Self::Service {
212 Timeout {
213 inner,
214 dur: self.dur,
215 }
216 }
217}
218
219impl<S, Req> Service<Req> for Timeout<S>
220where
221 S: Service<Req, Error = BoxError> + Send + 'static,
222 S::Future: Send + 'static,
223 S::Response: Send + 'static,
224{
225 type Response = S::Response;
226 type Error = BoxError;
227 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
228
229 fn poll_ready(
230 &mut self,
231 cx: &mut std::task::Context<'_>,
232 ) -> std::task::Poll<Result<(), Self::Error>> {
233 self.inner.poll_ready(cx)
234 }
235
236 fn call(&mut self, req: Req) -> Self::Future {
237 let fut = self.inner.call(req);
238 let dur = self.dur;
239 Box::pin(async move {
240 match timeout(dur, fut).await {
241 Ok(r) => r,
242 Err(_) => Err::<S::Response, BoxError>("timeout".into()),
243 }
244 })
245 }
246}
247
248#[derive(Debug, Clone, Copy, PartialEq, Eq)]
251enum BreakerState {
252 Closed,
253 OpenUntil(Instant),
254 HalfOpen,
255}
256
257#[derive(Debug, Clone, Copy)]
258pub struct BreakerConfig {
259 pub failure_threshold: usize,
260 pub reset_timeout: Duration,
261}
262
263pub struct CircuitBreakerLayer {
264 cfg: BreakerConfig,
265}
266
267impl CircuitBreakerLayer {
268 pub fn new(cfg: BreakerConfig) -> Self {
269 Self { cfg }
270 }
271}
272
273pub struct CircuitBreaker<S> {
274 inner: S,
275 cfg: BreakerConfig,
276 state: Arc<Mutex<(BreakerState, usize)>>, }
278
279impl<S> Layer<S> for CircuitBreakerLayer {
280 type Service = CircuitBreaker<S>;
281 fn layer(&self, inner: S) -> Self::Service {
282 CircuitBreaker {
283 inner,
284 cfg: self.cfg,
285 state: Arc::new(Mutex::new((BreakerState::Closed, 0))),
286 }
287 }
288}
289
290impl<S, Req> Service<Req> for CircuitBreaker<S>
291where
292 S: Service<Req, Error = BoxError> + Send + 'static,
293 S::Future: Send + 'static,
294 S::Response: Send + 'static,
295{
296 type Response = S::Response;
297 type Error = BoxError;
298 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
299
300 fn poll_ready(
301 &mut self,
302 cx: &mut std::task::Context<'_>,
303 ) -> std::task::Poll<Result<(), Self::Error>> {
304 self.inner.poll_ready(cx)
305 }
306
307 fn call(&mut self, req: Req) -> Self::Future {
308 let cfg = self.cfg;
309 let state = self.state.clone();
310 let fut = self.inner.call(req);
311 Box::pin(async move {
312 {
314 let mut s = state.lock().await;
315 match s.0 {
316 BreakerState::Closed => {}
317 BreakerState::OpenUntil(t) => {
318 if Instant::now() < t {
319 return Err("circuit open".into());
320 }
321 s.0 = BreakerState::HalfOpen;
322 }
323 BreakerState::HalfOpen => {}
324 }
325 }
326
327 match fut.await {
328 Ok(resp) => {
329 let mut s = state.lock().await;
330 s.1 = 0; s.0 = BreakerState::Closed;
332 Ok(resp)
333 }
334 Err(e) => {
335 let mut s = state.lock().await;
336 s.1 += 1;
337 if s.1 >= cfg.failure_threshold {
338 s.0 = BreakerState::OpenUntil(Instant::now() + cfg.reset_timeout);
339 }
340 Err(e)
341 }
342 }
343 })
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use std::sync::atomic::{AtomicUsize, Ordering};
351 use tower::service_fn;
352
353 #[tokio::test]
354 async fn retry_eventually_succeeds() {
355 static COUNT: AtomicUsize = AtomicUsize::new(0);
356 let svc = service_fn(|()| async move {
357 let n = COUNT.fetch_add(1, Ordering::SeqCst);
358 if n < 2 {
359 Err::<(), BoxError>("e".into())
360 } else {
361 Ok::<(), BoxError>(())
362 }
363 });
364 let layer = RetryLayer::new(
365 RetryPolicy {
366 max_retries: 5,
367 backoff: Backoff::fixed(Duration::from_millis(1)),
368 },
369 AlwaysRetry,
370 );
371 let mut svc = layer.layer(svc);
372 ServiceExt::ready(&mut svc)
373 .await
374 .unwrap()
375 .call(())
376 .await
377 .unwrap();
378 }
379
380 #[tokio::test]
381 async fn timeout_triggers_error() {
382 let svc = service_fn(|()| async move {
383 sleep(Duration::from_millis(20)).await;
384 Ok::<(), BoxError>(())
385 });
386 let mut svc = TimeoutLayer::new(Duration::from_millis(5)).layer(svc);
387 let err = ServiceExt::ready(&mut svc)
388 .await
389 .unwrap()
390 .call(())
391 .await
392 .unwrap_err();
393 assert!(format!("{}", err).contains("timeout"));
394 }
395
396 #[tokio::test]
397 async fn breaker_opens_after_failures() {
398 static CALLED: AtomicUsize = AtomicUsize::new(0);
399 let svc = service_fn(|()| async move {
400 CALLED.fetch_add(1, Ordering::SeqCst);
401 Err::<(), BoxError>("boom".into())
402 });
403 let mut svc = CircuitBreakerLayer::new(BreakerConfig {
404 failure_threshold: 2,
405 reset_timeout: Duration::from_millis(30),
406 })
407 .layer(svc);
408 let _ = ServiceExt::ready(&mut svc).await.unwrap().call(()).await;
410 let _ = ServiceExt::ready(&mut svc).await.unwrap().call(()).await;
411 let _ = ServiceExt::ready(&mut svc).await.unwrap().call(()).await;
413 assert!(CALLED.load(Ordering::SeqCst) <= 2);
414 sleep(Duration::from_millis(35)).await;
416 let _ = ServiceExt::ready(&mut svc).await.unwrap().call(()).await;
417 assert!(CALLED.load(Ordering::SeqCst) <= 3);
418 }
419}