tower_llm/resilience/
mod.rs

1//! Resilience layers: timeout, retry, rate-limit, circuit-breaker
2//!
3//! What this module provides (spec)
4//! - Cross-cutting, reusable Tower middleware for reliability under failure and load
5//!
6//! Exports
7//! - Models
8//!   - `RetryPolicy { max_retries, backoff: Backoff }`
9//!   - `RateLimit { qps, burst }`
10//!   - `BreakerConfig { failure_threshold, window, reset_timeout }`
11//!   - `ErrorKind` and classifier function `fn classify(&Error) -> ErrorKind`
12//! - Layers
13//!   - `TimeoutLayer(Duration)` (thin wrapper around `tower::timeout::Timeout`)
14//!   - `RetryLayer<Classifier, Policy>`
15//!   - `RateLimitLayer` (token bucket)
16//!   - `CircuitBreakerLayer` (stateful gate)
17//! - Utils
18//!   - Backoff builders (fixed, exponential, jitter), default classifiers
19//!
20//! Implementation strategy
21//! - Timeout: use `tower::timeout` directly; just re-expose with our config type
22//! - Retry: wrap inner service; on transient errors per classifier, retry with backoff
23//! - Rate limit: keep a token bucket (Arc<Mutex>) and check/consume tokens per request
24//! - Circuit breaker: track success/failure counts in a sliding window; on open, short-circuit with error; half-open lets a probe through
25//!
26//! Composition
27//! - Builder sugar chooses where to apply: model only, tools only, or entire step/agent
28//! - Example: `ServiceBuilder::new().layer(TimeoutLayer::new(dur)).layer(RetryLayer::new(policy)).service(step)`
29//!
30//! Testing strategy
31//! - Use fake services that error in a scripted pattern (e.g., E E S) and assert retry timing and counts
32//! - For breaker: simulate sustained failures, confirm open/half-open transitions with timers
33//! - Rate limit: simulate bursts to validate token consumption and backoff behavior
34
35use std::future::Future;
36use std::pin::Pin;
37use std::sync::Arc;
38use std::time::{Duration, Instant};
39
40use tokio::sync::Mutex;
41use tokio::time::{sleep, timeout};
42use tower::{BoxError, Layer, Service, ServiceExt};
43
44// ===== Retry =====
45
46#[derive(Debug, Clone, Copy)]
47pub enum BackoffKind {
48    Fixed,
49    Exponential,
50}
51
52#[derive(Debug, Clone, Copy)]
53pub struct Backoff {
54    pub kind: BackoffKind,
55    pub initial: Duration,
56    pub factor: f32,
57    pub max: Duration,
58}
59
60impl Backoff {
61    pub fn fixed(delay: Duration) -> Self {
62        Self {
63            kind: BackoffKind::Fixed,
64            initial: delay,
65            factor: 1.0,
66            max: delay,
67        }
68    }
69    pub fn exponential(initial: Duration, factor: f32, max: Duration) -> Self {
70        Self {
71            kind: BackoffKind::Exponential,
72            initial,
73            factor,
74            max,
75        }
76    }
77    pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
78        match self.kind {
79            BackoffKind::Fixed => self.initial,
80            BackoffKind::Exponential => {
81                let mult = self.factor.powi(attempt as i32);
82                let d = self.initial.mul_f32(mult);
83                if d > self.max {
84                    self.max
85                } else {
86                    d
87                }
88            }
89        }
90    }
91}
92
93#[derive(Debug, Clone, Copy)]
94pub struct RetryPolicy {
95    pub max_retries: usize,
96    pub backoff: Backoff,
97}
98
99pub trait ErrorClassifier: Send + Sync + 'static {
100    fn retryable(&self, error: &BoxError) -> bool;
101}
102
103#[derive(Debug, Clone, Copy)]
104pub struct AlwaysRetry;
105impl ErrorClassifier for AlwaysRetry {
106    fn retryable(&self, _error: &BoxError) -> bool {
107        true
108    }
109}
110
111pub struct RetryLayer<C> {
112    policy: RetryPolicy,
113    classifier: C,
114}
115
116impl<C> RetryLayer<C> {
117    pub fn new(policy: RetryPolicy, classifier: C) -> Self {
118        Self { policy, classifier }
119    }
120}
121
122pub struct Retry<S, C> {
123    inner: Arc<Mutex<S>>,
124    policy: RetryPolicy,
125    classifier: C,
126}
127
128impl<S, C> Layer<S> for RetryLayer<C>
129where
130    C: Clone,
131{
132    type Service = Retry<S, C>;
133    fn layer(&self, inner: S) -> Self::Service {
134        Retry {
135            inner: Arc::new(Mutex::new(inner)),
136            policy: self.policy,
137            classifier: self.classifier.clone(),
138        }
139    }
140}
141
142impl<S, C, Req> Service<Req> for Retry<S, C>
143where
144    Req: Clone + Send + 'static,
145    S: Service<Req, Error = BoxError> + Send + 'static,
146    S::Future: Send + 'static,
147    S::Response: Send + 'static,
148    C: ErrorClassifier + Send + Sync + Clone + 'static,
149{
150    type Response = S::Response;
151    type Error = BoxError;
152    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
153
154    fn poll_ready(
155        &mut self,
156        _cx: &mut std::task::Context<'_>,
157    ) -> std::task::Poll<Result<(), Self::Error>> {
158        std::task::Poll::Ready(Ok(()))
159    }
160
161    fn call(&mut self, req: Req) -> Self::Future {
162        let policy = self.policy;
163        let classifier = self.classifier.clone();
164        let req0 = req.clone();
165        let mut attempts: usize = 0;
166        let inner = self.inner.clone();
167        Box::pin(async move {
168            loop {
169                let result = {
170                    let mut guard = inner.lock().await;
171                    ServiceExt::ready(&mut *guard)
172                        .await?
173                        .call(req0.clone())
174                        .await
175                };
176                match result {
177                    Ok(resp) => return Ok(resp),
178                    Err(e) => {
179                        if attempts >= policy.max_retries || !classifier.retryable(&e) {
180                            return Err(e);
181                        }
182                        let delay = policy.backoff.delay_for_attempt(attempts);
183                        attempts += 1;
184                        sleep(delay).await;
185                    }
186                }
187            }
188        })
189    }
190}
191
192// ===== Timeout =====
193
194pub struct TimeoutLayer {
195    dur: Duration,
196}
197
198impl TimeoutLayer {
199    pub fn new(dur: Duration) -> Self {
200        Self { dur }
201    }
202}
203
204pub struct Timeout<S> {
205    inner: S,
206    dur: Duration,
207}
208
209impl<S> Layer<S> for TimeoutLayer {
210    type Service = Timeout<S>;
211    fn layer(&self, inner: S) -> Self::Service {
212        Timeout {
213            inner,
214            dur: self.dur,
215        }
216    }
217}
218
219impl<S, Req> Service<Req> for Timeout<S>
220where
221    S: Service<Req, Error = BoxError> + Send + 'static,
222    S::Future: Send + 'static,
223    S::Response: Send + 'static,
224{
225    type Response = S::Response;
226    type Error = BoxError;
227    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
228
229    fn poll_ready(
230        &mut self,
231        cx: &mut std::task::Context<'_>,
232    ) -> std::task::Poll<Result<(), Self::Error>> {
233        self.inner.poll_ready(cx)
234    }
235
236    fn call(&mut self, req: Req) -> Self::Future {
237        let fut = self.inner.call(req);
238        let dur = self.dur;
239        Box::pin(async move {
240            match timeout(dur, fut).await {
241                Ok(r) => r,
242                Err(_) => Err::<S::Response, BoxError>("timeout".into()),
243            }
244        })
245    }
246}
247
248// ===== Circuit Breaker (simplified) =====
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq)]
251enum BreakerState {
252    Closed,
253    OpenUntil(Instant),
254    HalfOpen,
255}
256
257#[derive(Debug, Clone, Copy)]
258pub struct BreakerConfig {
259    pub failure_threshold: usize,
260    pub reset_timeout: Duration,
261}
262
263pub struct CircuitBreakerLayer {
264    cfg: BreakerConfig,
265}
266
267impl CircuitBreakerLayer {
268    pub fn new(cfg: BreakerConfig) -> Self {
269        Self { cfg }
270    }
271}
272
273pub struct CircuitBreaker<S> {
274    inner: S,
275    cfg: BreakerConfig,
276    state: Arc<Mutex<(BreakerState, usize)>>, // (state, consecutive_failures)
277}
278
279impl<S> Layer<S> for CircuitBreakerLayer {
280    type Service = CircuitBreaker<S>;
281    fn layer(&self, inner: S) -> Self::Service {
282        CircuitBreaker {
283            inner,
284            cfg: self.cfg,
285            state: Arc::new(Mutex::new((BreakerState::Closed, 0))),
286        }
287    }
288}
289
290impl<S, Req> Service<Req> for CircuitBreaker<S>
291where
292    S: Service<Req, Error = BoxError> + Send + 'static,
293    S::Future: Send + 'static,
294    S::Response: Send + 'static,
295{
296    type Response = S::Response;
297    type Error = BoxError;
298    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
299
300    fn poll_ready(
301        &mut self,
302        cx: &mut std::task::Context<'_>,
303    ) -> std::task::Poll<Result<(), Self::Error>> {
304        self.inner.poll_ready(cx)
305    }
306
307    fn call(&mut self, req: Req) -> Self::Future {
308        let cfg = self.cfg;
309        let state = self.state.clone();
310        let fut = self.inner.call(req);
311        Box::pin(async move {
312            // Check state
313            {
314                let mut s = state.lock().await;
315                match s.0 {
316                    BreakerState::Closed => {}
317                    BreakerState::OpenUntil(t) => {
318                        if Instant::now() < t {
319                            return Err("circuit open".into());
320                        }
321                        s.0 = BreakerState::HalfOpen;
322                    }
323                    BreakerState::HalfOpen => {}
324                }
325            }
326
327            match fut.await {
328                Ok(resp) => {
329                    let mut s = state.lock().await;
330                    s.1 = 0; // reset failures
331                    s.0 = BreakerState::Closed;
332                    Ok(resp)
333                }
334                Err(e) => {
335                    let mut s = state.lock().await;
336                    s.1 += 1;
337                    if s.1 >= cfg.failure_threshold {
338                        s.0 = BreakerState::OpenUntil(Instant::now() + cfg.reset_timeout);
339                    }
340                    Err(e)
341                }
342            }
343        })
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use std::sync::atomic::{AtomicUsize, Ordering};
351    use tower::service_fn;
352
353    #[tokio::test]
354    async fn retry_eventually_succeeds() {
355        static COUNT: AtomicUsize = AtomicUsize::new(0);
356        let svc = service_fn(|()| async move {
357            let n = COUNT.fetch_add(1, Ordering::SeqCst);
358            if n < 2 {
359                Err::<(), BoxError>("e".into())
360            } else {
361                Ok::<(), BoxError>(())
362            }
363        });
364        let layer = RetryLayer::new(
365            RetryPolicy {
366                max_retries: 5,
367                backoff: Backoff::fixed(Duration::from_millis(1)),
368            },
369            AlwaysRetry,
370        );
371        let mut svc = layer.layer(svc);
372        ServiceExt::ready(&mut svc)
373            .await
374            .unwrap()
375            .call(())
376            .await
377            .unwrap();
378    }
379
380    #[tokio::test]
381    async fn timeout_triggers_error() {
382        let svc = service_fn(|()| async move {
383            sleep(Duration::from_millis(20)).await;
384            Ok::<(), BoxError>(())
385        });
386        let mut svc = TimeoutLayer::new(Duration::from_millis(5)).layer(svc);
387        let err = ServiceExt::ready(&mut svc)
388            .await
389            .unwrap()
390            .call(())
391            .await
392            .unwrap_err();
393        assert!(format!("{}", err).contains("timeout"));
394    }
395
396    #[tokio::test]
397    async fn breaker_opens_after_failures() {
398        static CALLED: AtomicUsize = AtomicUsize::new(0);
399        let svc = service_fn(|()| async move {
400            CALLED.fetch_add(1, Ordering::SeqCst);
401            Err::<(), BoxError>("boom".into())
402        });
403        let mut svc = CircuitBreakerLayer::new(BreakerConfig {
404            failure_threshold: 2,
405            reset_timeout: Duration::from_millis(30),
406        })
407        .layer(svc);
408        // first two calls invoke inner and fail
409        let _ = ServiceExt::ready(&mut svc).await.unwrap().call(()).await;
410        let _ = ServiceExt::ready(&mut svc).await.unwrap().call(()).await;
411        // now breaker should open; next call should be short-circuited (no inner increment)
412        let _ = ServiceExt::ready(&mut svc).await.unwrap().call(()).await;
413        assert!(CALLED.load(Ordering::SeqCst) <= 2);
414        // wait and allow half-open, then call again to hit inner once
415        sleep(Duration::from_millis(35)).await;
416        let _ = ServiceExt::ready(&mut svc).await.unwrap().call(()).await;
417        assert!(CALLED.load(Ordering::SeqCst) <= 3);
418    }
419}