rust_web_server/canary/
mod.rs1#[cfg(test)]
33mod tests;
34
35use std::sync::atomic::{AtomicUsize, Ordering};
36use std::time::Duration;
37
38use crate::application::Application;
39use crate::core::New;
40use crate::middleware::Middleware;
41use crate::mime_type::MimeType;
42use crate::range::Range;
43use crate::request::Request;
44use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
45use crate::server::ConnectionInfo;
46
47#[derive(Clone)]
53pub struct WeightedBackend {
54 pub url: String,
55 pub weight: u32,
56}
57
58impl WeightedBackend {
59 pub fn new(url: impl Into<String>, weight: u32) -> Self {
61 Self { url: url.into(), weight }
62 }
63}
64
65pub struct CanaryLayer {
73 pub(crate) rotation: Vec<(String, u16, bool)>,
77 counter: AtomicUsize,
78 connect_timeout: Duration,
79 read_timeout: Duration,
80 path_prefix: Option<String>,
81}
82
83impl CanaryLayer {
84 pub fn new(backends: Vec<WeightedBackend>) -> Self {
88 let mut rotation: Vec<(String, u16, bool)> = Vec::new();
89 for wb in &backends {
90 if wb.weight == 0 {
91 continue;
92 }
93 if let Some((host, port, tls)) = parse_backend_url(&wb.url) {
94 for _ in 0..wb.weight {
95 rotation.push((host.clone(), port, tls));
96 }
97 }
98 }
99 Self {
100 rotation,
101 counter: AtomicUsize::new(0),
102 connect_timeout: Duration::from_secs(5),
103 read_timeout: Duration::from_secs(30),
104 path_prefix: None,
105 }
106 }
107
108 pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
110 self.path_prefix = Some(prefix.into());
111 self
112 }
113
114 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
116 self.connect_timeout = Duration::from_millis(ms);
117 self
118 }
119
120 pub fn read_timeout_ms(mut self, ms: u64) -> Self {
122 self.read_timeout = Duration::from_millis(ms);
123 self
124 }
125
126 fn proxy(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
128 if self.rotation.is_empty() {
129 return Err("CanaryLayer: no backends in rotation".to_string());
130 }
131 let n = self.rotation.len();
132 let start = self.counter.fetch_add(1, Ordering::Relaxed);
133 let mut tried: Vec<usize> = Vec::new();
136 for attempt in 0..n {
137 let idx = (start + attempt) % n;
138 let backend = &self.rotation[idx];
139 let already_tried = tried.iter().any(|&i| self.rotation[i] == *backend);
141 if already_tried {
142 continue;
143 }
144 tried.push(idx);
145 let (host, port, tls) = backend;
146 let result = if *tls {
147 #[cfg(any(feature = "http-client", feature = "http2"))]
148 {
149 crate::proxy::proxy_https1(
150 request,
151 &connection.client.ip,
152 host,
153 *port,
154 self.connect_timeout,
155 self.read_timeout,
156 )
157 }
158 #[cfg(not(any(feature = "http-client", feature = "http2")))]
159 {
160 Err("CanaryLayer: TLS backend requires the http-client or http2 feature".to_string())
161 }
162 } else {
163 crate::proxy::proxy_http1(
164 request,
165 &connection.client.ip,
166 host,
167 *port,
168 self.connect_timeout,
169 self.read_timeout,
170 )
171 };
172 match result {
173 Ok(resp) => return Ok(resp),
174 Err(_) => continue,
175 }
176 }
177 Err("CanaryLayer: all backends failed".to_string())
178 }
179}
180
181impl Middleware for CanaryLayer {
182 fn handle(
183 &self,
184 request: &Request,
185 connection: &ConnectionInfo,
186 next: &dyn Application,
187 ) -> Result<Response, String> {
188 if let Some(prefix) = &self.path_prefix {
189 if !request.request_uri.starts_with(prefix.as_str()) {
190 return next.execute(request, connection);
191 }
192 }
193 match self.proxy(request, connection) {
194 Ok(resp) => Ok(resp),
195 Err(_) => Ok(bad_gateway()),
196 }
197 }
198}
199
200fn parse_backend_url(url: &str) -> Option<(String, u16, bool)> {
210 let (rest, tls, default_port) = if let Some(r) = url.strip_prefix("https://") {
211 (r, true, 443u16)
212 } else if let Some(r) = url.strip_prefix("h2s://") {
213 (r, true, 443u16)
214 } else if let Some(r) = url.strip_prefix("grpcs://") {
215 (r, true, 443u16)
216 } else if let Some(r) = url.strip_prefix("http://") {
217 (r, false, 80u16)
218 } else if let Some(r) = url.strip_prefix("h2://") {
219 (r, false, 80u16)
220 } else if let Some(r) = url.strip_prefix("grpc://") {
221 (r, false, 80u16)
222 } else {
223 (url, false, 80u16)
224 };
225 let host_port = rest.split('/').next().unwrap_or(rest);
227 let (host, port) = if let Some(colon) = host_port.rfind(':') {
228 let port_str = &host_port[colon + 1..];
229 if let Ok(p) = port_str.parse::<u16>() {
230 (host_port[..colon].to_string(), p)
231 } else {
232 (host_port.to_string(), default_port)
233 }
234 } else {
235 (host_port.to_string(), default_port)
236 };
237 if host.is_empty() { None } else { Some((host, port, tls)) }
238}
239
240fn bad_gateway() -> Response {
241 let cr = Range::get_content_range(
242 b"502 Bad Gateway".to_vec(),
243 MimeType::TEXT_PLAIN.to_string(),
244 );
245 let mut r = Response::new();
246 r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
247 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n502_bad_gateway.reason_phrase.to_string();
248 r.content_range_list = vec![cr];
249 r
250}