1pub mod config;
83pub mod error;
85pub mod events;
87pub mod selection;
89
90pub use config::WeightedRouterBuilder;
91pub use error::WeightedRouterError;
92pub use events::RouterEvent;
93pub use selection::SelectionStrategy;
94
95use config::RouterConfig;
96use selection::WeightedSelector;
97use std::task::{Context, Poll};
98use tower_service::Service;
99
100pub struct WeightedRouter<S> {
108 backends: Vec<(S, u32)>,
110 selector: WeightedSelector,
112 config: RouterConfig,
114}
115
116impl<S> WeightedRouter<S> {
117 pub fn builder() -> WeightedRouterBuilder<S> {
136 WeightedRouterBuilder::new()
137 }
138
139 pub(crate) fn new(backends: Vec<(S, u32)>, config: RouterConfig) -> Self {
140 let weights: Vec<u32> = backends.iter().map(|(_, w)| *w).collect();
141 let selector = WeightedSelector::new(&weights, config.strategy);
142 Self {
143 backends,
144 selector,
145 config,
146 }
147 }
148
149 pub fn backend_count(&self) -> usize {
151 self.backends.len()
152 }
153
154 pub fn weights(&self) -> Vec<u32> {
156 self.backends.iter().map(|(_, w)| *w).collect()
157 }
158
159 pub fn name(&self) -> &str {
161 &self.config.name
162 }
163}
164
165impl<S: Clone> Clone for WeightedRouter<S> {
166 fn clone(&self) -> Self {
167 Self {
168 backends: self.backends.clone(),
169 selector: self.selector.clone(),
170 config: self.config.clone(),
171 }
172 }
173}
174
175impl<S, Request> Service<Request> for WeightedRouter<S>
176where
177 S: Service<Request>,
178{
179 type Response = S::Response;
180 type Error = S::Error;
181 type Future = S::Future;
182
183 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
184 for (svc, _) in &mut self.backends {
186 match svc.poll_ready(cx)? {
187 Poll::Ready(()) => {}
188 Poll::Pending => return Poll::Pending,
189 }
190 }
191 Poll::Ready(Ok(()))
192 }
193
194 fn call(&mut self, request: Request) -> Self::Future {
195 let idx = self.selector.select();
196 let (svc, weight) = &mut self.backends[idx];
197
198 #[cfg(feature = "metrics")]
199 {
200 let labels = [
201 ("router", self.config.name.clone()),
202 ("backend", idx.to_string()),
203 ];
204 metrics::counter!("router_requests_routed_total", &labels).increment(1);
205 }
206
207 #[cfg(feature = "tracing")]
208 {
209 tracing::debug!(
210 router = %self.config.name,
211 backend_index = idx,
212 backend_weight = *weight,
213 "routing request to backend"
214 );
215 }
216
217 self.config
218 .event_listeners
219 .emit(&RouterEvent::RequestRouted {
220 pattern_name: self.config.name.clone(),
221 timestamp: std::time::Instant::now(),
222 backend_index: idx,
223 backend_weight: *weight,
224 });
225
226 svc.call(request)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use std::sync::atomic::{AtomicUsize, Ordering};
234 use std::sync::Arc;
235 use tower::util::BoxService;
236 use tower::ServiceExt;
237
238 type BoxSvc = BoxService<(), &'static str, TestError>;
239
240 #[derive(Clone, Debug)]
241 struct TestError;
242 impl std::fmt::Display for TestError {
243 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 write!(f, "test error")
245 }
246 }
247 impl std::error::Error for TestError {}
248
249 fn counting_svc(counter: Arc<AtomicUsize>, label: &'static str) -> BoxSvc {
250 BoxService::new(tower::service_fn(move |_: ()| {
251 let c = Arc::clone(&counter);
252 async move {
253 c.fetch_add(1, Ordering::SeqCst);
254 Ok::<_, TestError>(label)
255 }
256 }))
257 }
258
259 #[tokio::test]
260 async fn routes_by_weight_deterministic() {
261 let count_a = Arc::new(AtomicUsize::new(0));
262 let count_b = Arc::new(AtomicUsize::new(0));
263
264 let mut router = WeightedRouter::builder()
265 .route(counting_svc(Arc::clone(&count_a), "a"), 80)
266 .route(counting_svc(Arc::clone(&count_b), "b"), 20)
267 .build();
268
269 for _ in 0..100 {
270 let _ = router.ready().await.unwrap().call(()).await;
271 }
272
273 assert_eq!(count_a.load(Ordering::SeqCst), 80);
274 assert_eq!(count_b.load(Ordering::SeqCst), 20);
275 }
276
277 #[tokio::test]
278 async fn single_backend_gets_all_traffic() {
279 let count = Arc::new(AtomicUsize::new(0));
280
281 let mut router = WeightedRouter::builder()
282 .route(counting_svc(Arc::clone(&count), "ok"), 1)
283 .build();
284
285 for _ in 0..50 {
286 let _ = router.ready().await.unwrap().call(()).await;
287 }
288
289 assert_eq!(count.load(Ordering::SeqCst), 50);
290 }
291
292 #[tokio::test]
293 async fn three_backends() {
294 let counts: Vec<Arc<AtomicUsize>> = (0..3).map(|_| Arc::new(AtomicUsize::new(0))).collect();
295
296 let mut router = WeightedRouter::builder()
297 .route(counting_svc(Arc::clone(&counts[0]), "0"), 50)
298 .route(counting_svc(Arc::clone(&counts[1]), "1"), 30)
299 .route(counting_svc(Arc::clone(&counts[2]), "2"), 20)
300 .build();
301
302 for _ in 0..100 {
303 let _ = router.ready().await.unwrap().call(()).await;
304 }
305
306 assert_eq!(counts[0].load(Ordering::SeqCst), 50);
307 assert_eq!(counts[1].load(Ordering::SeqCst), 30);
308 assert_eq!(counts[2].load(Ordering::SeqCst), 20);
309 }
310
311 #[tokio::test]
312 async fn error_propagates_from_backend() {
313 let svc: BoxSvc = BoxService::new(tower::service_fn(|_: ()| async {
314 Err::<&str, _>(TestError)
315 }));
316
317 let mut router = WeightedRouter::builder().route(svc, 1).build();
318
319 let result = router.ready().await.unwrap().call(()).await;
320 assert!(result.is_err());
321 }
322
323 #[tokio::test]
324 async fn event_listener_fires() {
325 let routed_count = Arc::new(AtomicUsize::new(0));
326 let rc = Arc::clone(&routed_count);
327
328 let svc: BoxSvc = BoxService::new(tower::service_fn(|_: ()| async {
329 Ok::<_, TestError>("ok")
330 }));
331
332 let mut router = WeightedRouter::builder()
333 .route(svc, 1)
334 .on_request_routed(move |_idx, _weight| {
335 rc.fetch_add(1, Ordering::SeqCst);
336 })
337 .build();
338
339 for _ in 0..5 {
340 let _ = router.ready().await.unwrap().call(()).await;
341 }
342
343 assert_eq!(routed_count.load(Ordering::SeqCst), 5);
344 }
345
346 #[tokio::test]
347 async fn builder_accessors() {
348 let router = WeightedRouter::builder()
349 .name("canary")
350 .route(counting_svc(Arc::new(AtomicUsize::new(0)), "a"), 90)
351 .route(counting_svc(Arc::new(AtomicUsize::new(0)), "b"), 10)
352 .build();
353
354 assert_eq!(router.backend_count(), 2);
355 assert_eq!(router.weights(), vec![90, 10]);
356 assert_eq!(router.name(), "canary");
357 }
358
359 #[test]
360 #[should_panic(expected = "at least one backend is required")]
361 fn panics_on_no_backends() {
362 let _router: WeightedRouter<BoxSvc> = WeightedRouter::builder().build();
363 }
364
365 #[test]
366 #[should_panic(expected = "weight 0")]
367 fn panics_on_zero_weight() {
368 let svc: BoxSvc = BoxService::new(tower::service_fn(|_: ()| async {
369 Ok::<_, TestError>("ok")
370 }));
371 let _router = WeightedRouter::builder().route(svc, 0).build();
372 }
373
374 #[tokio::test]
375 async fn random_strategy_converges() {
376 let count_a = Arc::new(AtomicUsize::new(0));
377 let count_b = Arc::new(AtomicUsize::new(0));
378
379 let mut router = WeightedRouter::builder()
380 .route(counting_svc(Arc::clone(&count_a), "a"), 80)
381 .route(counting_svc(Arc::clone(&count_b), "b"), 20)
382 .random()
383 .build();
384
385 let total = 10_000;
386 for _ in 0..total {
387 let _ = router.ready().await.unwrap().call(()).await;
388 }
389
390 let a = count_a.load(Ordering::SeqCst);
391 let ratio = a as f64 / total as f64;
392 assert!(
393 (0.75..=0.85).contains(&ratio),
394 "expected ~80%, got {:.1}%",
395 ratio * 100.0
396 );
397 }
398}