stream_httparse/streaming_parser/
req_parser.rs

1use crate::streaming_parser::{ParseError, ParseResult};
2use crate::{header::HeaderKey, Headers, Method, Request};
3
4type MethodState = (usize, usize);
5type PathState = (usize, usize);
6type ProtocolState = (usize, usize);
7type HeaderKeyState = (usize, usize);
8
9enum State {
10    Nothing,
11    MethodParsed(MethodState),
12    PathParsed(MethodState, PathState),
13    HeaderKey(MethodState, PathState, ProtocolState, usize),
14    HeaderValue(MethodState, PathState, ProtocolState, HeaderKeyState),
15    HeadersParsed(MethodState, PathState, ProtocolState, usize),
16}
17
18enum ProgressState {
19    Head,
20    Body(usize),
21    Done,
22}
23
24/// A single Instance of the RequestParser that is used
25/// to parse HTTP-Requests
26pub struct ReqParser {
27    buffer: Vec<u8>,
28    body_buffer: Vec<u8>,
29    headers_buf: Vec<((usize, usize), (usize, usize))>,
30    state: State,
31    progress: ProgressState,
32}
33
34impl ReqParser {
35    /// Creates a new Request-Parser with the given
36    /// capacity as its pre-reserved capacity to store
37    /// the Head of the Request
38    pub fn new_capacity(cap: usize) -> Self {
39        Self {
40            buffer: Vec::with_capacity(cap),
41            body_buffer: Vec::new(),
42            headers_buf: Vec::with_capacity(20),
43            state: State::Nothing,
44            progress: ProgressState::Head,
45        }
46    }
47
48    /// Clears the internal Buffers and resets everything
49    /// to be ready to receive and parse a new request
50    ///
51    /// This should be the prefered way to parse mulitple
52    /// sequential requests, as this avoids extra allocations
53    pub fn clear(&mut self) {
54        // Clearing out the buffers of the previous request
55        // without needing to perform any new allocations
56        self.buffer.clear();
57        self.body_buffer.clear();
58        self.headers_buf.clear();
59
60        // Sets the Progress and internal State back to the
61        // beginning
62        self.state = State::Nothing;
63        self.progress = ProgressState::Head;
64    }
65
66    fn parse(&mut self, byte: u8, current: usize) -> ProgressState {
67        match &mut self.state {
68            State::Nothing if byte == b' ' => {
69                let end = current;
70                self.state = State::MethodParsed((0, end));
71                ProgressState::Head
72            }
73            State::MethodParsed(method) if byte == b' ' => {
74                let start = method.1;
75                let end = current;
76
77                self.state = State::PathParsed(*method, (start + 1, end));
78                ProgressState::Head
79            }
80            State::PathParsed(method, path) if byte == b'\r' => {
81                let start = path.1;
82                let end = current;
83
84                self.state = State::HeaderKey(*method, *path, (start + 1, end), end);
85                ProgressState::Head
86            }
87            State::HeaderKey(method, path, protocol, raw_start)
88                if current == *raw_start + 2 && byte == b'\r' =>
89            {
90                self.state = State::HeadersParsed(*method, *path, *protocol, current + 2);
91                ProgressState::Head
92            }
93            State::HeaderKey(method, path, protocol, raw_start)
94                if byte == b':' && *raw_start + 2 <= current =>
95            {
96                let start = *raw_start + 2;
97                let end = current;
98
99                self.state = State::HeaderValue(*method, *path, *protocol, (start, end));
100                ProgressState::Head
101            }
102            State::HeaderValue(method, path, protocol, header_key)
103                if byte == b'\r' && header_key.1 + 2 <= current =>
104            {
105                let start = header_key.1 + 2;
106                let end = current;
107
108                self.headers_buf.push((*header_key, (start, end)));
109                self.state = State::HeaderKey(*method, *path, *protocol, end);
110                ProgressState::Head
111            }
112            State::HeadersParsed(_, _, _, end) if current == *end - 1 => {
113                // The Length the body is supposed to have
114                let mut length: usize = 0;
115                for raw_header_pair in self.headers_buf.iter() {
116                    let key_pair = raw_header_pair.0;
117                    let value_pair = raw_header_pair.1;
118
119                    let key_str = match std::str::from_utf8(&self.buffer[key_pair.0..key_pair.1]) {
120                        Ok(k) => k,
121                        Err(_) => {
122                            continue;
123                        }
124                    };
125                    if HeaderKey::StrRef(key_str) != HeaderKey::StrRef("Content-Length") {
126                        continue;
127                    }
128
129                    let value_str =
130                        match std::str::from_utf8(&self.buffer[value_pair.0..value_pair.1]) {
131                            Ok(v) => v,
132                            Err(_) => {
133                                continue;
134                            }
135                        };
136
137                    length = value_str.parse().unwrap();
138                    break;
139                }
140
141                if length > 0 {
142                    ProgressState::Body(length)
143                } else {
144                    ProgressState::Done
145                }
146            }
147            _ => ProgressState::Head,
148        }
149    }
150
151    /// Returns a touple that stands for (done, data-left-in-buffer)
152    ///
153    /// Explanation:
154    /// * `done`: True if the request has been fully received and parsed
155    /// * `data-left-in-buffer`: The Amount of bytes at the end of the given
156    /// slice that were unused
157    pub fn block_parse(&mut self, bytes: &[u8]) -> (bool, Option<usize>) {
158        match self.progress {
159            ProgressState::Head => {
160                let start_point = self.buffer.len();
161                self.buffer.reserve(bytes.len());
162
163                for (index, tmp_byte) in bytes.iter().enumerate() {
164                    self.buffer.push(*tmp_byte);
165                    self.progress = self.parse(*tmp_byte, start_point + index);
166                    match self.progress {
167                        ProgressState::Body(length) => {
168                            self.body_buffer.reserve(length);
169                            return self.block_parse(&bytes[index + 1..]);
170                        }
171                        ProgressState::Done => {
172                            return self.block_parse(&bytes[index + 1..]);
173                        }
174                        _ => {}
175                    }
176                }
177
178                (false, None)
179            }
180            ProgressState::Body(length) => {
181                let left_to_read = length - self.body_buffer.len();
182                if left_to_read == 0 {
183                    self.progress = ProgressState::Done;
184                    return self.block_parse(&[]);
185                }
186
187                let chunk_size = bytes.len();
188                if left_to_read >= chunk_size {
189                    self.body_buffer.extend_from_slice(bytes);
190                    (self.body_buffer.len() == length, None)
191                } else {
192                    self.body_buffer.extend_from_slice(&bytes[..left_to_read]);
193                    self.progress = ProgressState::Done;
194                    self.block_parse(&bytes[left_to_read..])
195                }
196            }
197            ProgressState::Done => {
198                let length = bytes.len();
199                let rest = (length > 0).then(|| length);
200
201                (true, rest)
202            }
203        }
204    }
205
206    /// Finishes up the parsing and finalizes all the Data it received
207    /// and returns a Request-Instance containing the parsed out
208    /// Request
209    pub fn finish<'a, 'b>(&'a self) -> ParseResult<Request<'b>>
210    where
211        'a: 'b,
212    {
213        let (method, path, protocol) = match &self.state {
214            State::HeadersParsed(m, p, pt, _) => (m, p, pt),
215            State::Nothing => {
216                return Err(ParseError::MissingMethod);
217            }
218            State::MethodParsed(_) => {
219                return Err(ParseError::MissingPath);
220            }
221            State::PathParsed(_, _) => {
222                return Err(ParseError::MissingProtocol);
223            }
224            State::HeaderKey(_, _, _, _) | State::HeaderValue(_, _, _, _) => {
225                return Err(ParseError::MissingHeaders);
226            }
227        };
228
229        let raw_method = &self.buffer[method.0..method.1];
230        let raw_path = &self.buffer[path.0..path.1];
231        let raw_protocol = &self.buffer[protocol.0..protocol.1];
232
233        let method = unsafe { std::str::from_utf8_unchecked(raw_method) };
234        let path = unsafe { std::str::from_utf8_unchecked(raw_path) };
235        let protocol = unsafe { std::str::from_utf8_unchecked(raw_protocol) };
236
237        let parsed_method = match Method::parse(method) {
238            Some(m) => m,
239            None => return Err(ParseError::MissingMethod),
240        };
241
242        let header_count = self.headers_buf.len();
243        let mut headers = Headers::with_capacity(header_count);
244        for tmp_header in self.headers_buf.iter() {
245            let key_range = tmp_header.0;
246            let raw_key = &self.buffer[key_range.0..key_range.1];
247
248            let value_range = tmp_header.1;
249            let raw_value = &self.buffer[value_range.0..value_range.1];
250
251            let key = unsafe { std::str::from_utf8_unchecked(raw_key) };
252            let value = unsafe { std::str::from_utf8_unchecked(raw_value) };
253
254            // Use append to simply add the header at the end of the collection
255            // without checking for duplicates
256            headers.append(key, value);
257        }
258
259        let body = &self.body_buffer;
260
261        Ok(Request::new(protocol, parsed_method, path, headers, body))
262    }
263
264    /// Returns the current Buffer of the Parser
265    pub fn buffer(&self) -> &[u8] {
266        &self.buffer
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn parser_parse_no_body() {
276        let block = "GET /path/ HTTP/1.1\r\nTest-1: Value-1\r\n\r\n";
277
278        let mut parser = ReqParser::new_capacity(4096);
279        assert_eq!((true, None), parser.block_parse(block.as_bytes()));
280
281        let mut headers = Headers::new();
282        headers.set("Test-1", "Value-1");
283        assert_eq!(
284            Ok(Request::new(
285                "HTTP/1.1",
286                Method::GET,
287                "/path/",
288                headers,
289                "".as_bytes()
290            )),
291            parser.finish()
292        );
293    }
294    #[test]
295    fn parser_parse_with_body() {
296        let block = "GET /path/ HTTP/1.1\r\nContent-Length: 22\r\n\r\nThis is just some body";
297
298        let mut parser = ReqParser::new_capacity(4096);
299        assert_eq!((true, None), parser.block_parse(block.as_bytes()));
300
301        let mut headers = Headers::new();
302        headers.set("Content-Length", "22");
303        assert_eq!(
304            Ok(Request::new(
305                "HTTP/1.1",
306                Method::GET,
307                "/path/",
308                headers,
309                "This is just some body".as_bytes()
310            )),
311            parser.finish()
312        );
313    }
314    #[test]
315    fn parser_parse_multiple_headers_with_body() {
316        let block =
317        "GET /path/ HTTP/1.1\r\nContent-Length: 22\r\nTest-2: Value-2\r\n\r\nThis is just some body";
318        let mut parser = ReqParser::new_capacity(4096);
319        assert_eq!((true, None), parser.block_parse(block.as_bytes()));
320
321        let mut headers = Headers::new();
322        headers.set("Content-Length", "22");
323        headers.set("Test-2", "Value-2");
324        assert_eq!(
325            Ok(Request::new(
326                "HTTP/1.1",
327                Method::GET,
328                "/path/",
329                headers,
330                "This is just some body".as_bytes()
331            )),
332            parser.finish()
333        );
334    }
335    #[test]
336    fn parser_parse_multiple_headers_with_body_set_shorter() {
337        let block =
338        "GET /path/ HTTP/1.1\r\nContent-Length: 10\r\nTest-2: Value-2\r\n\r\nThis is just some body";
339        let mut parser = ReqParser::new_capacity(4096);
340        assert_eq!((true, Some(12)), parser.block_parse(block.as_bytes()));
341
342        let mut headers = Headers::new();
343        headers.set("Content-Length", "10");
344        headers.set("Test-2", "Value-2");
345        assert_eq!(
346            Ok(Request::new(
347                "HTTP/1.1",
348                Method::GET,
349                "/path/",
350                headers,
351                "This is ju".as_bytes()
352            )),
353            parser.finish()
354        );
355    }
356
357    #[test]
358    fn parser_missing_method() {
359        let block = "";
360        let mut parser = ReqParser::new_capacity(4096);
361        assert_eq!((false, None), parser.block_parse(block.as_bytes()));
362
363        assert_eq!(Err(ParseError::MissingMethod), parser.finish());
364    }
365    #[test]
366    fn parser_missing_path() {
367        let block = "GET ";
368        let mut parser = ReqParser::new_capacity(4096);
369        assert_eq!((false, None), parser.block_parse(block.as_bytes()));
370
371        assert_eq!(Err(ParseError::MissingPath), parser.finish());
372    }
373    #[test]
374    fn parser_missing_protocol() {
375        let block = "GET /path/ ";
376        let mut parser = ReqParser::new_capacity(4096);
377        assert_eq!((false, None), parser.block_parse(block.as_bytes()));
378
379        assert_eq!(Err(ParseError::MissingProtocol), parser.finish());
380    }
381    #[test]
382    fn parser_missing_headers() {
383        let block = "GET /path/ HTTP/1.1\r\n";
384        let mut parser = ReqParser::new_capacity(4096);
385        assert_eq!((false, None), parser.block_parse(block.as_bytes()));
386
387        assert_eq!(Err(ParseError::MissingHeaders), parser.finish());
388    }
389
390    #[test]
391    fn parser_fuzzing_bug_0() {
392        let block = vec![
393            13, 36, 32, 32, 36, 13, 58, 32, 32, 13, 36, 13, 36, 32, 32, 36, 13, 58, 36, 32, 32, 36,
394            13, 58, 1,
395        ];
396        let mut parser = ReqParser::new_capacity(2048);
397
398        assert_eq!((false, None), parser.block_parse(&block));
399    }
400    #[test]
401    fn parser_fuzzing_bug_1() {
402        let block = vec![
403            84, 82, 65, 67, 69, 32, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
404            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 85, 58, 13, 36, 36, 58, 93, 0, 36,
405            32, 32, 13, 213, 58, 13, 36, 36, 58, 13, 36, 32, 32, 13, 85, 58, 13, 36, 36, 58, 93, 0,
406            36, 32, 32, 13, 213, 58, 13, 36, 36, 58, 13, 64, 13, 36, 64,
407        ];
408        let mut parser = ReqParser::new_capacity(2048);
409
410        assert_eq!((true, Some(1)), parser.block_parse(&block));
411        // This is somehow a valid request according to my parser
412        assert_eq!(true, parser.finish().is_ok());
413    }
414}