Skip to main content

tower_resilience_router/
lib.rs

1//! Weighted traffic routing for Tower services.
2//!
3//! This crate provides a `WeightedRouter` service that distributes requests
4//! across multiple backend services based on configured weights. It is designed
5//! for canary deployments, progressive rollouts, and controlled traffic splitting.
6//!
7//! # Overview
8//!
9//! Unlike other tower-resilience patterns which wrap a single service and modify
10//! its behavior, `WeightedRouter` *selects among* multiple services. It is a
11//! standalone `Service`, not a `Layer`.
12//!
13//! All backend services must have the same `Request`, `Response`, and `Error`
14//! types. For canary deployments (same service, different version), this is
15//! the natural case.
16//!
17//! # Selection Strategies
18//!
19//! - **Deterministic** (default): Uses an atomic counter for predictable,
20//!   repeatable distribution. With weights `[90, 10]`, every cycle of 100
21//!   requests sends exactly 90 to the first backend and 10 to the second.
22//!
23//! - **Random**: Each request independently selects a backend with probability
24//!   proportional to its weight. Better for high-volume statistical distribution,
25//!   but may show variance at low traffic.
26//!
27//! # Readiness
28//!
29//! `poll_ready` checks that **all** backends are ready. This is the simplest
30//! and most predictable contract. If a backend is slow or failing, pair it
31//! with a circuit breaker so that readiness resolves quickly (open circuit =
32//! immediate ready or error).
33//!
34//! # Example
35//!
36//! Because all backends must be the same type `S`, use `BoxService` when
37//! constructing from different closures:
38//!
39//! ```rust,no_run
40//! use tower::Service;
41//! use tower::util::BoxService;
42//! use tower_resilience_router::WeightedRouter;
43//!
44//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
45//! let svc_v1: BoxService<String, String, std::io::Error> =
46//!     BoxService::new(tower::service_fn(|req: String| async move {
47//!         Ok(format!("v1: {}", req))
48//!     }));
49//! let svc_v2: BoxService<String, String, std::io::Error> =
50//!     BoxService::new(tower::service_fn(|req: String| async move {
51//!         Ok(format!("v2: {}", req))
52//!     }));
53//!
54//! let mut router = WeightedRouter::builder()
55//!     .route(svc_v1, 90)
56//!     .route(svc_v2, 10)
57//!     .build();
58//! # Ok(())
59//! # }
60//! ```
61//!
62//! # Composability
63//!
64//! The natural composition pattern puts resilience middleware *inside* each backend:
65//!
66//! ```rust,no_run
67//! use tower::Layer;
68//! use tower_resilience_router::WeightedRouter;
69//! # use tower::service_fn;
70//! # let svc_v1 = service_fn(|_: ()| async { Ok::<_, std::io::Error>(()) });
71//! # let svc_v2 = service_fn(|_: ()| async { Ok::<_, std::io::Error>(()) });
72//!
73//! // Each backend gets its own circuit breaker
74//! // let cb = CircuitBreakerLayer::standard().build();
75//! // let router = WeightedRouter::builder()
76//! //     .route(cb.layer(svc_v1), 90)
77//! //     .route(cb.layer(svc_v2), 10)
78//! //     .build();
79//! ```
80
81/// Configuration and builder types for the weighted router.
82pub mod config;
83/// Error types for routing failures.
84pub mod error;
85/// Event types emitted when requests are routed.
86pub mod events;
87/// Backend selection strategies (deterministic, random).
88pub mod selection;
89
90pub use config::WeightedRouterBuilder;
91pub use error::WeightedRouterError;
92pub use events::RouterEvent;
93pub use selection::SelectionStrategy;
94
95use config::RouterConfig;
96use selection::WeightedSelector;
97use std::task::{Context, Poll};
98use tower_service::Service;
99
100/// A service that routes requests to one of several backends based on weights.
101///
102/// `WeightedRouter` is a standalone `Service`, not a `Layer`. It selects among
103/// multiple backend services of the same type, distributing traffic according
104/// to configured weights.
105///
106/// Use [`WeightedRouter::builder`] to construct a new router.
107pub struct WeightedRouter<S> {
108    /// Backend services with their weights.
109    backends: Vec<(S, u32)>,
110    /// Selector for choosing backends.
111    selector: WeightedSelector,
112    /// Configuration.
113    config: RouterConfig,
114}
115
116impl<S> WeightedRouter<S> {
117    /// Creates a new builder for configuring a `WeightedRouter`.
118    ///
119    /// # Examples
120    ///
121    /// ```rust,no_run
122    /// use tower_resilience_router::WeightedRouter;
123    /// use tower::util::BoxService;
124    ///
125    /// let svc_v1: BoxService<(), (), std::io::Error> =
126    ///     BoxService::new(tower::service_fn(|_: ()| async { Ok(()) }));
127    /// let svc_v2: BoxService<(), (), std::io::Error> =
128    ///     BoxService::new(tower::service_fn(|_: ()| async { Ok(()) }));
129    ///
130    /// let router = WeightedRouter::builder()
131    ///     .route(svc_v1, 90)
132    ///     .route(svc_v2, 10)
133    ///     .build();
134    /// ```
135    pub fn builder() -> WeightedRouterBuilder<S> {
136        WeightedRouterBuilder::new()
137    }
138
139    pub(crate) fn new(backends: Vec<(S, u32)>, config: RouterConfig) -> Self {
140        let weights: Vec<u32> = backends.iter().map(|(_, w)| *w).collect();
141        let selector = WeightedSelector::new(&weights, config.strategy);
142        Self {
143            backends,
144            selector,
145            config,
146        }
147    }
148
149    /// Returns the number of backends.
150    pub fn backend_count(&self) -> usize {
151        self.backends.len()
152    }
153
154    /// Returns the weights of all backends.
155    pub fn weights(&self) -> Vec<u32> {
156        self.backends.iter().map(|(_, w)| *w).collect()
157    }
158
159    /// Returns the name of this router instance.
160    pub fn name(&self) -> &str {
161        &self.config.name
162    }
163}
164
165impl<S: Clone> Clone for WeightedRouter<S> {
166    fn clone(&self) -> Self {
167        Self {
168            backends: self.backends.clone(),
169            selector: self.selector.clone(),
170            config: self.config.clone(),
171        }
172    }
173}
174
175impl<S, Request> Service<Request> for WeightedRouter<S>
176where
177    S: Service<Request>,
178{
179    type Response = S::Response;
180    type Error = S::Error;
181    type Future = S::Future;
182
183    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
184        // All backends must be ready.
185        for (svc, _) in &mut self.backends {
186            match svc.poll_ready(cx)? {
187                Poll::Ready(()) => {}
188                Poll::Pending => return Poll::Pending,
189            }
190        }
191        Poll::Ready(Ok(()))
192    }
193
194    fn call(&mut self, request: Request) -> Self::Future {
195        let idx = self.selector.select();
196        let (svc, weight) = &mut self.backends[idx];
197
198        #[cfg(feature = "metrics")]
199        {
200            let labels = [
201                ("router", self.config.name.clone()),
202                ("backend", idx.to_string()),
203            ];
204            metrics::counter!("router_requests_routed_total", &labels).increment(1);
205        }
206
207        #[cfg(feature = "tracing")]
208        {
209            tracing::debug!(
210                router = %self.config.name,
211                backend_index = idx,
212                backend_weight = *weight,
213                "routing request to backend"
214            );
215        }
216
217        self.config
218            .event_listeners
219            .emit(&RouterEvent::RequestRouted {
220                pattern_name: self.config.name.clone(),
221                timestamp: std::time::Instant::now(),
222                backend_index: idx,
223                backend_weight: *weight,
224            });
225
226        svc.call(request)
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use std::sync::atomic::{AtomicUsize, Ordering};
234    use std::sync::Arc;
235    use tower::util::BoxService;
236    use tower::ServiceExt;
237
238    type BoxSvc = BoxService<(), &'static str, TestError>;
239
240    #[derive(Clone, Debug)]
241    struct TestError;
242    impl std::fmt::Display for TestError {
243        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244            write!(f, "test error")
245        }
246    }
247    impl std::error::Error for TestError {}
248
249    fn counting_svc(counter: Arc<AtomicUsize>, label: &'static str) -> BoxSvc {
250        BoxService::new(tower::service_fn(move |_: ()| {
251            let c = Arc::clone(&counter);
252            async move {
253                c.fetch_add(1, Ordering::SeqCst);
254                Ok::<_, TestError>(label)
255            }
256        }))
257    }
258
259    #[tokio::test]
260    async fn routes_by_weight_deterministic() {
261        let count_a = Arc::new(AtomicUsize::new(0));
262        let count_b = Arc::new(AtomicUsize::new(0));
263
264        let mut router = WeightedRouter::builder()
265            .route(counting_svc(Arc::clone(&count_a), "a"), 80)
266            .route(counting_svc(Arc::clone(&count_b), "b"), 20)
267            .build();
268
269        for _ in 0..100 {
270            let _ = router.ready().await.unwrap().call(()).await;
271        }
272
273        assert_eq!(count_a.load(Ordering::SeqCst), 80);
274        assert_eq!(count_b.load(Ordering::SeqCst), 20);
275    }
276
277    #[tokio::test]
278    async fn single_backend_gets_all_traffic() {
279        let count = Arc::new(AtomicUsize::new(0));
280
281        let mut router = WeightedRouter::builder()
282            .route(counting_svc(Arc::clone(&count), "ok"), 1)
283            .build();
284
285        for _ in 0..50 {
286            let _ = router.ready().await.unwrap().call(()).await;
287        }
288
289        assert_eq!(count.load(Ordering::SeqCst), 50);
290    }
291
292    #[tokio::test]
293    async fn three_backends() {
294        let counts: Vec<Arc<AtomicUsize>> = (0..3).map(|_| Arc::new(AtomicUsize::new(0))).collect();
295
296        let mut router = WeightedRouter::builder()
297            .route(counting_svc(Arc::clone(&counts[0]), "0"), 50)
298            .route(counting_svc(Arc::clone(&counts[1]), "1"), 30)
299            .route(counting_svc(Arc::clone(&counts[2]), "2"), 20)
300            .build();
301
302        for _ in 0..100 {
303            let _ = router.ready().await.unwrap().call(()).await;
304        }
305
306        assert_eq!(counts[0].load(Ordering::SeqCst), 50);
307        assert_eq!(counts[1].load(Ordering::SeqCst), 30);
308        assert_eq!(counts[2].load(Ordering::SeqCst), 20);
309    }
310
311    #[tokio::test]
312    async fn error_propagates_from_backend() {
313        let svc: BoxSvc = BoxService::new(tower::service_fn(|_: ()| async {
314            Err::<&str, _>(TestError)
315        }));
316
317        let mut router = WeightedRouter::builder().route(svc, 1).build();
318
319        let result = router.ready().await.unwrap().call(()).await;
320        assert!(result.is_err());
321    }
322
323    #[tokio::test]
324    async fn event_listener_fires() {
325        let routed_count = Arc::new(AtomicUsize::new(0));
326        let rc = Arc::clone(&routed_count);
327
328        let svc: BoxSvc = BoxService::new(tower::service_fn(|_: ()| async {
329            Ok::<_, TestError>("ok")
330        }));
331
332        let mut router = WeightedRouter::builder()
333            .route(svc, 1)
334            .on_request_routed(move |_idx, _weight| {
335                rc.fetch_add(1, Ordering::SeqCst);
336            })
337            .build();
338
339        for _ in 0..5 {
340            let _ = router.ready().await.unwrap().call(()).await;
341        }
342
343        assert_eq!(routed_count.load(Ordering::SeqCst), 5);
344    }
345
346    #[tokio::test]
347    async fn builder_accessors() {
348        let router = WeightedRouter::builder()
349            .name("canary")
350            .route(counting_svc(Arc::new(AtomicUsize::new(0)), "a"), 90)
351            .route(counting_svc(Arc::new(AtomicUsize::new(0)), "b"), 10)
352            .build();
353
354        assert_eq!(router.backend_count(), 2);
355        assert_eq!(router.weights(), vec![90, 10]);
356        assert_eq!(router.name(), "canary");
357    }
358
359    #[test]
360    #[should_panic(expected = "at least one backend is required")]
361    fn panics_on_no_backends() {
362        let _router: WeightedRouter<BoxSvc> = WeightedRouter::builder().build();
363    }
364
365    #[test]
366    #[should_panic(expected = "weight 0")]
367    fn panics_on_zero_weight() {
368        let svc: BoxSvc = BoxService::new(tower::service_fn(|_: ()| async {
369            Ok::<_, TestError>("ok")
370        }));
371        let _router = WeightedRouter::builder().route(svc, 0).build();
372    }
373
374    #[tokio::test]
375    async fn random_strategy_converges() {
376        let count_a = Arc::new(AtomicUsize::new(0));
377        let count_b = Arc::new(AtomicUsize::new(0));
378
379        let mut router = WeightedRouter::builder()
380            .route(counting_svc(Arc::clone(&count_a), "a"), 80)
381            .route(counting_svc(Arc::clone(&count_b), "b"), 20)
382            .random()
383            .build();
384
385        let total = 10_000;
386        for _ in 0..total {
387            let _ = router.ready().await.unwrap().call(()).await;
388        }
389
390        let a = count_a.load(Ordering::SeqCst);
391        let ratio = a as f64 / total as f64;
392        assert!(
393            (0.75..=0.85).contains(&ratio),
394            "expected ~80%, got {:.1}%",
395            ratio * 100.0
396        );
397    }
398}