sawp_resp/
lib.rs

1//! A RESP protocol parser. Given bytes and a [`sawp::parser::Direction`], it will
2//! attempt to parse the bytes and return a [`Message`]. The parser will
3//! inform the caller about what went wrong if no message is returned (see [`sawp::parser::Parse`]
4//! for details on possible return types).
5//!
6//! The following protocol references were used to create this module:
7//!
8//! [RESP Protocol Specification](https://redis.io/topics/protocol)
9//!
10//! # Example
11//! ```
12//! use sawp::parser::{Direction, Parse};
13//! use sawp::error::Error;
14//! use sawp::error::ErrorKind;
15//! use sawp_resp::{Resp, Message};
16//!
17//! fn parse_bytes(input: &[u8]) -> std::result::Result<&[u8], Error> {
18//!     let resp = Resp {};
19//!     let mut bytes = input;
20//!     while bytes.len() > 0 {
21//!         // If we know that this is a request or response, change the Direction
22//!         // for a more accurate parsing
23//!         match resp.parse(bytes, Direction::Unknown) {
24//!             // The parser succeeded and returned the remaining bytes and the parsed RESP message
25//!             Ok((rest, Some(message))) => {
26//!                 println!("Resp message: {:?}", message);
27//!                 bytes = rest;
28//!             }
29//!             // The parser recognized that this might be RESP and made some progress,
30//!             // but more bytes are needed
31//!             Ok((rest, None)) => return Ok(rest),
32//!             // The parser was unable to determine whether this was RESP or not and more
33//!             // bytes are needed
34//!             Err(Error { kind: ErrorKind::Incomplete(_) }) => return Ok(bytes),
35//!             // The parser determined that this was not RESP
36//!             Err(e) => return Err(e)
37//!         }
38//!     }
39//!
40//!     Ok(bytes)
41//! }
42//! ```
43
44use sawp::error::Result;
45use sawp::parser::{Direction, Parse};
46use sawp::probe::{Probe, Status};
47use sawp::protocol::Protocol;
48use sawp_flags::{BitFlags, Flag, Flags};
49
50use nom::bytes::streaming::{take, take_until};
51use nom::character::streaming::crlf;
52use nom::number::streaming::be_u8;
53use nom::{AsBytes, FindToken, InputTakeAtPosition};
54
55use num_enum::TryFromPrimitive;
56
57use std::convert::TryFrom;
58
59/// FFI structs and Accessors
60#[cfg(feature = "ffi")]
61mod ffi;
62
63#[cfg(feature = "ffi")]
64use sawp_ffi::GenerateFFI;
65
66pub const CRLF: &[u8] = b"\r\n";
67pub const DATA_TYPE_TOKENS: &str = "$*+-:";
68pub const MAX_ARRAY_DEPTH: usize = 64;
69/// Bulk strings should not exceed 512 MB in length.
70pub const MAX_BULK_STRING_LEN: usize = 1024 * 512;
71
72/// Error flags raised while parsing RESP - to be used in the returned Message
73#[repr(u8)]
74#[derive(Clone, Copy, Debug, PartialEq, Eq, BitFlags)]
75pub enum ErrorFlags {
76    /// Malformed data including invalid type tokens, invalid integers,
77    /// or improperly formatted RESP has been parsed.
78    InvalidData = 0b0000_0001,
79    /// The length of a bulk string exceeds the specification-defined max 1024*512 bytes.
80    /// SAWP will try to return the whole string.
81    BulkStringExceedsMaxLen = 0b0000_0010,
82    /// An array of arrays with > MAX_ARRAY_DEPTH depth was found. Message will truncate
83    /// at the limit but futher bytes WILL NOT be consumed.
84    MaxArrayDepthReached = 0b0000_0100,
85}
86
87/// RESP signals data types by prepending these one-character tokens
88#[derive(Clone, Copy, Debug, PartialEq, Eq, TryFromPrimitive)]
89#[repr(u8)]
90pub enum DataTypeToken {
91    /// A single binary-safe string up to 512 MB in length. Precedes a CRLF-terminated length describing the following
92    /// string. Can also be used to signal non-existence of a value using a special format that is used to represent a
93    /// null value i.e. "$-1\r\n". Note the 'empty' string still CRLF-terminates unlike the null value
94    /// i.e. "$0\r\n\r\n".
95    BulkString = b'$',
96    /// Clients send commands to the Redis server using RESP arrays. Servers can also return collections of elements
97    /// with arrays. Precedes a CRLF-terminated length describing the following array. Empty type: "*0\r\n". Note
98    /// that single elements to the array may be null, in which case the element will look like the null value as
99    /// described in BulkString
100    Array = b'*',
101    /// Used to transmit non-binary-safe strings. These strings cannot contain '\r' or '\n' and are terminated by
102    /// "\r\n".
103    SimpleString = b'+',
104    /// Exactly like simple strings but errors should be treated by the client as exceptions.
105    Error = b'-',
106    /// Like simple strings, but representing an integer.
107    Integer = b':',
108    Unknown,
109}
110
111impl DataTypeToken {
112    pub fn from_raw(val: u8) -> Self {
113        DataTypeToken::try_from(val).unwrap_or(DataTypeToken::Unknown)
114    }
115}
116
117/// Entry types to return in the parsed message
118#[cfg_attr(feature = "ffi", derive(GenerateFFI))]
119#[cfg_attr(feature = "ffi", sawp_ffi(prefix = "sawp_resp"))]
120#[derive(Debug, PartialEq, Eq)]
121pub enum Entry {
122    /// Arrays of entries
123    Array(Vec<Entry>),
124    /// The same as a String in practice but it may be useful to differentiate.
125    Error(Vec<u8>),
126    /// Integers
127    Integer(i64),
128    /// Invalid Data: a special type used here to return the data that would otherwise be lost after a recoverable parsing failure.
129    /// This data will not be structured and is subject to interpretation.
130    Invalid(Vec<u8>),
131    /// The null value. Used to indicate a requested resource doesn't exist.
132    /// Client libraries are supposed to return a "nil/null object" depending on the implementation language's preferred word.
133    Nil,
134    /// Simple Strings and Bulk Strings
135    String(Vec<u8>),
136}
137
138pub enum IntegerResult<'a> {
139    Integer(i64),
140    Data(&'a [u8]),
141}
142
143pub enum StringResult<'a> {
144    String(&'a [u8]),
145    Nil,
146    Invalid(&'a [u8], &'a [u8]), // length/data pair
147}
148
149/// Breakdown of the parsed resp bytes
150#[cfg_attr(feature = "ffi", derive(GenerateFFI))]
151#[cfg_attr(feature = "ffi", sawp_ffi(prefix = "sawp_resp"))]
152#[derive(Debug, PartialEq, Eq)]
153pub struct Message {
154    pub entry: Entry,
155    #[cfg_attr(feature = "ffi", sawp_ffi(flag = "u8"))]
156    pub error_flags: Flags<ErrorFlags>,
157}
158
159impl Message {}
160
161#[derive(Debug)]
162pub struct Resp {}
163
164impl<'a> Protocol<'a> for Resp {
165    type Message = Message;
166
167    fn name() -> &'static str {
168        "resp"
169    }
170}
171
172impl<'a> Probe<'a> for Resp {
173    /// Probes the input to recognize if the underlying bytes likely match this
174    /// protocol.
175    ///
176    /// Returns a probe status. Probe again once more data is available when the
177    /// status is `Status::Incomplete`.
178    fn probe(&self, input: &'a [u8], direction: Direction) -> Status {
179        match self.parse(input, direction) {
180            Ok((
181                _,
182                Some(Message {
183                    entry: Entry::Invalid(_),
184                    error_flags: _,
185                }),
186            )) => Status::Unrecognized, // If the only message is Invalid it is probably not RESP
187            Ok(_) => Status::Recognized,
188            Err(sawp::error::Error {
189                kind: sawp::error::ErrorKind::Incomplete(_),
190            }) => Status::Incomplete,
191            Err(_) => Status::Unrecognized,
192        }
193    }
194}
195
196impl Resp {
197    fn advance_if_crlf(input: &[u8]) -> &[u8] {
198        crlf::<_, (&[u8], nom::error::ErrorKind)>(input)
199            .map(|(rem, _)| rem)
200            .unwrap_or(input)
201    }
202
203    fn parse_integer(input: &[u8]) -> Result<(&[u8], IntegerResult, Flags<ErrorFlags>)> {
204        let (rem, raw_len) = take_until(CRLF)(input)?;
205        // We don't know how long ret is but it is supposed to be valid text.
206        match std::str::from_utf8(raw_len) {
207            Ok(len_str) => match len_str.parse::<i64>() {
208                Ok(len) => Ok((
209                    Resp::advance_if_crlf(rem),
210                    IntegerResult::Integer(len),
211                    ErrorFlags::none(),
212                )),
213                Err(_) => Ok((
214                    Resp::advance_if_crlf(rem),
215                    IntegerResult::Data(raw_len),
216                    ErrorFlags::InvalidData.into(),
217                )),
218            },
219            Err(_) => Ok((
220                Resp::advance_if_crlf(rem),
221                IntegerResult::Data(raw_len),
222                ErrorFlags::InvalidData.into(),
223            )),
224        }
225    }
226
227    /// The TypeOr looks a bit complicated but essentially means that if there are no errors we just return the parsed string data TypeOr::Left.
228    /// If there is an error with the integer calculation we'll try to return the <integer data, string data> TyperOr::Right.
229    /// Indicates a nil entry return for the caller via the returned bool
230    fn parse_bulk_string(input: &[u8]) -> Result<(&[u8], StringResult, Flags<ErrorFlags>)> {
231        let (rem, wrapped_length, mut error_flags) = Resp::parse_integer(input)?;
232        match wrapped_length {
233            IntegerResult::Integer(length) => {
234                if length >= 0 {
235                    if length > MAX_BULK_STRING_LEN as i64 {
236                        error_flags |= ErrorFlags::BulkStringExceedsMaxLen
237                    }
238                    let (rem, ret) = take(length as usize)(rem)?;
239                    // The standard states that even bulk strings should end with CRLF, but it may not be strictly necessary based on implementation?
240                    Ok((
241                        Resp::advance_if_crlf(rem),
242                        StringResult::String(ret),
243                        error_flags,
244                    ))
245                } else {
246                    // Whether the result is the NULL string (-1 length) or some negative number, we can pass an "empty" result back with the inner error_flags and let the caller handle it.
247                    if length == -1 {
248                        return Ok((Resp::advance_if_crlf(rem), StringResult::Nil, error_flags));
249                    }
250                    error_flags |= ErrorFlags::InvalidData;
251                    Ok((
252                        Resp::advance_if_crlf(rem),
253                        StringResult::String(b""),
254                        error_flags,
255                    ))
256                }
257            }
258            IntegerResult::Data(bytes) => Ok((
259                Resp::advance_if_crlf(rem),
260                StringResult::Invalid(bytes, b""),
261                error_flags,
262            )),
263        }
264    }
265
266    fn parse_simple_string(input: &[u8]) -> Result<(&[u8], &[u8])> {
267        let (rem, ret) = take_until(CRLF)(input)?;
268        // Remove the CRLF from remaining bytes
269        Ok((Resp::advance_if_crlf(rem), ret))
270    }
271
272    fn parse_entry(input: &[u8], array_depth: usize) -> Result<(&[u8], Entry, Flags<ErrorFlags>)> {
273        let (input, raw_token) = be_u8(input)?;
274        let token = DataTypeToken::from_raw(raw_token);
275        match token {
276            DataTypeToken::BulkString => {
277                let (rem, parsed_data, error_flags) = Resp::parse_bulk_string(input)?;
278                match parsed_data {
279                    StringResult::String(string_data) => {
280                        Ok((rem, Entry::String(string_data.to_vec()), error_flags))
281                    }
282                    StringResult::Nil => Ok((rem, Entry::Nil, error_flags)),
283                    StringResult::Invalid(len, data) => {
284                        Ok((rem, Entry::Invalid([len, data].concat()), error_flags))
285                    }
286                }
287            }
288            DataTypeToken::Array => {
289                if array_depth < MAX_ARRAY_DEPTH {
290                    let (mut local_input, length, mut error_flags) = Resp::parse_integer(input)?;
291                    match length {
292                        IntegerResult::Integer(length) if length >= 0 => {
293                            let mut entries: Vec<Entry> = Vec::with_capacity(length as usize);
294
295                            for _ in 0..length {
296                                let (rem, entry, inner_error_flags) =
297                                    Resp::parse_entry(local_input, array_depth + 1)?;
298                                error_flags |= inner_error_flags;
299                                if error_flags.contains(ErrorFlags::MaxArrayDepthReached) {
300                                    return Ok((input, Entry::Array(entries), error_flags));
301                                }
302                                entries.push(entry);
303                                local_input = rem;
304                            }
305                            Ok((local_input, Entry::Array(entries), error_flags))
306                        }
307                        IntegerResult::Integer(-1) => Ok((local_input, Entry::Nil, error_flags)),
308                        IntegerResult::Integer(_length) => {
309                            error_flags |= ErrorFlags::InvalidData;
310                            Ok((
311                                Resp::advance_if_crlf(local_input),
312                                Entry::Array(vec![]),
313                                error_flags,
314                            ))
315                        }
316                        IntegerResult::Data(invalid_length) => Ok((
317                            Resp::advance_if_crlf(local_input),
318                            Entry::Invalid(
319                                [b"*", invalid_length].concat(), // include the token character in the returned value.
320                            ),
321                            error_flags,
322                        )),
323                    }
324                } else {
325                    Ok((
326                        input,
327                        Entry::Invalid(vec![]),
328                        ErrorFlags::MaxArrayDepthReached.into(),
329                    ))
330                }
331            }
332            DataTypeToken::SimpleString => {
333                let (rem, ret) = Resp::parse_simple_string(input)?;
334                Ok((rem, Entry::String(ret.to_vec()), ErrorFlags::none()))
335            }
336            DataTypeToken::Error => {
337                let (rem, ret) = Resp::parse_simple_string(input)?;
338                Ok((
339                    rem,
340                    Entry::Error(ret.as_bytes().to_vec()),
341                    ErrorFlags::none(),
342                ))
343            }
344            DataTypeToken::Integer => {
345                let (rem, ret, error_flags) = Resp::parse_integer(input)?;
346                match ret {
347                    IntegerResult::Integer(ret) => Ok((rem, Entry::Integer(ret), error_flags)),
348                    IntegerResult::Data(ret) => Ok((
349                        rem,
350                        Entry::Invalid([b":", ret].concat()), // include the token character in the returned value.
351                        error_flags,
352                    )),
353                }
354            }
355            DataTypeToken::Unknown => {
356                // Advance to the next possible data type token, returning the "in-between" as InvalidData
357                // Note we should include the first character in the returned value.
358                let (rem, data) =
359                    input.split_at_position_complete(|e: u8| DATA_TYPE_TOKENS.find_token(e))?;
360                Ok((
361                    Resp::advance_if_crlf(rem),
362                    Entry::Invalid([&[raw_token], data].concat()),
363                    ErrorFlags::InvalidData.into(),
364                ))
365            }
366        }
367    }
368}
369
370/// Returns ErrorKind::Incomplete if more data is needed.
371/// If part of the message was parsed successfully will attempt to return a partial message
372/// with an appropriate error_flags field indicating what went wrong.
373impl<'a> Parse<'a> for Resp {
374    fn parse(
375        &self,
376        input: &'a [u8],
377        _direction: Direction,
378    ) -> Result<(&'a [u8], Option<Self::Message>)> {
379        let (rem, entry, error_flags) = Resp::parse_entry(input, 0)?;
380
381        Ok((rem, Some(Message { entry, error_flags })))
382    }
383}
384
385#[cfg(test)]
386mod test {
387    use crate::{Entry, ErrorFlags, Message, Resp};
388    use rstest::rstest;
389    use sawp::error::Result;
390    use sawp::parser::{Direction, Parse};
391    use sawp_flags::Flag;
392
393    #[rstest(
394    input,
395    expected,
396    case::parse_simple_string(
397        b"+OK\r\n",
398        Ok((
399            0,
400            Some(
401                Message {
402                    entry: Entry::String(b"OK".to_vec()),
403                    error_flags: ErrorFlags::none(),
404                }
405            )
406        ))
407    ),
408    case::parse_error(
409        b"-Error message\r\n",
410        Ok((
411        0,
412            Some(
413                Message {
414                    entry: Entry::Error(b"Error message".to_vec()),
415                    error_flags: ErrorFlags::none(),
416                }
417            )
418        ))
419    ),
420    case::parse_integer(
421        b":1000\r\n",
422        Ok((
423            0,
424            Some(
425                Message {
426                    entry: Entry::Integer(1000),
427                    error_flags: ErrorFlags::none(),
428                }
429            )
430        ))
431    ),
432    case::parse_bulk_string(
433        b"$6\r\nfoobar\r\n",
434        Ok((
435            0,
436            Some(
437                Message {
438                    entry: Entry::String(b"foobar".to_vec()),
439                    error_flags: ErrorFlags::none(),
440                }
441            )
442        ))
443    ),
444    case::parse_array(
445        b"*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",
446        Ok((
447            0,
448            Some(
449                Message {
450                    entry: Entry::Array(vec!(
451                        Entry::String(b"foo".to_vec()),
452                        Entry::String(b"bar".to_vec()),
453                    )),
454                    error_flags: ErrorFlags::none(),
455                }
456            )
457        ))
458    ),
459    case::parse_null_value_array(
460        b"*-1\r\n",
461        Ok((
462            0,
463            Some(
464                Message {
465                    entry: Entry::Nil,
466                    error_flags: ErrorFlags::none(),
467                }
468            )
469        ))
470    ),
471    case::invalid_negative_array_length(
472        b"*-2\r\n",
473        Ok((
474            0,
475            Some(
476                Message {
477                    entry: Entry::Array(vec![]),
478                    error_flags: ErrorFlags::InvalidData.into(),
479                }
480            )
481        ))
482    ),
483    case::parse_nested_array(
484        b"*1\r\n*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",
485        Ok((
486            0,
487            Some(
488                Message {
489                    entry:
490                        Entry::Array(vec!(
491                            Entry::Array(vec!(
492                                Entry::String(b"foo".to_vec()),
493                                Entry::String(b"bar".to_vec()),
494                            )),
495                        ),
496                    ),
497                    error_flags: ErrorFlags::none(),
498                }
499            )
500        ))
501    ),
502    case::parse_empty_array(
503        b"*0\r\n",
504        Ok((
505            0,
506            Some(
507                Message {
508                    entry: Entry::Array(vec![]),
509                    error_flags: ErrorFlags::none(),
510                }
511            )
512        ))
513    ),
514    case::nested_array_exceeds_max_depth(
515    b"*2\r\n$3\r\nfoo\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n\
516    *1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n\
517    *1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n\
518    *1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n\
519    *1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n", // array depth 65
520    Ok((
521        268,
522        Some(
523            Message {
524                entry:
525                Entry::Array(vec![
526                Entry::String(b"foo".to_vec()),
527                ]),
528                error_flags: ErrorFlags::MaxArrayDepthReached.into(),
529            }
530        )
531    ))
532    ),
533    case::parse_empty_bulk_string_with_trailing_negative_int(
534        b"*2\r\n$0\r\n\r\n:-100\r\n",
535        Ok((
536            0,
537            Some(
538                Message {
539                    entry: Entry::Array(vec![
540                        Entry::String(b"".to_vec()),
541                        Entry::Integer(-100),
542                    ]
543                ),
544                error_flags: ErrorFlags::none(),
545                }
546        )
547        ))
548    ),
549    case::parse_null_value_string(
550        b"$-1\r\n",
551        Ok((
552            0,
553            Some(
554                Message {
555                    entry: Entry::Nil,
556                    error_flags: ErrorFlags::none(),
557                }
558            )
559        ))
560    ),
561    case::invalid_negative_bulk_string_length(
562        b"$-2\r\n",
563        Ok((
564            0,
565            Some(
566                Message {
567                    entry: Entry::String(b"".to_vec()),
568                    error_flags: ErrorFlags::InvalidData.into(),
569                }
570            )
571        ))
572    ),
573    case::invalid_type_token(
574    b"!1\r\n",
575    Ok((
576        0,
577        Some(
578            Message {
579                entry: Entry::Invalid(b"!1\r\n".to_vec()),
580                error_flags: ErrorFlags::InvalidData.into(),
581            }
582        )
583    ))
584    ),
585    case::invalid_type_token_mixed_with_good_data(
586        b"!1\r\n$6\r\nfoobar\r\n",
587        Ok((
588            12,
589            Some(
590                Message {
591                    entry: Entry::Invalid(b"!1\r\n".to_vec()),
592                    error_flags: ErrorFlags::InvalidData.into(),
593                }
594            )
595        ))
596    ),
597    case::missing_type_token(
598        b"1\r\n$6\r\nfoobar\r\n",
599        Ok((
600            12,
601            Some(
602                Message {
603                    entry: Entry::Invalid(b"1\r\n".to_vec()),
604                    error_flags: ErrorFlags::InvalidData.into(),
605                }
606            )
607        ))
608    ),
609    case::parse_too_big_integer(
610        b":9223372036854775808\r\n", // int64 max + 1
611        Ok((
612            0,
613            Some(
614                Message {
615                    entry: Entry::Invalid(b":9223372036854775808".to_vec()),
616                    error_flags: ErrorFlags::InvalidData.into(),
617                }
618            )
619        ))
620    ),
621    case::parse_too_small_integer(
622        b":-9223372036854775809\r\n", // int64 min - 1
623        Ok((
624            0,
625            Some(
626                Message {
627                    entry: Entry::Invalid(b":-9223372036854775809".to_vec()),
628                    error_flags: ErrorFlags::InvalidData.into(),
629                }
630            )
631        ))
632    ),
633    case::parse_invalid_integer(
634    b":cats\r\n",
635    Ok((
636        0,
637        Some(
638            Message {
639                entry: Entry::Invalid(b":cats".to_vec()),
640                error_flags: ErrorFlags::InvalidData.into(),
641            }
642        )
643    ))
644    ),
645    )]
646    fn resp(input: &[u8], expected: Result<(usize, Option<Message>)>) {
647        let resp = Resp {};
648        assert_eq!(
649            resp.parse(input, Direction::Unknown)
650                .map(|(rem, msg)| (rem.len(), msg)),
651            expected
652        );
653    }
654}