1use crate::streaming_parser::{ParseError, ParseResult};
2use crate::{header::HeaderKey, Headers, Method, Request};
3
4type MethodState = (usize, usize);
5type PathState = (usize, usize);
6type ProtocolState = (usize, usize);
7type HeaderKeyState = (usize, usize);
8
9enum State {
10 Nothing,
11 MethodParsed(MethodState),
12 PathParsed(MethodState, PathState),
13 HeaderKey(MethodState, PathState, ProtocolState, usize),
14 HeaderValue(MethodState, PathState, ProtocolState, HeaderKeyState),
15 HeadersParsed(MethodState, PathState, ProtocolState, usize),
16}
17
18enum ProgressState {
19 Head,
20 Body(usize),
21 Done,
22}
23
24pub struct ReqParser {
27 buffer: Vec<u8>,
28 body_buffer: Vec<u8>,
29 headers_buf: Vec<((usize, usize), (usize, usize))>,
30 state: State,
31 progress: ProgressState,
32}
33
34impl ReqParser {
35 pub fn new_capacity(cap: usize) -> Self {
39 Self {
40 buffer: Vec::with_capacity(cap),
41 body_buffer: Vec::new(),
42 headers_buf: Vec::with_capacity(20),
43 state: State::Nothing,
44 progress: ProgressState::Head,
45 }
46 }
47
48 pub fn clear(&mut self) {
54 self.buffer.clear();
57 self.body_buffer.clear();
58 self.headers_buf.clear();
59
60 self.state = State::Nothing;
63 self.progress = ProgressState::Head;
64 }
65
66 fn parse(&mut self, byte: u8, current: usize) -> ProgressState {
67 match &mut self.state {
68 State::Nothing if byte == b' ' => {
69 let end = current;
70 self.state = State::MethodParsed((0, end));
71 ProgressState::Head
72 }
73 State::MethodParsed(method) if byte == b' ' => {
74 let start = method.1;
75 let end = current;
76
77 self.state = State::PathParsed(*method, (start + 1, end));
78 ProgressState::Head
79 }
80 State::PathParsed(method, path) if byte == b'\r' => {
81 let start = path.1;
82 let end = current;
83
84 self.state = State::HeaderKey(*method, *path, (start + 1, end), end);
85 ProgressState::Head
86 }
87 State::HeaderKey(method, path, protocol, raw_start)
88 if current == *raw_start + 2 && byte == b'\r' =>
89 {
90 self.state = State::HeadersParsed(*method, *path, *protocol, current + 2);
91 ProgressState::Head
92 }
93 State::HeaderKey(method, path, protocol, raw_start)
94 if byte == b':' && *raw_start + 2 <= current =>
95 {
96 let start = *raw_start + 2;
97 let end = current;
98
99 self.state = State::HeaderValue(*method, *path, *protocol, (start, end));
100 ProgressState::Head
101 }
102 State::HeaderValue(method, path, protocol, header_key)
103 if byte == b'\r' && header_key.1 + 2 <= current =>
104 {
105 let start = header_key.1 + 2;
106 let end = current;
107
108 self.headers_buf.push((*header_key, (start, end)));
109 self.state = State::HeaderKey(*method, *path, *protocol, end);
110 ProgressState::Head
111 }
112 State::HeadersParsed(_, _, _, end) if current == *end - 1 => {
113 let mut length: usize = 0;
115 for raw_header_pair in self.headers_buf.iter() {
116 let key_pair = raw_header_pair.0;
117 let value_pair = raw_header_pair.1;
118
119 let key_str = match std::str::from_utf8(&self.buffer[key_pair.0..key_pair.1]) {
120 Ok(k) => k,
121 Err(_) => {
122 continue;
123 }
124 };
125 if HeaderKey::StrRef(key_str) != HeaderKey::StrRef("Content-Length") {
126 continue;
127 }
128
129 let value_str =
130 match std::str::from_utf8(&self.buffer[value_pair.0..value_pair.1]) {
131 Ok(v) => v,
132 Err(_) => {
133 continue;
134 }
135 };
136
137 length = value_str.parse().unwrap();
138 break;
139 }
140
141 if length > 0 {
142 ProgressState::Body(length)
143 } else {
144 ProgressState::Done
145 }
146 }
147 _ => ProgressState::Head,
148 }
149 }
150
151 pub fn block_parse(&mut self, bytes: &[u8]) -> (bool, Option<usize>) {
158 match self.progress {
159 ProgressState::Head => {
160 let start_point = self.buffer.len();
161 self.buffer.reserve(bytes.len());
162
163 for (index, tmp_byte) in bytes.iter().enumerate() {
164 self.buffer.push(*tmp_byte);
165 self.progress = self.parse(*tmp_byte, start_point + index);
166 match self.progress {
167 ProgressState::Body(length) => {
168 self.body_buffer.reserve(length);
169 return self.block_parse(&bytes[index + 1..]);
170 }
171 ProgressState::Done => {
172 return self.block_parse(&bytes[index + 1..]);
173 }
174 _ => {}
175 }
176 }
177
178 (false, None)
179 }
180 ProgressState::Body(length) => {
181 let left_to_read = length - self.body_buffer.len();
182 if left_to_read == 0 {
183 self.progress = ProgressState::Done;
184 return self.block_parse(&[]);
185 }
186
187 let chunk_size = bytes.len();
188 if left_to_read >= chunk_size {
189 self.body_buffer.extend_from_slice(bytes);
190 (self.body_buffer.len() == length, None)
191 } else {
192 self.body_buffer.extend_from_slice(&bytes[..left_to_read]);
193 self.progress = ProgressState::Done;
194 self.block_parse(&bytes[left_to_read..])
195 }
196 }
197 ProgressState::Done => {
198 let length = bytes.len();
199 let rest = (length > 0).then(|| length);
200
201 (true, rest)
202 }
203 }
204 }
205
206 pub fn finish<'a, 'b>(&'a self) -> ParseResult<Request<'b>>
210 where
211 'a: 'b,
212 {
213 let (method, path, protocol) = match &self.state {
214 State::HeadersParsed(m, p, pt, _) => (m, p, pt),
215 State::Nothing => {
216 return Err(ParseError::MissingMethod);
217 }
218 State::MethodParsed(_) => {
219 return Err(ParseError::MissingPath);
220 }
221 State::PathParsed(_, _) => {
222 return Err(ParseError::MissingProtocol);
223 }
224 State::HeaderKey(_, _, _, _) | State::HeaderValue(_, _, _, _) => {
225 return Err(ParseError::MissingHeaders);
226 }
227 };
228
229 let raw_method = &self.buffer[method.0..method.1];
230 let raw_path = &self.buffer[path.0..path.1];
231 let raw_protocol = &self.buffer[protocol.0..protocol.1];
232
233 let method = unsafe { std::str::from_utf8_unchecked(raw_method) };
234 let path = unsafe { std::str::from_utf8_unchecked(raw_path) };
235 let protocol = unsafe { std::str::from_utf8_unchecked(raw_protocol) };
236
237 let parsed_method = match Method::parse(method) {
238 Some(m) => m,
239 None => return Err(ParseError::MissingMethod),
240 };
241
242 let header_count = self.headers_buf.len();
243 let mut headers = Headers::with_capacity(header_count);
244 for tmp_header in self.headers_buf.iter() {
245 let key_range = tmp_header.0;
246 let raw_key = &self.buffer[key_range.0..key_range.1];
247
248 let value_range = tmp_header.1;
249 let raw_value = &self.buffer[value_range.0..value_range.1];
250
251 let key = unsafe { std::str::from_utf8_unchecked(raw_key) };
252 let value = unsafe { std::str::from_utf8_unchecked(raw_value) };
253
254 headers.append(key, value);
257 }
258
259 let body = &self.body_buffer;
260
261 Ok(Request::new(protocol, parsed_method, path, headers, body))
262 }
263
264 pub fn buffer(&self) -> &[u8] {
266 &self.buffer
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn parser_parse_no_body() {
276 let block = "GET /path/ HTTP/1.1\r\nTest-1: Value-1\r\n\r\n";
277
278 let mut parser = ReqParser::new_capacity(4096);
279 assert_eq!((true, None), parser.block_parse(block.as_bytes()));
280
281 let mut headers = Headers::new();
282 headers.set("Test-1", "Value-1");
283 assert_eq!(
284 Ok(Request::new(
285 "HTTP/1.1",
286 Method::GET,
287 "/path/",
288 headers,
289 "".as_bytes()
290 )),
291 parser.finish()
292 );
293 }
294 #[test]
295 fn parser_parse_with_body() {
296 let block = "GET /path/ HTTP/1.1\r\nContent-Length: 22\r\n\r\nThis is just some body";
297
298 let mut parser = ReqParser::new_capacity(4096);
299 assert_eq!((true, None), parser.block_parse(block.as_bytes()));
300
301 let mut headers = Headers::new();
302 headers.set("Content-Length", "22");
303 assert_eq!(
304 Ok(Request::new(
305 "HTTP/1.1",
306 Method::GET,
307 "/path/",
308 headers,
309 "This is just some body".as_bytes()
310 )),
311 parser.finish()
312 );
313 }
314 #[test]
315 fn parser_parse_multiple_headers_with_body() {
316 let block =
317 "GET /path/ HTTP/1.1\r\nContent-Length: 22\r\nTest-2: Value-2\r\n\r\nThis is just some body";
318 let mut parser = ReqParser::new_capacity(4096);
319 assert_eq!((true, None), parser.block_parse(block.as_bytes()));
320
321 let mut headers = Headers::new();
322 headers.set("Content-Length", "22");
323 headers.set("Test-2", "Value-2");
324 assert_eq!(
325 Ok(Request::new(
326 "HTTP/1.1",
327 Method::GET,
328 "/path/",
329 headers,
330 "This is just some body".as_bytes()
331 )),
332 parser.finish()
333 );
334 }
335 #[test]
336 fn parser_parse_multiple_headers_with_body_set_shorter() {
337 let block =
338 "GET /path/ HTTP/1.1\r\nContent-Length: 10\r\nTest-2: Value-2\r\n\r\nThis is just some body";
339 let mut parser = ReqParser::new_capacity(4096);
340 assert_eq!((true, Some(12)), parser.block_parse(block.as_bytes()));
341
342 let mut headers = Headers::new();
343 headers.set("Content-Length", "10");
344 headers.set("Test-2", "Value-2");
345 assert_eq!(
346 Ok(Request::new(
347 "HTTP/1.1",
348 Method::GET,
349 "/path/",
350 headers,
351 "This is ju".as_bytes()
352 )),
353 parser.finish()
354 );
355 }
356
357 #[test]
358 fn parser_missing_method() {
359 let block = "";
360 let mut parser = ReqParser::new_capacity(4096);
361 assert_eq!((false, None), parser.block_parse(block.as_bytes()));
362
363 assert_eq!(Err(ParseError::MissingMethod), parser.finish());
364 }
365 #[test]
366 fn parser_missing_path() {
367 let block = "GET ";
368 let mut parser = ReqParser::new_capacity(4096);
369 assert_eq!((false, None), parser.block_parse(block.as_bytes()));
370
371 assert_eq!(Err(ParseError::MissingPath), parser.finish());
372 }
373 #[test]
374 fn parser_missing_protocol() {
375 let block = "GET /path/ ";
376 let mut parser = ReqParser::new_capacity(4096);
377 assert_eq!((false, None), parser.block_parse(block.as_bytes()));
378
379 assert_eq!(Err(ParseError::MissingProtocol), parser.finish());
380 }
381 #[test]
382 fn parser_missing_headers() {
383 let block = "GET /path/ HTTP/1.1\r\n";
384 let mut parser = ReqParser::new_capacity(4096);
385 assert_eq!((false, None), parser.block_parse(block.as_bytes()));
386
387 assert_eq!(Err(ParseError::MissingHeaders), parser.finish());
388 }
389
390 #[test]
391 fn parser_fuzzing_bug_0() {
392 let block = vec![
393 13, 36, 32, 32, 36, 13, 58, 32, 32, 13, 36, 13, 36, 32, 32, 36, 13, 58, 36, 32, 32, 36,
394 13, 58, 1,
395 ];
396 let mut parser = ReqParser::new_capacity(2048);
397
398 assert_eq!((false, None), parser.block_parse(&block));
399 }
400 #[test]
401 fn parser_fuzzing_bug_1() {
402 let block = vec![
403 84, 82, 65, 67, 69, 32, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
404 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 85, 58, 13, 36, 36, 58, 93, 0, 36,
405 32, 32, 13, 213, 58, 13, 36, 36, 58, 13, 36, 32, 32, 13, 85, 58, 13, 36, 36, 58, 93, 0,
406 36, 32, 32, 13, 213, 58, 13, 36, 36, 58, 13, 64, 13, 36, 64,
407 ];
408 let mut parser = ReqParser::new_capacity(2048);
409
410 assert_eq!((true, Some(1)), parser.block_parse(&block));
411 assert_eq!(true, parser.finish().is_ok());
413 }
414}