Skip to main content

sozu_lib/protocol/kawa_h1/
parser.rs

1//! Kawa-adjacent H1 parsing helpers.
2//!
3//! Hosts the byte-comparison primitives, the `Method` enum, and the small
4//! nom helpers Kawa relies on when its tolerant/strict mode boundaries do
5//! not cover Sōzu-specific cases (custom methods, ASCII-fast-path
6//! comparisons). Pure-function module; no session state.
7
8use std::{
9    cmp::min,
10    fmt::{self, Write},
11    ops::Deref,
12};
13
14use nom::{
15    Err, IResult,
16    bytes::{self, complete::take_while},
17    character::{complete::digit1, is_alphanumeric},
18    combinator::opt,
19    error::{Error, ErrorKind},
20    sequence::preceded,
21};
22
23pub fn compare_no_case(left: &[u8], right: &[u8]) -> bool {
24    if left.len() != right.len() {
25        return false;
26    }
27
28    left.iter().zip(right).all(|(a, b)| match (*a, *b) {
29        (0..=64, 0..=64) | (91..=96, 91..=96) | (123..=255, 123..=255) => a == b,
30        (65..=90, 65..=90) | (97..=122, 97..=122) | (65..=90, 97..=122) | (97..=122, 65..=90) => {
31            *a | 0b00_10_00_00 == *b | 0b00_10_00_00
32        }
33        _ => false,
34    })
35}
36
37#[derive(PartialEq, Eq, Debug, Clone)]
38pub enum Method {
39    Get,
40    Post,
41    Head,
42    Options,
43    Put,
44    Delete,
45    Trace,
46    Connect,
47    Custom(String),
48}
49
50impl Method {
51    pub fn new(s: &[u8]) -> Method {
52        if compare_no_case(s, b"GET") {
53            Method::Get
54        } else if compare_no_case(s, b"POST") {
55            Method::Post
56        } else if compare_no_case(s, b"HEAD") {
57            Method::Head
58        } else if compare_no_case(s, b"OPTIONS") {
59            Method::Options
60        } else if compare_no_case(s, b"PUT") {
61            Method::Put
62        } else if compare_no_case(s, b"DELETE") {
63            Method::Delete
64        } else if compare_no_case(s, b"TRACE") {
65            Method::Trace
66        } else if compare_no_case(s, b"CONNECT") {
67            Method::Connect
68        } else {
69            Method::Custom(String::from_utf8_lossy(s).into_owned())
70        }
71    }
72}
73
74impl AsRef<str> for Method {
75    fn as_ref(&self) -> &str {
76        match self {
77            Self::Get => "GET",
78            Self::Post => "POST",
79            Self::Head => "HEAD",
80            Self::Options => "OPTIONS",
81            Self::Put => "PUT",
82            Self::Delete => "DELETE",
83            Self::Trace => "TRACE",
84            Self::Connect => "CONNECT",
85            Self::Custom(custom) => custom,
86        }
87    }
88}
89
90impl Deref for Method {
91    type Target = str;
92
93    fn deref(&self) -> &Self::Target {
94        self.as_ref()
95    }
96}
97
98impl fmt::Display for Method {
99    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
100        write!(f, "{}", self.as_ref())
101    }
102}
103
104#[cfg(feature = "tolerant-http1-parser")]
105fn is_hostname_char(i: u8) -> bool {
106    is_alphanumeric(i) ||
107  // the domain name should not start with a hyphen or dot
108  // but is it important here, since we will match this to
109  // the list of accepted clusters?
110  // BTW each label between dots has a max of 63 chars,
111  // and the whole domain should not be larger than 253 chars
112  //
113  // this tolerant parser also allows underscore, which is wrong
114  // in domain names but accepted by some proxies and web servers
115  // see https://github.com/sozu-proxy/sozu/issues/480
116  b"-._".contains(&i)
117}
118
119#[cfg(not(feature = "tolerant-http1-parser"))]
120fn is_hostname_char(i: u8) -> bool {
121    is_alphanumeric(i) ||
122  // the domain name should not start with a hyphen or dot
123  // but is it important here, since we will match this to
124  // the list of accepted clusters?
125  // BTW each label between dots has a max of 63 chars,
126  // and the whole domain should not be larger than 253 chars
127  b"-.".contains(&i)
128}
129
130pub fn hostname_and_port(i: &[u8]) -> IResult<&[u8], (&[u8], Option<u16>)> {
131    let (i, host) = take_while(is_hostname_char)(i)?;
132    if host.is_empty() {
133        return Err(Err::Error(Error::new(i, ErrorKind::Alpha)));
134    }
135    let (i, port_bytes) = opt(preceded(bytes::complete::tag(":"), digit1))(i)?;
136    let port = match port_bytes {
137        // `digit1` guarantees ASCII digits, so the slice is always valid UTF-8;
138        // `parse::<u16>()` rejects values that overflow the 16-bit port space
139        // (RFC 6335 §6) and `Ok(0)` is explicitly rejected because port 0 is
140        // reserved by the same RFC and not a routable authority.
141        Some(bytes) => match std::str::from_utf8(bytes).unwrap().parse::<u16>() {
142            Ok(p) if p != 0 => Some(p),
143            _ => return Err(Err::Error(Error::new(bytes, ErrorKind::Digit))),
144        },
145        None => None,
146    };
147
148    if !i.is_empty() {
149        return Err(Err::Error(Error::new(i, ErrorKind::Eof)));
150    }
151    Ok((i, (host, port)))
152}
153
154pub fn view(buf: &[u8], size: usize, points: &[usize]) -> String {
155    let mut view = format!("{points:?} => ");
156    let mut end = 0;
157    for (i, point) in points.iter().enumerate() {
158        if *point > buf.len() {
159            break;
160        }
161        let start = if end + size < *point {
162            view.push_str("... ");
163            point - size
164        } else {
165            end
166        };
167        let stop = if i + 1 < points.len() {
168            min(buf.len(), points[i + 1])
169        } else {
170            buf.len()
171        };
172        end = if point + size > stop {
173            stop
174        } else {
175            point + size
176        };
177        for element in &buf[start..*point] {
178            let _ = view.write_fmt(format_args!("{element:02X} "));
179        }
180        view.push_str("| ");
181        for element in &buf[*point..end] {
182            let _ = view.write_fmt(format_args!("{element:02X} "));
183        }
184    }
185    if end < buf.len() {
186        view.push_str("...")
187    }
188    view
189}
190
191#[test]
192fn test_view_out_of_bound() {
193    println!(
194        "{}",
195        view(
196            &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
197            2,
198            &[5, 5, 8, 9, 9, 13, 80]
199        )
200    );
201}
202
203#[test]
204fn custom_method_does_not_assume_utf8() {
205    let method = Method::new(b"\xFFBAD");
206
207    assert_eq!(method, Method::Custom("\u{FFFD}BAD".to_owned()));
208}
209
210#[test]
211fn hostname_and_port_rejects_empty_host() {
212    assert!(hostname_and_port(b":80").is_err());
213}
214
215#[test]
216fn hostname_and_port_rejects_out_of_range_port() {
217    assert!(hostname_and_port(b"example.com:65536").is_err());
218}
219
220#[test]
221fn hostname_and_port_rejects_port_zero() {
222    // RFC 6335 §6 reserves port 0; it cannot identify a TCP/UDP service.
223    assert!(hostname_and_port(b"example.com:0").is_err());
224}
225
226#[test]
227fn hostname_and_port_accepts_u16_port() {
228    let (remaining, (host, port)) = hostname_and_port(b"example.com:65535").unwrap();
229
230    assert!(remaining.is_empty());
231    assert_eq!(host, b"example.com");
232    assert_eq!(port, Some(65535));
233}
234
235#[test]
236fn hostname_and_port_returns_no_port_when_absent() {
237    let (remaining, (host, port)) = hostname_and_port(b"example.com").unwrap();
238
239    assert!(remaining.is_empty());
240    assert_eq!(host, b"example.com");
241    assert_eq!(port, None);
242}