Skip to main content

rust_web_server/canary/
mod.rs

1//! Weighted canary / A-B traffic splitting middleware.
2//!
3//! [`CanaryLayer`] implements [`Middleware`] and distributes incoming requests
4//! across a set of backends according to configurable weights.  A backend with
5//! weight 3 receives three times as many requests as one with weight 1.
6//!
7//! Backends are contacted over plain HTTP/1.1.  If a backend is unavailable the
8//! next one in the rotation is tried; after exhausting all backends the
9//! middleware returns `502 Bad Gateway`.
10//!
11//! # Example
12//!
13//! ```rust,no_run
14//! use rust_web_server::app::App;
15//! use rust_web_server::core::New;
16//! use rust_web_server::canary::{CanaryLayer, WeightedBackend};
17//! use rust_web_server::middleware::WithMiddleware;
18//!
19//! // 75 % of traffic → stable, 25 % → canary
20//! let app = WithMiddleware::new(App::new())
21//!     .wrap(
22//!         CanaryLayer::new(vec![
23//!             WeightedBackend::new("http://stable:8080", 3),
24//!             WeightedBackend::new("http://canary:8080", 1),
25//!         ])
26//!         .path_prefix("/api"),
27//!     );
28//! ```
29
30#[cfg(test)]
31mod tests;
32
33use std::sync::atomic::{AtomicUsize, Ordering};
34use std::time::Duration;
35
36use crate::application::Application;
37use crate::core::New;
38use crate::middleware::Middleware;
39use crate::mime_type::MimeType;
40use crate::range::Range;
41use crate::request::Request;
42use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
43use crate::server::ConnectionInfo;
44
45// ── WeightedBackend ───────────────────────────────────────────────────────────
46
47/// A backend URL together with a relative traffic weight.
48///
49/// A weight of 0 causes the backend to be skipped entirely.
50#[derive(Clone)]
51pub struct WeightedBackend {
52    pub url: String,
53    pub weight: u32,
54}
55
56impl WeightedBackend {
57    /// Create a new weighted backend.
58    pub fn new(url: impl Into<String>, weight: u32) -> Self {
59        Self { url: url.into(), weight }
60    }
61}
62
63// ── CanaryLayer ───────────────────────────────────────────────────────────────
64
65/// Weighted traffic-splitting proxy middleware.
66///
67/// The rotation is pre-expanded so that each backend appears exactly `weight`
68/// times.  An atomic counter selects the next entry in the rotation on every
69/// request, giving a deterministic, lock-free weighted round-robin distribution.
70pub struct CanaryLayer {
71    /// Expanded rotation: each entry is `(host, port)` and appears `weight` times.
72    pub(crate) rotation: Vec<(String, u16)>,
73    counter: AtomicUsize,
74    connect_timeout: Duration,
75    read_timeout: Duration,
76    path_prefix: Option<String>,
77}
78
79impl CanaryLayer {
80    /// Build a `CanaryLayer` from the given weighted backends.
81    ///
82    /// Backends with `weight == 0` are ignored.
83    pub fn new(backends: Vec<WeightedBackend>) -> Self {
84        let mut rotation: Vec<(String, u16)> = Vec::new();
85        for wb in &backends {
86            if wb.weight == 0 {
87                continue;
88            }
89            if let Some((host, port)) = parse_backend_url(&wb.url) {
90                for _ in 0..wb.weight {
91                    rotation.push((host.clone(), port));
92                }
93            }
94        }
95        Self {
96            rotation,
97            counter: AtomicUsize::new(0),
98            connect_timeout: Duration::from_secs(5),
99            read_timeout: Duration::from_secs(30),
100            path_prefix: None,
101        }
102    }
103
104    /// Only proxy requests whose URI starts with `prefix`; pass others through.
105    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
106        self.path_prefix = Some(prefix.into());
107        self
108    }
109
110    /// Override the TCP connect timeout (default: 5 000 ms).
111    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
112        self.connect_timeout = Duration::from_millis(ms);
113        self
114    }
115
116    /// Override the response read timeout (default: 30 000 ms).
117    pub fn read_timeout_ms(mut self, ms: u64) -> Self {
118        self.read_timeout = Duration::from_millis(ms);
119        self
120    }
121
122    /// Try every backend in rotation order until one succeeds.
123    fn proxy(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
124        if self.rotation.is_empty() {
125            return Err("CanaryLayer: no backends in rotation".to_string());
126        }
127        let n = self.rotation.len();
128        let start = self.counter.fetch_add(1, Ordering::Relaxed);
129        // Deduplicate by (host, port) so we don't hit the same backend twice
130        // when it appears multiple times in the rotation.
131        let mut tried: Vec<usize> = Vec::new();
132        for attempt in 0..n {
133            let idx = (start + attempt) % n;
134            let backend = &self.rotation[idx];
135            // Check if we already tried this (host, port) pair.
136            let already_tried = tried.iter().any(|&i| self.rotation[i] == *backend);
137            if already_tried {
138                continue;
139            }
140            tried.push(idx);
141            match crate::proxy::proxy_http1(
142                request,
143                &connection.client.ip,
144                &backend.0,
145                backend.1,
146                self.connect_timeout,
147                self.read_timeout,
148            ) {
149                Ok(resp) => return Ok(resp),
150                Err(_) => continue,
151            }
152        }
153        Err("CanaryLayer: all backends failed".to_string())
154    }
155}
156
157impl Middleware for CanaryLayer {
158    fn handle(
159        &self,
160        request: &Request,
161        connection: &ConnectionInfo,
162        next: &dyn Application,
163    ) -> Result<Response, String> {
164        if let Some(prefix) = &self.path_prefix {
165            if !request.request_uri.starts_with(prefix.as_str()) {
166                return next.execute(request, connection);
167            }
168        }
169        match self.proxy(request, connection) {
170            Ok(resp) => Ok(resp),
171            Err(_) => Ok(bad_gateway()),
172        }
173    }
174}
175
176// ── helpers ───────────────────────────────────────────────────────────────────
177
178/// Parse a backend URL of the form `[http://]host[:port][/path]` into
179/// `(host, port)`.  Defaults to port 80 when no port is present.
180fn parse_backend_url(url: &str) -> Option<(String, u16)> {
181    let rest = url
182        .strip_prefix("https://")
183        .or_else(|| url.strip_prefix("http://"))
184        .or_else(|| url.strip_prefix("h2://"))
185        .unwrap_or(url);
186    // Drop any path component
187    let host_port = rest.split('/').next().unwrap_or(rest);
188    let (host, port) = if let Some(colon) = host_port.rfind(':') {
189        let port_str = &host_port[colon + 1..];
190        if let Ok(p) = port_str.parse::<u16>() {
191            (host_port[..colon].to_string(), p)
192        } else {
193            (host_port.to_string(), 80)
194        }
195    } else {
196        (host_port.to_string(), 80)
197    };
198    if host.is_empty() { None } else { Some((host, port)) }
199}
200
201fn bad_gateway() -> Response {
202    let cr = Range::get_content_range(
203        b"502 Bad Gateway".to_vec(),
204        MimeType::TEXT_PLAIN.to_string(),
205    );
206    let mut r = Response::new();
207    r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
208    r.reason_phrase = STATUS_CODE_REASON_PHRASE.n502_bad_gateway.reason_phrase.to_string();
209    r.content_range_list = vec![cr];
210    r
211}