Skip to main content

sozu_command_lib/
request.rs

1use std::{
2    error,
3    fmt::{self, Display},
4    fs::File,
5    io::{BufReader, Read},
6    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
7    str::FromStr,
8};
9
10use prost::{DecodeError, Message};
11use rusty_ulid::Ulid;
12
13use crate::{
14    proto::{
15        command::{
16            InitialState, IpAddress, LoadBalancingAlgorithms, PathRuleKind, Request,
17            RequestHttpFrontend, RulePosition, SocketAddress, Uint128, WorkerRequest, ip_address,
18            request::RequestType,
19        },
20        display::format_request_type,
21    },
22    response::HttpFrontend,
23};
24
25#[derive(thiserror::Error, Debug)]
26pub enum RequestError {
27    #[error("invalid value {value} for field '{name}'")]
28    InvalidValue { name: String, value: i32 },
29    #[error("Could not read requests from file: {0}")]
30    ReadFile(std::io::Error),
31    #[error("Could not decode requests: {0}")]
32    Decode(DecodeError),
33}
34
35impl Request {
36    /// determine to which of the four proxies (HTTP, HTTPS, TCP, UDP) a request is destined
37    pub fn get_destinations(&self) -> ProxyDestinations {
38        let mut proxy_destination = ProxyDestinations {
39            to_http_proxy: false,
40            to_https_proxy: false,
41            to_tcp_proxy: false,
42            to_udp_proxy: false,
43        };
44        let request_type = match &self.request_type {
45            Some(t) => t,
46            None => return proxy_destination,
47        };
48
49        match request_type {
50            RequestType::AddHttpFrontend(_) | RequestType::RemoveHttpFrontend(_) => {
51                proxy_destination.to_http_proxy = true
52            }
53
54            RequestType::AddHttpsFrontend(_)
55            | RequestType::RemoveHttpsFrontend(_)
56            | RequestType::AddCertificate(_)
57            | RequestType::QueryCertificatesFromWorkers(_)
58            | RequestType::ReplaceCertificate(_)
59            | RequestType::RemoveCertificate(_) => proxy_destination.to_https_proxy = true,
60
61            RequestType::AddTcpFrontend(_) | RequestType::RemoveTcpFrontend(_) => {
62                proxy_destination.to_tcp_proxy = true
63            }
64
65            RequestType::AddUdpFrontend(_) | RequestType::RemoveUdpFrontend(_) => {
66                proxy_destination.to_udp_proxy = true
67            }
68
69            RequestType::AddCluster(_)
70            | RequestType::AddBackend(_)
71            | RequestType::RemoveCluster(_)
72            | RequestType::RemoveBackend(_)
73            | RequestType::SetHealthCheck(_)
74            | RequestType::RemoveHealthCheck(_)
75            | RequestType::SoftStop(_)
76            | RequestType::HardStop(_)
77            | RequestType::Status(_) => {
78                proxy_destination.to_http_proxy = true;
79                proxy_destination.to_https_proxy = true;
80                proxy_destination.to_tcp_proxy = true;
81                proxy_destination.to_udp_proxy = true;
82            }
83
84            // handled at worker level prior to this call
85            RequestType::ConfigureMetrics(_)
86            | RequestType::SetMetricDetail(_)
87            | RequestType::QueryMetrics(_)
88            | RequestType::Logging(_)
89            | RequestType::QueryClustersHashes(_)
90            | RequestType::QueryClusterById(_)
91            | RequestType::QueryClustersByDomain(_)
92            | RequestType::SetMaxConnectionsPerIp(_)
93            | RequestType::QueryMaxConnectionsPerIp(_) => {}
94
95            // the Add***Listener / Update***Listener and other Listener orders will be
96            // handled separately by the notify_proxys function, so we don't give them
97            // destinations here
98            RequestType::AddHttpsListener(_)
99            | RequestType::AddHttpListener(_)
100            | RequestType::AddTcpListener(_)
101            | RequestType::AddUdpListener(_)
102            | RequestType::UpdateHttpListener(_)
103            | RequestType::UpdateHttpsListener(_)
104            | RequestType::UpdateTcpListener(_)
105            | RequestType::UpdateUdpListener(_)
106            | RequestType::RemoveListener(_)
107            | RequestType::ActivateListener(_)
108            | RequestType::DeactivateListener(_)
109            | RequestType::ReturnListenSockets(_) => {}
110
111            // These won't ever reach a worker anyway
112            RequestType::SaveState(_)
113            | RequestType::CountRequests(_)
114            | RequestType::QueryCertificatesFromTheState(_)
115            | RequestType::QueryHealthChecks(_)
116            | RequestType::LoadState(_)
117            | RequestType::ListWorkers(_)
118            | RequestType::ListFrontends(_)
119            | RequestType::ListListeners(_)
120            | RequestType::LaunchWorker(_)
121            | RequestType::UpgradeMain(_)
122            | RequestType::UpgradeWorker(_)
123            | RequestType::SubscribeEvents(_)
124            | RequestType::ReloadConfiguration(_) => {}
125        }
126
127        // POST: HTTP-frontend orders route to the HTTP proxy ONLY, HTTPS /
128        // certificate orders to the HTTPS proxy ONLY, and TCP-frontend orders
129        // to the TCP proxy ONLY — a frontend order must never fan out across
130        // protocol planes (that would double-apply the order). Cluster-wide
131        // and broadcast orders (AddCluster, SoftStop, …) legitimately target
132        // all three, so we only assert the single-plane exclusivity here.
133        debug_assert!(
134            !(proxy_destination.to_http_proxy
135                && proxy_destination.to_https_proxy
136                && proxy_destination.to_tcp_proxy)
137                || matches!(
138                    self.request_type,
139                    Some(
140                        RequestType::AddCluster(_)
141                            | RequestType::AddBackend(_)
142                            | RequestType::RemoveCluster(_)
143                            | RequestType::RemoveBackend(_)
144                            | RequestType::SetHealthCheck(_)
145                            | RequestType::RemoveHealthCheck(_)
146                            | RequestType::SoftStop(_)
147                            | RequestType::HardStop(_)
148                            | RequestType::Status(_)
149                    )
150                ),
151            "only cluster-wide / broadcast orders may target all three proxy planes"
152        );
153        // POST: a None request_type carries no destination at all.
154        debug_assert!(
155            self.request_type.is_some()
156                || (!proxy_destination.to_http_proxy
157                    && !proxy_destination.to_https_proxy
158                    && !proxy_destination.to_tcp_proxy),
159            "a request without a request_type must have no proxy destination"
160        );
161        proxy_destination
162    }
163
164    /// True if the request is a SoftStop or a HardStop
165    pub fn is_a_stop(&self) -> bool {
166        matches!(
167            self.request_type,
168            Some(RequestType::SoftStop(_)) | Some(RequestType::HardStop(_))
169        )
170    }
171
172    pub fn short_name(&self) -> &str {
173        match &self.request_type {
174            Some(request_type) => format_request_type(request_type),
175            None => "Unallowed",
176        }
177    }
178}
179
180impl WorkerRequest {
181    pub fn new(id: String, content: Request) -> Self {
182        Self { id, content }
183    }
184}
185
186impl fmt::Display for WorkerRequest {
187    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
188        write!(f, "{}-{:?}", self.id, self.content)
189    }
190}
191
192pub fn read_initial_state_from_file(file: &mut File) -> Result<InitialState, RequestError> {
193    let mut buf_reader = BufReader::new(file);
194    read_initial_state(&mut buf_reader)
195}
196
197pub fn read_initial_state<R: Read>(reader: &mut R) -> Result<InitialState, RequestError> {
198    let mut buffer = Vec::new();
199    reader
200        .read_to_end(&mut buffer)
201        .map_err(RequestError::ReadFile)?;
202
203    InitialState::decode(&buffer[..]).map_err(RequestError::Decode)
204}
205
206#[derive(Debug, Clone, PartialEq, Eq, Hash)]
207pub struct ProxyDestinations {
208    pub to_http_proxy: bool,
209    pub to_https_proxy: bool,
210    pub to_tcp_proxy: bool,
211    pub to_udp_proxy: bool,
212}
213
214impl RequestHttpFrontend {
215    /// convert a requested frontend to a usable one by parsing its address
216    pub fn to_frontend(self) -> Result<HttpFrontend, RequestError> {
217        let requested_hostname = self.hostname.clone();
218        let requested_cluster_id = self.cluster_id.clone();
219        let frontend = HttpFrontend {
220            address: self.address.into(),
221            cluster_id: self.cluster_id,
222            hostname: self.hostname,
223            path: self.path,
224            method: self.method,
225            position: RulePosition::try_from(self.position).map_err(|_| {
226                RequestError::InvalidValue {
227                    name: "position".to_string(),
228                    value: self.position,
229                }
230            })?,
231            tags: Some(self.tags),
232            redirect: self.redirect,
233            redirect_scheme: self.redirect_scheme,
234            redirect_template: self.redirect_template,
235            rewrite_host: self.rewrite_host,
236            rewrite_path: self.rewrite_path,
237            rewrite_port: self.rewrite_port,
238            required_auth: self.required_auth,
239            headers: self.headers,
240            hsts: self.hsts,
241        };
242
243        // POST: routing identity (hostname + cluster_id) is carried through
244        // unchanged — only the address is reparsed and the position is mapped
245        // through the proto enum. A frontend whose hostname or cluster shifted
246        // here would route traffic to the wrong place.
247        debug_assert_eq!(
248            frontend.hostname, requested_hostname,
249            "hostname must survive the frontend conversion"
250        );
251        debug_assert_eq!(
252            frontend.cluster_id, requested_cluster_id,
253            "cluster_id must survive the frontend conversion"
254        );
255        Ok(frontend)
256    }
257}
258
259impl Display for RequestHttpFrontend {
260    /// Used to create a unique summary of the frontend, used as a key in maps
261    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
262        let s = match &PathRuleKind::try_from(self.path.kind) {
263            Ok(PathRuleKind::Prefix) => {
264                format!("{};{};P{}", self.address, self.hostname, self.path.value)
265            }
266            Ok(PathRuleKind::Regex) => {
267                format!("{};{};R{}", self.address, self.hostname, self.path.value)
268            }
269            Ok(PathRuleKind::Equals) => {
270                format!("{};{};={}", self.address, self.hostname, self.path.value)
271            }
272            Err(e) => format!("Wrong variant of PathRuleKind: {e}"),
273        };
274
275        match &self.method {
276            Some(method) => write!(f, "{s};{method}"),
277            None => write!(f, "{s}"),
278        }
279    }
280}
281
282#[derive(Debug)]
283pub struct ParseErrorLoadBalancing;
284
285impl fmt::Display for ParseErrorLoadBalancing {
286    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
287        write!(f, "Cannot find the load balancing policy asked")
288    }
289}
290
291impl error::Error for ParseErrorLoadBalancing {
292    fn description(&self) -> &str {
293        "Cannot find the load balancing policy asked"
294    }
295
296    fn cause(&self) -> Option<&dyn error::Error> {
297        None
298    }
299}
300
301impl FromStr for LoadBalancingAlgorithms {
302    type Err = ParseErrorLoadBalancing;
303
304    fn from_str(s: &str) -> Result<Self, Self::Err> {
305        match s.to_lowercase().as_str() {
306            "round_robin" => Ok(LoadBalancingAlgorithms::RoundRobin),
307            "random" => Ok(LoadBalancingAlgorithms::Random),
308            "power_of_two" => Ok(LoadBalancingAlgorithms::PowerOfTwo),
309            "least_loaded" => Ok(LoadBalancingAlgorithms::LeastLoaded),
310            "hrw" => Ok(LoadBalancingAlgorithms::Hrw),
311            "maglev" => Ok(LoadBalancingAlgorithms::Maglev),
312            _ => Err(ParseErrorLoadBalancing {}),
313        }
314    }
315}
316
317impl SocketAddress {
318    pub fn new_v4(a: u8, b: u8, c: u8, d: u8, port: u16) -> Self {
319        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(a, b, c, d)), port).into()
320    }
321}
322
323impl From<SocketAddr> for SocketAddress {
324    fn from(socket_addr: SocketAddr) -> SocketAddress {
325        let ip_inner = match socket_addr {
326            SocketAddr::V4(ip_v4_addr) => ip_address::Inner::V4(u32::from(*ip_v4_addr.ip())),
327            SocketAddr::V6(ip_v6_addr) => {
328                ip_address::Inner::V6(Uint128::from(u128::from(*ip_v6_addr.ip())))
329            }
330        };
331
332        let encoded = SocketAddress {
333            port: socket_addr.port() as u32,
334            ip: IpAddress {
335                inner: Some(ip_inner),
336            },
337        };
338
339        // POST: the port widens losslessly (u16 → u32) and the proto address
340        // family matches the source family — a V4 SocketAddr must never encode
341        // as a V6 inner and vice versa, or the reverse `From` would synthesize
342        // the wrong address.
343        debug_assert_eq!(
344            encoded.port,
345            socket_addr.port() as u32,
346            "port must round-trip losslessly into the proto"
347        );
348        debug_assert_eq!(
349            matches!(encoded.ip.inner, Some(ip_address::Inner::V4(_))),
350            socket_addr.is_ipv4(),
351            "proto IP family must match the source SocketAddr family"
352        );
353        encoded
354    }
355}
356
357impl From<SocketAddress> for SocketAddr {
358    fn from(socket_address: SocketAddress) -> Self {
359        // PRE: a wire-sourced proto port may exceed u16::MAX (16-bit on the
360        // wire is carried as a 32-bit field). This is peer/config input, so we
361        // narrow rather than panic; the debug_assert only guards our *own*
362        // encoders, which never emit an out-of-range port.
363        debug_assert!(
364            socket_address.port <= u16::MAX as u32,
365            "self-encoded proto port must fit in a u16"
366        );
367        let had_inner = socket_address.ip.inner.is_some();
368        let port = socket_address.port as u16;
369
370        let ip = match socket_address.ip.inner {
371            Some(inner) => match inner {
372                ip_address::Inner::V4(v4_value) => IpAddr::V4(Ipv4Addr::from(v4_value)),
373                ip_address::Inner::V6(v6_value) => IpAddr::V6(Ipv6Addr::from(u128::from(v6_value))),
374            },
375            None => IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), // should never happen
376        };
377
378        let decoded = SocketAddr::new(ip, port);
379        // POST: a self-encoded proto address always carries an inner IP, so the
380        // unspecified-V4 fallback is only ever reached on malformed peer input.
381        debug_assert!(
382            had_inner || decoded.ip() == IpAddr::V4(Ipv4Addr::UNSPECIFIED),
383            "missing inner IP must decode to the unspecified-V4 sentinel"
384        );
385        decoded
386    }
387}
388
389impl From<Uint128> for u128 {
390    fn from(value: Uint128) -> Self {
391        let combined = value.low as u128 | ((value.high as u128) << 64);
392        // POST: the two 64-bit halves occupy disjoint bit ranges, so the low
393        // half is recoverable as the bottom 64 bits and the high half as the
394        // top 64 bits — the pack is bijective.
395        debug_assert_eq!(
396            combined as u64, value.low,
397            "low half must be the bottom 64 bits"
398        );
399        debug_assert_eq!(
400            (combined >> 64) as u64,
401            value.high,
402            "high half must be the top 64 bits"
403        );
404        combined
405    }
406}
407
408impl From<u128> for Uint128 {
409    fn from(value: u128) -> Self {
410        let low = value as u64;
411        let high = (value >> 64) as u64;
412        let packed = Uint128 { low, high };
413        // POST: splitting then recombining reproduces the original u128 — the
414        // split-into-halves and join-from-halves operations are mutual
415        // inverses (no bit is lost or duplicated).
416        debug_assert_eq!(
417            u128::from(Uint128 {
418                low: packed.low,
419                high: packed.high
420            }),
421            value,
422            "u128 → Uint128 → u128 must round-trip"
423        );
424        packed
425    }
426}
427
428impl From<i128> for Uint128 {
429    fn from(value: i128) -> Self {
430        Uint128::from(value as u128)
431    }
432}
433
434impl From<Ulid> for Uint128 {
435    fn from(value: Ulid) -> Self {
436        let (low, high) = value.into();
437        let packed = Uint128 { low, high };
438        // POST: the (low, high) tuple is carried verbatim into the proto, so
439        // re-reading it reconstructs the same Ulid — the encoding loses no bits.
440        debug_assert_eq!(
441            Ulid::from((packed.low, packed.high)),
442            value,
443            "Ulid → Uint128 must preserve all 128 bits"
444        );
445        packed
446    }
447}
448
449impl From<Uint128> for Ulid {
450    fn from(value: Uint128) -> Self {
451        let Uint128 { low, high } = value;
452        let ulid = Ulid::from((low, high));
453        // POST: the decode is the exact inverse of the encode above — the same
454        // halves go back out, so Uint128 → Ulid → Uint128 round-trips.
455        debug_assert_eq!(
456            Uint128::from(ulid),
457            value,
458            "Uint128 → Ulid must preserve all 128 bits"
459        );
460        ulid
461    }
462}