1use crate::streaming_parser::ParseError;
2use crate::{header::HeaderKey, Headers, Response, StatusCode};
3
4type ProtocolState = (usize, usize);
5type StatusCodeState = (usize, usize);
6type HeaderKeyState = (usize, usize);
7
8enum ParseState {
9 Nothing,
10 ProtocolParsed(ProtocolState),
11 HeaderKey(ProtocolState, StatusCodeState, usize),
12 HeaderValue(ProtocolState, StatusCodeState, HeaderKeyState),
13 HeadersParsed(ProtocolState, StatusCodeState, usize),
14}
15
16#[derive(Debug)]
17enum ProgressState {
18 Head,
19 Body(usize),
21 Done,
22}
23
24pub struct RespParser {
27 buffer: Vec<u8>,
28 body_buffer: Vec<u8>,
29 headers_buf: Vec<((usize, usize), (usize, usize))>,
30 state: ParseState,
31 progress: ProgressState,
32}
33
34impl RespParser {
35 pub fn new_capacity(head_cap: usize) -> Self {
39 Self {
40 buffer: Vec::with_capacity(head_cap),
41 body_buffer: Vec::new(),
42 headers_buf: Vec::with_capacity(20),
43 state: ParseState::Nothing,
44 progress: ProgressState::Head,
45 }
46 }
47
48 pub fn clear(&mut self) {
55 self.buffer.clear();
57 self.body_buffer.clear();
58 self.headers_buf.clear();
59
60 self.state = ParseState::Nothing;
62 self.progress = ProgressState::Head;
63 }
64
65 #[inline(always)]
66 fn parse(&mut self, byte: u8, current: usize) -> ProgressState {
67 match &mut self.state {
68 ParseState::Nothing if byte == b' ' => {
69 let end = current;
70 self.state = ParseState::ProtocolParsed((0, end));
71 ProgressState::Head
72 }
73 ParseState::ProtocolParsed(protocol) if byte == b'\r' => {
74 let start = protocol.1;
75 let end = current;
76
77 self.state = ParseState::HeaderKey(*protocol, (start + 1, end), end);
78 ProgressState::Head
79 }
80 ParseState::HeaderKey(protocol, status_code, raw_start)
81 if current == *raw_start + 2 && byte == b'\r' =>
82 {
83 self.state = ParseState::HeadersParsed(*protocol, *status_code, current + 2);
84 ProgressState::Head
85 }
86 ParseState::HeaderKey(protocol, status_code, raw_start)
87 if byte == b':' && *raw_start + 2 <= current =>
88 {
89 let start = *raw_start + 2;
90 let end = current;
91
92 self.state = ParseState::HeaderValue(*protocol, *status_code, (start, end));
93 ProgressState::Head
94 }
95 ParseState::HeaderValue(protocol, status_code, header_key)
96 if byte == b'\r' && header_key.1 + 2 <= current =>
97 {
98 let start = header_key.1 + 2;
99 let end = current;
100
101 self.headers_buf.push((*header_key, (start, end)));
102 self.state = ParseState::HeaderKey(*protocol, *status_code, end);
103 ProgressState::Head
104 }
105 ParseState::HeadersParsed(_, _, end) if current == *end - 1 => {
106 let mut length: usize = 0;
108 for raw_header_pair in self.headers_buf.iter() {
109 let key_pair = raw_header_pair.0;
110 let value_pair = raw_header_pair.1;
111
112 let key_str = match std::str::from_utf8(&self.buffer[key_pair.0..key_pair.1]) {
113 Ok(k) => k,
114 Err(_) => {
115 continue;
116 }
117 };
118 if HeaderKey::StrRef(key_str) != HeaderKey::StrRef("Content-Length") {
119 continue;
120 }
121
122 let value_str =
123 match std::str::from_utf8(&self.buffer[value_pair.0..value_pair.1]) {
124 Ok(v) => v,
125 Err(_) => {
126 continue;
127 }
128 };
129
130 length = value_str.parse().unwrap();
131 break;
132 }
133
134 if length > 0 {
135 ProgressState::Body(length)
136 } else {
137 ProgressState::Done
138 }
139 }
140 _ => ProgressState::Head,
141 }
142 }
143
144 pub fn block_parse(&mut self, bytes: &[u8]) -> (bool, usize) {
152 match self.progress {
153 ProgressState::Head => {
154 let start_point = self.buffer.len();
155 self.buffer.reserve(bytes.len());
156
157 for (index, tmp_byte) in bytes.iter().enumerate() {
158 self.buffer.push(*tmp_byte);
159 self.progress = self.parse(*tmp_byte, start_point + index);
160 match self.progress {
161 ProgressState::Body(length) => {
162 self.body_buffer.reserve(length);
163 return self.block_parse(&bytes[index + 1..]);
164 }
165 ProgressState::Done => {
166 return self.block_parse(&bytes[index + 1..]);
167 }
168 _ => {}
169 }
170 }
171 (false, 0)
172 }
173 ProgressState::Body(length) => {
174 let left_to_read = length - self.body_buffer.len();
175 if left_to_read == 0 {
176 self.progress = ProgressState::Done;
177 return self.block_parse(&[]);
178 }
179
180 let chunk_size = bytes.len();
181 if left_to_read >= chunk_size {
182 self.body_buffer.extend_from_slice(bytes);
183 (self.body_buffer.len() == length, 0)
184 } else {
185 self.body_buffer.extend_from_slice(&bytes[..left_to_read]);
186 self.progress = ProgressState::Done;
187 self.block_parse(&bytes[left_to_read..])
188 }
189 }
190 ProgressState::Done => (true, bytes.len()),
191 }
192 }
193
194 pub fn finish<'a, 'b>(&'a mut self) -> Result<Response<'b>, ParseError>
197 where
198 'a: 'b,
199 {
200 let (protocol, status_code) = match &self.state {
201 ParseState::HeadersParsed(p, stc, _) => (p, stc),
202 ParseState::Nothing => {
203 return Err(ParseError::MissingProtocol);
204 }
205 ParseState::ProtocolParsed(_) => {
206 return Err(ParseError::MissingStatusCode);
207 }
208 ParseState::HeaderKey(_, _, _) => {
209 return Err(ParseError::MissingHeaders);
210 }
211 ParseState::HeaderValue(_, _, _) => {
212 return Err(ParseError::MissingHeaders);
213 }
214 };
215
216 let raw_protocol = &self.buffer[protocol.0..protocol.1];
217 let raw_status_code = &self.buffer[status_code.0..status_code.1];
218
219 let protocol = unsafe { std::str::from_utf8_unchecked(raw_protocol) };
220 let status_code = match std::str::from_utf8(raw_status_code) {
221 Ok(s) => s,
222 Err(_) => {
223 return Err(ParseError::InvalidStatusCode);
224 }
225 };
226 if !status_code.is_ascii() {
227 return Err(ParseError::InvalidStatusCode);
228 }
229
230 let parsed_status_code = match StatusCode::parse(status_code) {
231 Some(s) => s,
232 None => return Err(ParseError::InvalidStatusCode),
233 };
234
235 let header_count = self.headers_buf.len();
236 let mut headers = Headers::with_capacity(header_count);
237 for tmp_header in self.headers_buf.iter() {
238 let key_range = tmp_header.0;
239 let raw_key = &self.buffer[key_range.0..key_range.1];
240
241 let value_range = tmp_header.1;
242 let raw_value = &self.buffer[value_range.0..value_range.1];
243
244 let key = unsafe { std::str::from_utf8_unchecked(raw_key) };
245 let value = unsafe { std::str::from_utf8_unchecked(raw_value) };
246
247 headers.append(key, value);
250 }
251
252 Ok(Response::new(
253 protocol,
254 parsed_status_code,
255 headers,
256 std::mem::take(&mut self.body_buffer),
257 ))
258 }
259
260 pub fn finish_owned<'a, 'owned>(&'a mut self) -> Result<Response<'owned>, ParseError> {
266 let (protocol, status_code) = match &self.state {
267 ParseState::HeadersParsed(p, stc, _) => (p, stc),
268 ParseState::Nothing => {
269 return Err(ParseError::MissingProtocol);
270 }
271 ParseState::ProtocolParsed(_) => {
272 return Err(ParseError::MissingStatusCode);
273 }
274 ParseState::HeaderKey(_, _, _) => {
275 return Err(ParseError::MissingHeaders);
276 }
277 ParseState::HeaderValue(_, _, _) => {
278 return Err(ParseError::MissingHeaders);
279 }
280 };
281
282 let raw_protocol = &self.buffer[protocol.0..protocol.1];
283 let raw_status_code = &self.buffer[status_code.0..status_code.1];
284
285 let protocol = unsafe { String::from_utf8_unchecked(raw_protocol.to_owned()) };
286 let status_code = match std::str::from_utf8(raw_status_code) {
287 Ok(s) => s,
288 Err(_) => {
289 return Err(ParseError::InvalidStatusCode);
290 }
291 };
292 if !status_code.is_ascii() {
293 return Err(ParseError::InvalidStatusCode);
294 }
295
296 let parsed_status_code = match StatusCode::parse(status_code) {
297 Some(s) => s,
298 None => return Err(ParseError::InvalidStatusCode),
299 };
300
301 let header_count = self.headers_buf.len();
302 let mut headers = Headers::with_capacity(header_count);
303 for tmp_header in self.headers_buf.iter() {
304 let key_range = tmp_header.0;
305 let raw_key = &self.buffer[key_range.0..key_range.1];
306
307 let value_range = tmp_header.1;
308 let raw_value = &self.buffer[value_range.0..value_range.1];
309
310 let key = unsafe { String::from_utf8_unchecked(raw_key.to_owned()) };
311 let value = unsafe { String::from_utf8_unchecked(raw_value.to_owned()) };
312
313 headers.append(key, value);
316 }
317
318 Ok(Response::new_owned(
319 protocol,
320 parsed_status_code,
321 headers,
322 std::mem::take(&mut self.body_buffer),
323 ))
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn parser_parse_no_body() {
333 let block = "HTTP/1.1 200 OK\r\nTest-1: Value-1\r\n\r\n";
334
335 let mut parser = RespParser::new_capacity(1024);
336 assert_eq!((true, 0), parser.block_parse(block.as_bytes()));
337
338 let mut headers = Headers::new();
339 headers.set("Test-1", "Value-1");
340 assert_eq!(
341 Ok(Response::new(
342 "HTTP/1.1",
343 StatusCode::OK,
344 headers,
345 "".as_bytes().to_vec()
346 )),
347 parser.finish()
348 );
349 }
350 #[test]
351 fn parser_parse_with_body() {
352 let block = "HTTP/1.1 200 OK\r\nContent-Length: 22\r\n\r\nThis is just some body";
353
354 let mut parser = RespParser::new_capacity(1024);
355 assert_eq!((true, 0), parser.block_parse(block.as_bytes()));
356
357 let mut headers = Headers::new();
358 headers.set("Content-Length", "22");
359 assert_eq!(
360 Ok(Response::new(
361 "HTTP/1.1",
362 StatusCode::OK,
363 headers,
364 "This is just some body".as_bytes().to_vec()
365 )),
366 parser.finish()
367 );
368 }
369 #[test]
370 fn parser_parse_multiple_headers_with_body() {
371 let block =
372 "HTTP/1.1 200 OK\r\nTest-1: Value-1\r\nContent-Length: 22\r\n\r\nThis is just some body";
373 let mut parser = RespParser::new_capacity(1024);
374 assert_eq!((true, 0), parser.block_parse(block.as_bytes()));
375
376 let mut headers = Headers::new();
377 headers.set("Test-1", "Value-1");
378 headers.set("Content-Length", "22");
379 assert_eq!(
380 Ok(Response::new(
381 "HTTP/1.1",
382 StatusCode::OK,
383 headers,
384 "This is just some body".as_bytes().to_vec()
385 )),
386 parser.finish()
387 );
388 }
389 #[test]
390 fn parser_parse_multiple_headers_with_body_longer_than_told() {
391 let block =
392 "HTTP/1.1 200 OK\r\nTest-1: Value-1\r\nContent-Length: 10\r\n\r\nThis is just some body";
393 let mut parser = RespParser::new_capacity(1024);
394 assert_eq!((true, 12), parser.block_parse(block.as_bytes()));
395
396 let mut headers = Headers::new();
397 headers.set("Test-1", "Value-1");
398 headers.set("Content-Length", "10");
399 assert_eq!(
400 Ok(Response::new(
401 "HTTP/1.1",
402 StatusCode::OK,
403 headers,
404 "This is ju".as_bytes().to_vec()
405 )),
406 parser.finish()
407 );
408 }
409
410 #[test]
411 fn parser_fuzzing_bug_0() {
412 let block = vec![63, 32, 243, 13, 33, 13, 33, 242];
413 let mut parser = RespParser::new_capacity(1024);
414
415 assert_eq!((true, 1), parser.block_parse(&block));
416 assert_eq!(true, parser.finish().is_err());
418 }
419 #[test]
420 fn parser_fuzzing_bug_1() {
421 let block = vec![32, 13, 58, 13, 32, 13, 93];
422 let mut parser = RespParser::new_capacity(1024);
423
424 assert_eq!((true, 2), parser.block_parse(&block));
425 }
426 #[test]
427 fn parser_fuzzing_bug_2() {
428 let block = vec![
429 32, 15, 93, 58, 156, 156, 156, 156, 156, 156, 13, 32, 13, 58, 11, 93, 13,
430 ];
431 let mut parser = RespParser::new_capacity(1024);
432
433 assert_eq!((true, 3), parser.block_parse(&block));
434 assert_eq!(true, parser.finish().is_err());
435 }
436 #[test]
437 fn parser_fuzzing_bug_3() {
438 let block = vec![
439 32, 52, 48, 200, 169, 58, 13, 58, 222, 13, 58, 52, 48, 58, 13, 58, 222, 21, 58, 13, 58,
440 13, 29, 29, 58, 58, 43, 29, 58, 13, 13, 13, 29, 58, 9, 13,
441 ];
442 let mut parser = RespParser::new_capacity(1024);
443
444 assert_eq!((true, 3), parser.block_parse(&block));
445 assert_eq!(true, parser.finish().is_err());
446 }
447}