Skip to main content

rust_web_server/canary/
mod.rs

1//! Weighted canary / A-B traffic splitting middleware.
2//!
3//! [`CanaryLayer`] implements [`Middleware`] and distributes incoming requests
4//! across a set of backends according to configurable weights.  A backend with
5//! weight 3 receives three times as many requests as one with weight 1.
6//!
7//! Backends are contacted over plain HTTP/1.1, or over TLS when the backend
8//! URL uses an `https://`, `h2s://`, or `grpcs://` scheme (requires the
9//! `http-client` or `http2` feature — both pull in `rustls`). If a backend is
10//! unavailable the next one in the rotation is tried; after exhausting all
11//! backends the middleware returns `502 Bad Gateway`.
12//!
13//! # Example
14//!
15//! ```rust,no_run
16//! use rust_web_server::app::App;
17//! use rust_web_server::core::New;
18//! use rust_web_server::canary::{CanaryLayer, WeightedBackend};
19//! use rust_web_server::middleware::WithMiddleware;
20//!
21//! // 75 % of traffic → stable, 25 % → canary
22//! let app = WithMiddleware::new(App::new())
23//!     .wrap(
24//!         CanaryLayer::new(vec![
25//!             WeightedBackend::new("http://stable:8080", 3),
26//!             WeightedBackend::new("http://canary:8080", 1),
27//!         ])
28//!         .path_prefix("/api"),
29//!     );
30//! ```
31
32#[cfg(test)]
33mod tests;
34
35use std::sync::atomic::{AtomicUsize, Ordering};
36use std::time::Duration;
37
38use crate::application::Application;
39use crate::core::New;
40use crate::middleware::Middleware;
41use crate::mime_type::MimeType;
42use crate::range::Range;
43use crate::request::Request;
44use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
45use crate::server::ConnectionInfo;
46
47// ── WeightedBackend ───────────────────────────────────────────────────────────
48
49/// A backend URL together with a relative traffic weight.
50///
51/// A weight of 0 causes the backend to be skipped entirely.
52#[derive(Clone)]
53pub struct WeightedBackend {
54    pub url: String,
55    pub weight: u32,
56}
57
58impl WeightedBackend {
59    /// Create a new weighted backend.
60    pub fn new(url: impl Into<String>, weight: u32) -> Self {
61        Self { url: url.into(), weight }
62    }
63}
64
65// ── CanaryLayer ───────────────────────────────────────────────────────────────
66
67/// Weighted traffic-splitting proxy middleware.
68///
69/// The rotation is pre-expanded so that each backend appears exactly `weight`
70/// times.  An atomic counter selects the next entry in the rotation on every
71/// request, giving a deterministic, lock-free weighted round-robin distribution.
72pub struct CanaryLayer {
73    /// Expanded rotation: each entry is `(host, port, tls)` and appears
74    /// `weight` times. `tls` is set when the backend's URL used an
75    /// `https://`/`h2s://`/`grpcs://` scheme.
76    pub(crate) rotation: Vec<(String, u16, bool)>,
77    counter: AtomicUsize,
78    connect_timeout: Duration,
79    read_timeout: Duration,
80    path_prefix: Option<String>,
81}
82
83impl CanaryLayer {
84    /// Build a `CanaryLayer` from the given weighted backends.
85    ///
86    /// Backends with `weight == 0` are ignored.
87    pub fn new(backends: Vec<WeightedBackend>) -> Self {
88        let mut rotation: Vec<(String, u16, bool)> = Vec::new();
89        for wb in &backends {
90            if wb.weight == 0 {
91                continue;
92            }
93            if let Some((host, port, tls)) = parse_backend_url(&wb.url) {
94                for _ in 0..wb.weight {
95                    rotation.push((host.clone(), port, tls));
96                }
97            }
98        }
99        Self {
100            rotation,
101            counter: AtomicUsize::new(0),
102            connect_timeout: Duration::from_secs(5),
103            read_timeout: Duration::from_secs(30),
104            path_prefix: None,
105        }
106    }
107
108    /// Only proxy requests whose URI starts with `prefix`; pass others through.
109    pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
110        self.path_prefix = Some(prefix.into());
111        self
112    }
113
114    /// Override the TCP connect timeout (default: 5 000 ms).
115    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
116        self.connect_timeout = Duration::from_millis(ms);
117        self
118    }
119
120    /// Override the response read timeout (default: 30 000 ms).
121    pub fn read_timeout_ms(mut self, ms: u64) -> Self {
122        self.read_timeout = Duration::from_millis(ms);
123        self
124    }
125
126    /// Try every backend in rotation order until one succeeds.
127    fn proxy(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
128        if self.rotation.is_empty() {
129            return Err("CanaryLayer: no backends in rotation".to_string());
130        }
131        let n = self.rotation.len();
132        let start = self.counter.fetch_add(1, Ordering::Relaxed);
133        // Deduplicate by (host, port) so we don't hit the same backend twice
134        // when it appears multiple times in the rotation.
135        let mut tried: Vec<usize> = Vec::new();
136        for attempt in 0..n {
137            let idx = (start + attempt) % n;
138            let backend = &self.rotation[idx];
139            // Check if we already tried this (host, port) pair.
140            let already_tried = tried.iter().any(|&i| self.rotation[i] == *backend);
141            if already_tried {
142                continue;
143            }
144            tried.push(idx);
145            let (host, port, tls) = backend;
146            let result = if *tls {
147                #[cfg(any(feature = "http-client", feature = "http2"))]
148                {
149                    crate::proxy::proxy_https1(
150                        request,
151                        &connection.client.ip,
152                        host,
153                        *port,
154                        self.connect_timeout,
155                        self.read_timeout,
156                    )
157                }
158                #[cfg(not(any(feature = "http-client", feature = "http2")))]
159                {
160                    Err("CanaryLayer: TLS backend requires the http-client or http2 feature".to_string())
161                }
162            } else {
163                crate::proxy::proxy_http1(
164                    request,
165                    &connection.client.ip,
166                    host,
167                    *port,
168                    self.connect_timeout,
169                    self.read_timeout,
170                )
171            };
172            match result {
173                Ok(resp) => return Ok(resp),
174                Err(_) => continue,
175            }
176        }
177        Err("CanaryLayer: all backends failed".to_string())
178    }
179}
180
181impl Middleware for CanaryLayer {
182    fn handle(
183        &self,
184        request: &Request,
185        connection: &ConnectionInfo,
186        next: &dyn Application,
187    ) -> Result<Response, String> {
188        if let Some(prefix) = &self.path_prefix {
189            if !request.request_uri.starts_with(prefix.as_str()) {
190                return next.execute(request, connection);
191            }
192        }
193        match self.proxy(request, connection) {
194            Ok(resp) => Ok(resp),
195            Err(_) => Ok(bad_gateway()),
196        }
197    }
198}
199
200// ── helpers ───────────────────────────────────────────────────────────────────
201
202/// Parse a backend URL of the form `[scheme://]host[:port][/path]` into
203/// `(host, port, tls)`.
204///
205/// `https://`, `h2s://`, and `grpcs://` set `tls = true` and default to port
206/// 443; `http://`, `h2://`, `grpc://`, and a bare `host[:port]` set
207/// `tls = false` and default to port 80 — matching `proxy::Backend::parse`'s
208/// scheme conventions.
209fn parse_backend_url(url: &str) -> Option<(String, u16, bool)> {
210    let (rest, tls, default_port) = if let Some(r) = url.strip_prefix("https://") {
211        (r, true, 443u16)
212    } else if let Some(r) = url.strip_prefix("h2s://") {
213        (r, true, 443u16)
214    } else if let Some(r) = url.strip_prefix("grpcs://") {
215        (r, true, 443u16)
216    } else if let Some(r) = url.strip_prefix("http://") {
217        (r, false, 80u16)
218    } else if let Some(r) = url.strip_prefix("h2://") {
219        (r, false, 80u16)
220    } else if let Some(r) = url.strip_prefix("grpc://") {
221        (r, false, 80u16)
222    } else {
223        (url, false, 80u16)
224    };
225    // Drop any path component
226    let host_port = rest.split('/').next().unwrap_or(rest);
227    let (host, port) = if let Some(colon) = host_port.rfind(':') {
228        let port_str = &host_port[colon + 1..];
229        if let Ok(p) = port_str.parse::<u16>() {
230            (host_port[..colon].to_string(), p)
231        } else {
232            (host_port.to_string(), default_port)
233        }
234    } else {
235        (host_port.to_string(), default_port)
236    };
237    if host.is_empty() { None } else { Some((host, port, tls)) }
238}
239
240fn bad_gateway() -> Response {
241    let cr = Range::get_content_range(
242        b"502 Bad Gateway".to_vec(),
243        MimeType::TEXT_PLAIN.to_string(),
244    );
245    let mut r = Response::new();
246    r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
247    r.reason_phrase = STATUS_CODE_REASON_PHRASE.n502_bad_gateway.reason_phrase.to_string();
248    r.content_range_list = vec![cr];
249    r
250}