sierradb_server/
parser.rs

1use std::collections::HashSet;
2use std::fmt;
3
4use combine::error::{StreamError, Tracked};
5use combine::parser::choice::or;
6use combine::stream::{ResetStream, StreamErrorFor, StreamOnce};
7use combine::{ParseError, Parser, Positioned, attempt, choice, easy, satisfy_map};
8use redis_protocol::resp3::types::{BytesFrame, VerbatimStringFormat};
9use sierradb::StreamId;
10use sierradb::bucket::PartitionId;
11use sierradb_protocol::{ErrorCode, ExpectedVersion};
12use uuid::Uuid;
13
14use crate::request::{PartitionSelector, RangeValue};
15
16#[derive(Debug, PartialEq)]
17pub struct FrameStreamErrors<'a> {
18    errors: easy::Errors<&'a BytesFrame, &'a [BytesFrame], usize>,
19}
20
21impl<'a> ParseError<&'a BytesFrame, &'a [BytesFrame], usize> for FrameStreamErrors<'a> {
22    type StreamError = easy::Error<&'a BytesFrame, &'a [BytesFrame]>;
23
24    fn empty(position: usize) -> Self {
25        FrameStreamErrors {
26            errors: easy::Errors::empty(position),
27        }
28    }
29
30    fn set_position(&mut self, position: usize) {
31        self.errors.set_position(position);
32    }
33
34    fn add(&mut self, err: Self::StreamError) {
35        self.errors.add(err);
36    }
37
38    fn set_expected<F>(self_: &mut Tracked<Self>, info: Self::StreamError, f: F)
39    where
40        F: FnOnce(&mut Tracked<Self>),
41    {
42        let start = self_.error.errors.errors.len();
43        f(self_);
44        // Replace all expected errors that were added from the previous add_error
45        // with this expected error
46        let mut i = 0;
47        self_.error.errors.errors.retain(|e| {
48            if i < start {
49                i += 1;
50                true
51            } else {
52                !matches!(*e, easy::Error::Expected(_))
53            }
54        });
55        self_.error.errors.add(info);
56    }
57
58    fn is_unexpected_end_of_input(&self) -> bool {
59        self.errors.is_unexpected_end_of_input()
60    }
61
62    fn into_other<T>(self) -> T
63    where
64        T: ParseError<&'a BytesFrame, &'a [BytesFrame], usize>,
65    {
66        self.errors.into_other()
67    }
68}
69
70fn frame_kind(frame: &BytesFrame) -> &'static str {
71    match frame {
72        BytesFrame::BlobString { .. } => "string",
73        BytesFrame::BlobError { .. } => "error",
74        BytesFrame::SimpleString { .. } => "string",
75        BytesFrame::SimpleError { .. } => "error",
76        BytesFrame::Boolean { .. } => "boolean",
77        BytesFrame::Null => "null",
78        BytesFrame::Number { .. } => "number",
79        BytesFrame::Double { .. } => "double",
80        BytesFrame::BigNumber { .. } => "number",
81        BytesFrame::VerbatimString { .. } => "string",
82        BytesFrame::Array { .. } => "array",
83        BytesFrame::Map { .. } => "map",
84        BytesFrame::Set { .. } => "set",
85        BytesFrame::Push { .. } => "push",
86        BytesFrame::Hello { .. } => "hello",
87        BytesFrame::ChunkedString(_) => "chunk",
88    }
89}
90
91fn display_error<'a>(
92    f: &mut fmt::Formatter<'_>,
93    err: &easy::Error<&'a BytesFrame, &'a [BytesFrame]>,
94) -> fmt::Result {
95    match err {
96        easy::Error::Unexpected(info) => display_info(f, info),
97        easy::Error::Expected(info) => display_info(f, info),
98        easy::Error::Message(info) => display_info(f, info),
99        easy::Error::Other(error) => write!(f, "{error}"),
100    }
101}
102
103fn display_info<'a>(
104    f: &mut fmt::Formatter<'_>,
105    info: &easy::Info<&'a BytesFrame, &'a [BytesFrame]>,
106) -> fmt::Result {
107    match info {
108        easy::Info::Token(token) => write!(f, "{}", frame_kind(token)),
109        easy::Info::Range(range) => {
110            for (i, frame) in range.iter().enumerate() {
111                if i < range.len() - 1 {
112                    write!(f, "{}, ", frame_kind(frame))?;
113                } else {
114                    write!(f, "{}", frame_kind(frame))?;
115                }
116            }
117            Ok(())
118        }
119        easy::Info::Owned(msg) => write!(f, "{msg}"),
120        easy::Info::Static(msg) => write!(f, "{msg}"),
121    }
122}
123
124impl<'a> fmt::Display for FrameStreamErrors<'a> {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        write!(f, "{} ", ErrorCode::InvalidArg)?;
127
128        // First print the token that we did not expect
129        // There should really just be one unexpected message at this point though we
130        // print them all to be safe
131        let unexpected = self
132            .errors
133            .errors
134            .iter()
135            .filter(|e| matches!(**e, easy::Error::Unexpected(_)));
136        let mut has_unexpected = false;
137        for err in unexpected {
138            if !has_unexpected {
139                write!(f, "unexpected ")?;
140            }
141            has_unexpected = true;
142            display_error(f, err)?;
143        }
144
145        // Then we print out all the things that were expected in a comma separated list
146        // 'Expected 'a', 'expression' or 'let'
147        let iter = || {
148            self.errors.errors.iter().filter_map(|err| match *err {
149                easy::Error::Expected(ref err) => Some(err),
150                _ => None,
151            })
152        };
153        let expected_count = iter().count();
154        for (i, message) in iter().enumerate() {
155            if has_unexpected {
156                write!(f, ": ")?;
157                has_unexpected = false;
158            }
159            let s = match i {
160                0 => "expected",
161                _ if i < expected_count - 1 => ",",
162                // Last expected message to be written
163                _ => " or",
164            };
165            write!(f, "{s} ")?;
166            display_info(f, message)?;
167        }
168        // If there are any generic messages we print them out last
169        let messages = self
170            .errors
171            .errors
172            .iter()
173            .filter(|e| matches!(**e, easy::Error::Message(_) | easy::Error::Other(_)));
174        for (i, err) in messages.enumerate() {
175            if i == 0 && expected_count != 0 {
176                write!(f, ": ")?;
177            }
178            display_error(f, err)?;
179        }
180        Ok(())
181    }
182}
183
184// Implement the Stream trait for &[BytesFrame]
185#[derive(Clone, Debug, PartialEq)]
186pub struct FrameStream<'a> {
187    frames: &'a [BytesFrame],
188    position: usize,
189}
190
191impl<'a> StreamOnce for FrameStream<'a> {
192    type Error = FrameStreamErrors<'a>;
193    type Position = usize;
194    type Range = &'a [BytesFrame];
195    type Token = &'a BytesFrame;
196
197    fn uncons(&mut self) -> Result<Self::Token, StreamErrorFor<Self>> {
198        match self.frames.split_first() {
199            Some((first, rest)) => {
200                self.frames = rest;
201                self.position += 1;
202                Ok(first)
203            }
204            None => Err(easy::Error::end_of_input()),
205        }
206    }
207
208    fn is_partial(&self) -> bool {
209        false
210    }
211}
212
213impl<'a> ResetStream for FrameStream<'a> {
214    type Checkpoint = (usize, &'a [BytesFrame]);
215
216    fn checkpoint(&self) -> Self::Checkpoint {
217        (self.position, self.frames)
218    }
219
220    fn reset(&mut self, checkpoint: Self::Checkpoint) -> Result<(), Self::Error> {
221        self.position = checkpoint.0;
222        self.frames = checkpoint.1;
223        Ok(())
224    }
225}
226
227impl<'a> Positioned for FrameStream<'a> {
228    fn position(&self) -> Self::Position {
229        self.position
230    }
231}
232
233// Helper function to create a FrameStream
234pub fn frame_stream(frames: &'_ [BytesFrame]) -> FrameStream<'_> {
235    FrameStream {
236        frames,
237        position: 0,
238    }
239}
240
241// Basic frame parsers using regular functions instead of the parser! macro
242pub fn string<'a>() -> impl Parser<FrameStream<'a>, Output = &'a str> + 'a {
243    satisfy_map(|frame: &'a BytesFrame| match frame {
244        BytesFrame::BlobString { data, .. }
245        | BytesFrame::SimpleString { data, .. }
246        | BytesFrame::VerbatimString {
247            data,
248            format: VerbatimStringFormat::Text,
249            ..
250        } => str::from_utf8(data).ok(),
251        _ => None,
252    })
253    .expected("string")
254}
255
256pub fn data<'a>() -> impl Parser<FrameStream<'a>, Output = &'a [u8]> + 'a {
257    satisfy_map(|frame: &'a BytesFrame| match frame {
258        BytesFrame::BlobString { data, .. }
259        | BytesFrame::SimpleString { data, .. }
260        | BytesFrame::VerbatimString {
261            data,
262            format: VerbatimStringFormat::Text,
263            ..
264        } => Some(&**data),
265        _ => None,
266    })
267    .expected("string or bytes")
268}
269
270pub fn data_owned<'a>() -> impl Parser<FrameStream<'a>, Output = Vec<u8>> + 'a {
271    data().map(ToOwned::to_owned)
272}
273
274pub fn keyword<'a>(kw: &'static str) -> impl Parser<FrameStream<'a>, Output = &'a str> + 'a {
275    debug_assert_eq!(kw, kw.to_uppercase(), "keywords should be uppercase");
276    satisfy_map(move |frame: &'a BytesFrame| match frame {
277        BytesFrame::BlobString { data, .. }
278        | BytesFrame::SimpleString { data, .. }
279        | BytesFrame::VerbatimString {
280            data,
281            format: VerbatimStringFormat::Text,
282            ..
283        } => str::from_utf8(data).ok().and_then(|s| {
284            if s.to_uppercase() == kw {
285                Some(s)
286            } else {
287                None
288            }
289        }),
290        _ => None,
291    })
292    .expected(kw)
293}
294
295pub fn number_u64<'a>() -> impl Parser<FrameStream<'a>, Output = u64> + 'a {
296    satisfy_map(|frame: &'a BytesFrame| match frame {
297        BytesFrame::BlobString { data, .. }
298        | BytesFrame::SimpleString { data, .. }
299        | BytesFrame::BigNumber { data, .. }
300        | BytesFrame::VerbatimString {
301            data,
302            format: VerbatimStringFormat::Text,
303            ..
304        } => str::from_utf8(data).ok().and_then(|s| s.parse().ok()),
305        BytesFrame::Number { data, .. } => (*data).try_into().ok(),
306        _ => None,
307    })
308    .expected("number")
309}
310
311pub fn number_u64_min<'a>(min: u64) -> impl Parser<FrameStream<'a>, Output = u64> + 'a {
312    number_u64().and_then(move |n| {
313        if n < min {
314            Err(easy::Error::message_format(format!(
315                "number {n} must not be less than {min}"
316            )))
317        } else {
318            Ok(n)
319        }
320    })
321}
322
323pub fn number_i64<'a>() -> impl Parser<FrameStream<'a>, Output = i64> + 'a {
324    satisfy_map(|frame: &'a BytesFrame| match frame {
325        BytesFrame::BlobString { data, .. }
326        | BytesFrame::SimpleString { data, .. }
327        | BytesFrame::BigNumber { data, .. }
328        | BytesFrame::VerbatimString {
329            data,
330            format: VerbatimStringFormat::Text,
331            ..
332        } => str::from_utf8(data).ok().and_then(|s| s.parse().ok()),
333        BytesFrame::Number { data, .. } => Some(*data),
334        _ => None,
335    })
336    .expected("number")
337}
338
339pub fn number_u32<'a>() -> impl Parser<FrameStream<'a>, Output = u32> + 'a {
340    satisfy_map(|frame: &'a BytesFrame| match frame {
341        BytesFrame::BlobString { data, .. }
342        | BytesFrame::SimpleString { data, .. }
343        | BytesFrame::BigNumber { data, .. }
344        | BytesFrame::VerbatimString {
345            data,
346            format: VerbatimStringFormat::Text,
347            ..
348        } => str::from_utf8(data).ok().and_then(|s| s.parse().ok()),
349        BytesFrame::Number { data, .. } => (*data).try_into().ok(),
350        _ => None,
351    })
352    .expected("number")
353}
354
355pub fn partition_id<'a>() -> impl Parser<FrameStream<'a>, Output = PartitionId> + 'a {
356    satisfy_map(|frame: &'a BytesFrame| match frame {
357        BytesFrame::BlobString { data, .. }
358        | BytesFrame::SimpleString { data, .. }
359        | BytesFrame::BigNumber { data, .. }
360        | BytesFrame::VerbatimString {
361            data,
362            format: VerbatimStringFormat::Text,
363            ..
364        } => str::from_utf8(data).ok().and_then(|s| s.parse().ok()),
365        BytesFrame::Number { data, .. } => (*data).try_into().ok(),
366        _ => None,
367    })
368    .expected("partition id")
369}
370
371fn uuid<'a>(expected: &'static str) -> impl Parser<FrameStream<'a>, Output = Uuid> + 'a {
372    string()
373        .and_then(move |s| {
374            Uuid::parse_str(s.trim())
375                .map_err(|_| easy::Error::message_format(format!("invalid {expected}")))
376        })
377        .expected(expected)
378}
379
380pub fn event_id<'a>() -> impl Parser<FrameStream<'a>, Output = Uuid> + 'a {
381    uuid("event id")
382}
383
384pub fn partition_key<'a>() -> impl Parser<FrameStream<'a>, Output = Uuid> + 'a {
385    uuid("partition key")
386}
387
388pub fn subscription_id<'a>() -> impl Parser<FrameStream<'a>, Output = Uuid> + 'a {
389    uuid("subscription id")
390}
391
392pub fn expected_version<'a>() -> impl Parser<FrameStream<'a>, Output = ExpectedVersion> + 'a {
393    let exact = number_u64().map(ExpectedVersion::Exact);
394
395    let keyword = choice((
396        keyword("ANY").map(|_| ExpectedVersion::Any),
397        keyword("EXISTS").map(|_| ExpectedVersion::Exists),
398        keyword("EMPTY").map(|_| ExpectedVersion::Empty),
399    ));
400
401    exact
402        .or(keyword)
403        .message("expected version number or 'any', 'exists', 'empty'")
404}
405
406pub fn range_value<'a>() -> impl Parser<FrameStream<'a>, Output = RangeValue> + 'a {
407    choice!(
408        keyword("-").map(|_| RangeValue::Start),
409        keyword("+").map(|_| RangeValue::End),
410        number_u64().map(RangeValue::Value)
411    )
412    .expected("range value (-, +, or number)")
413}
414
415pub fn partition_selector<'a>() -> impl Parser<FrameStream<'a>, Output = PartitionSelector> + 'a {
416    or(
417        attempt(partition_key().map(PartitionSelector::ByKey)),
418        partition_id().map(PartitionSelector::ById),
419    )
420    .expected("partition id or key")
421}
422
423pub fn all_selector<'a>() -> impl Parser<FrameStream<'a>, Output = char> + 'a {
424    keyword("*").map(|_| '*')
425}
426
427pub fn partition_ids<'a>() -> impl Parser<FrameStream<'a>, Output = HashSet<PartitionId>> + 'a {
428    satisfy_map(|frame: &'a BytesFrame| {
429        match frame {
430            BytesFrame::BlobString { data, .. }
431            | BytesFrame::SimpleString { data, .. }
432            | BytesFrame::VerbatimString {
433                data,
434                format: VerbatimStringFormat::Text,
435                ..
436            } => {
437                let data = str::from_utf8(data).ok()?;
438
439                // Otherwise try comma-separated list
440                data.split(',')
441                    .map(|part| part.trim().parse::<PartitionId>())
442                    .collect::<Result<_, _>>()
443                    .ok()
444            }
445            BytesFrame::BigNumber { data, .. } => {
446                // Handle Number frames directly
447                let id: PartitionId = str::from_utf8(data).ok()?.parse().ok()?;
448                Some(HashSet::from_iter([id]))
449            }
450            BytesFrame::Number { data, .. } => {
451                // Handle Number frames directly
452                let id: PartitionId = (*data).try_into().ok()?;
453                Some(HashSet::from_iter([id]))
454            }
455            _ => None,
456        }
457    })
458    .expected("comma-separated partition ids")
459}
460
461// <p1>=<s1>
462pub fn partition_id_sequence<'a>() -> impl Parser<FrameStream<'a>, Output = (PartitionId, u64)> + 'a
463{
464    satisfy_map(|frame: &'a BytesFrame| match frame {
465        BytesFrame::BlobString { data, .. }
466        | BytesFrame::SimpleString { data, .. }
467        | BytesFrame::VerbatimString {
468            data,
469            format: VerbatimStringFormat::Text,
470            ..
471        } => {
472            let (partition_id, sequence) = str::from_utf8(data).ok()?.split_once('=')?;
473            let partition_id: PartitionId = partition_id.parse().ok()?;
474            let sequence: u64 = sequence.parse().ok()?;
475            Some((partition_id, sequence))
476        }
477        _ => None,
478    })
479    .expected("partition id sequence value")
480}
481
482// <s1>=<v1>
483pub fn stream_id_version<'a>() -> impl Parser<FrameStream<'a>, Output = (StreamId, u64)> + 'a {
484    string()
485        .and_then(|s| {
486            let (stream_id, version) = s
487                .split_once('=')
488                .ok_or_else(|| easy::Error::message_format("missing `=` in stream id version"))?;
489            let stream_id = StreamId::new(stream_id).map_err(easy::Error::message_format)?;
490            let version: u64 = version
491                .parse()
492                .map_err(|_| easy::Error::message_format("invalid stream id version number"))?;
493            Ok::<_, easy::Error<_, _>>((stream_id, version))
494        })
495        .expected("stream id version value")
496}
497
498pub fn stream_id<'a>() -> impl Parser<FrameStream<'a>, Output = StreamId> + 'a {
499    string()
500        .and_then(|s| StreamId::new(s).map_err(easy::Error::message_format))
501        .expected("stream id")
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507
508    #[test]
509    fn test_string_parser() {
510        let frames = vec![BytesFrame::SimpleString {
511            data: b"GET".to_vec().into(),
512            attributes: None,
513        }];
514
515        let stream = frame_stream(&frames);
516        let result = string().parse(stream);
517        assert!(result.is_ok());
518        let (parsed, _) = result.unwrap();
519        assert_eq!(parsed, "GET");
520    }
521
522    #[test]
523    fn test_keyword_parser() {
524        let frames = vec![BytesFrame::SimpleString {
525            data: b"PARTITION_KEY".to_vec().into(),
526            attributes: None,
527        }];
528
529        let stream = frame_stream(&frames);
530        let (parsed, _) = keyword("PARTITION_KEY").parse(stream).unwrap();
531        assert_eq!(parsed, "PARTITION_KEY");
532    }
533
534    #[test]
535    fn test_partition_id_parser() {
536        let frames = vec![BytesFrame::Number {
537            data: 10,
538            attributes: None,
539        }];
540
541        let stream = frame_stream(&frames);
542        let (parsed, _) = partition_id().parse(stream).unwrap();
543        assert_eq!(parsed, 10);
544
545        let frames = vec![BytesFrame::SimpleString {
546            data: b"10".to_vec().into(),
547            attributes: None,
548        }];
549
550        let stream = frame_stream(&frames);
551        let (parsed, _) = partition_id().parse(stream).unwrap();
552        assert_eq!(parsed, 10);
553    }
554
555    #[test]
556    fn test_partition_ids_parser() {
557        let frames = vec![BytesFrame::Number {
558            data: 10,
559            attributes: None,
560        }];
561
562        let stream = frame_stream(&frames);
563        let (parsed, _) = partition_ids().parse(stream).unwrap();
564        assert_eq!(parsed, HashSet::from_iter([10]));
565
566        let frames = vec![BytesFrame::SimpleString {
567            data: b"10".to_vec().into(),
568            attributes: None,
569        }];
570
571        let stream = frame_stream(&frames);
572        let (parsed, _) = partition_ids().parse(stream).unwrap();
573        assert_eq!(parsed, HashSet::from_iter([10]));
574
575        let frames = vec![BytesFrame::SimpleString {
576            data: b"10,59,24".to_vec().into(),
577            attributes: None,
578        }];
579
580        let stream = frame_stream(&frames);
581        let (parsed, _) = partition_ids().parse(stream).unwrap();
582        assert_eq!(parsed, HashSet::from_iter([10, 59, 24]));
583    }
584}