1use std::collections::HashSet;
2use std::fmt;
3
4use combine::error::{StreamError, Tracked};
5use combine::parser::choice::or;
6use combine::stream::{ResetStream, StreamErrorFor, StreamOnce};
7use combine::{ParseError, Parser, Positioned, attempt, choice, easy, satisfy_map};
8use redis_protocol::resp3::types::{BytesFrame, VerbatimStringFormat};
9use sierradb::StreamId;
10use sierradb::bucket::PartitionId;
11use sierradb_protocol::{ErrorCode, ExpectedVersion};
12use uuid::Uuid;
13
14use crate::request::{PartitionSelector, RangeValue};
15
16#[derive(Debug, PartialEq)]
17pub struct FrameStreamErrors<'a> {
18 errors: easy::Errors<&'a BytesFrame, &'a [BytesFrame], usize>,
19}
20
21impl<'a> ParseError<&'a BytesFrame, &'a [BytesFrame], usize> for FrameStreamErrors<'a> {
22 type StreamError = easy::Error<&'a BytesFrame, &'a [BytesFrame]>;
23
24 fn empty(position: usize) -> Self {
25 FrameStreamErrors {
26 errors: easy::Errors::empty(position),
27 }
28 }
29
30 fn set_position(&mut self, position: usize) {
31 self.errors.set_position(position);
32 }
33
34 fn add(&mut self, err: Self::StreamError) {
35 self.errors.add(err);
36 }
37
38 fn set_expected<F>(self_: &mut Tracked<Self>, info: Self::StreamError, f: F)
39 where
40 F: FnOnce(&mut Tracked<Self>),
41 {
42 let start = self_.error.errors.errors.len();
43 f(self_);
44 let mut i = 0;
47 self_.error.errors.errors.retain(|e| {
48 if i < start {
49 i += 1;
50 true
51 } else {
52 !matches!(*e, easy::Error::Expected(_))
53 }
54 });
55 self_.error.errors.add(info);
56 }
57
58 fn is_unexpected_end_of_input(&self) -> bool {
59 self.errors.is_unexpected_end_of_input()
60 }
61
62 fn into_other<T>(self) -> T
63 where
64 T: ParseError<&'a BytesFrame, &'a [BytesFrame], usize>,
65 {
66 self.errors.into_other()
67 }
68}
69
70fn frame_kind(frame: &BytesFrame) -> &'static str {
71 match frame {
72 BytesFrame::BlobString { .. } => "string",
73 BytesFrame::BlobError { .. } => "error",
74 BytesFrame::SimpleString { .. } => "string",
75 BytesFrame::SimpleError { .. } => "error",
76 BytesFrame::Boolean { .. } => "boolean",
77 BytesFrame::Null => "null",
78 BytesFrame::Number { .. } => "number",
79 BytesFrame::Double { .. } => "double",
80 BytesFrame::BigNumber { .. } => "number",
81 BytesFrame::VerbatimString { .. } => "string",
82 BytesFrame::Array { .. } => "array",
83 BytesFrame::Map { .. } => "map",
84 BytesFrame::Set { .. } => "set",
85 BytesFrame::Push { .. } => "push",
86 BytesFrame::Hello { .. } => "hello",
87 BytesFrame::ChunkedString(_) => "chunk",
88 }
89}
90
91fn display_error<'a>(
92 f: &mut fmt::Formatter<'_>,
93 err: &easy::Error<&'a BytesFrame, &'a [BytesFrame]>,
94) -> fmt::Result {
95 match err {
96 easy::Error::Unexpected(info) => display_info(f, info),
97 easy::Error::Expected(info) => display_info(f, info),
98 easy::Error::Message(info) => display_info(f, info),
99 easy::Error::Other(error) => write!(f, "{error}"),
100 }
101}
102
103fn display_info<'a>(
104 f: &mut fmt::Formatter<'_>,
105 info: &easy::Info<&'a BytesFrame, &'a [BytesFrame]>,
106) -> fmt::Result {
107 match info {
108 easy::Info::Token(token) => write!(f, "{}", frame_kind(token)),
109 easy::Info::Range(range) => {
110 for (i, frame) in range.iter().enumerate() {
111 if i < range.len() - 1 {
112 write!(f, "{}, ", frame_kind(frame))?;
113 } else {
114 write!(f, "{}", frame_kind(frame))?;
115 }
116 }
117 Ok(())
118 }
119 easy::Info::Owned(msg) => write!(f, "{msg}"),
120 easy::Info::Static(msg) => write!(f, "{msg}"),
121 }
122}
123
124impl<'a> fmt::Display for FrameStreamErrors<'a> {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 write!(f, "{} ", ErrorCode::InvalidArg)?;
127
128 let unexpected = self
132 .errors
133 .errors
134 .iter()
135 .filter(|e| matches!(**e, easy::Error::Unexpected(_)));
136 let mut has_unexpected = false;
137 for err in unexpected {
138 if !has_unexpected {
139 write!(f, "unexpected ")?;
140 }
141 has_unexpected = true;
142 display_error(f, err)?;
143 }
144
145 let iter = || {
148 self.errors.errors.iter().filter_map(|err| match *err {
149 easy::Error::Expected(ref err) => Some(err),
150 _ => None,
151 })
152 };
153 let expected_count = iter().count();
154 for (i, message) in iter().enumerate() {
155 if has_unexpected {
156 write!(f, ": ")?;
157 has_unexpected = false;
158 }
159 let s = match i {
160 0 => "expected",
161 _ if i < expected_count - 1 => ",",
162 _ => " or",
164 };
165 write!(f, "{s} ")?;
166 display_info(f, message)?;
167 }
168 let messages = self
170 .errors
171 .errors
172 .iter()
173 .filter(|e| matches!(**e, easy::Error::Message(_) | easy::Error::Other(_)));
174 for (i, err) in messages.enumerate() {
175 if i == 0 && expected_count != 0 {
176 write!(f, ": ")?;
177 }
178 display_error(f, err)?;
179 }
180 Ok(())
181 }
182}
183
184#[derive(Clone, Debug, PartialEq)]
186pub struct FrameStream<'a> {
187 frames: &'a [BytesFrame],
188 position: usize,
189}
190
191impl<'a> StreamOnce for FrameStream<'a> {
192 type Error = FrameStreamErrors<'a>;
193 type Position = usize;
194 type Range = &'a [BytesFrame];
195 type Token = &'a BytesFrame;
196
197 fn uncons(&mut self) -> Result<Self::Token, StreamErrorFor<Self>> {
198 match self.frames.split_first() {
199 Some((first, rest)) => {
200 self.frames = rest;
201 self.position += 1;
202 Ok(first)
203 }
204 None => Err(easy::Error::end_of_input()),
205 }
206 }
207
208 fn is_partial(&self) -> bool {
209 false
210 }
211}
212
213impl<'a> ResetStream for FrameStream<'a> {
214 type Checkpoint = (usize, &'a [BytesFrame]);
215
216 fn checkpoint(&self) -> Self::Checkpoint {
217 (self.position, self.frames)
218 }
219
220 fn reset(&mut self, checkpoint: Self::Checkpoint) -> Result<(), Self::Error> {
221 self.position = checkpoint.0;
222 self.frames = checkpoint.1;
223 Ok(())
224 }
225}
226
227impl<'a> Positioned for FrameStream<'a> {
228 fn position(&self) -> Self::Position {
229 self.position
230 }
231}
232
233pub fn frame_stream(frames: &'_ [BytesFrame]) -> FrameStream<'_> {
235 FrameStream {
236 frames,
237 position: 0,
238 }
239}
240
241pub fn string<'a>() -> impl Parser<FrameStream<'a>, Output = &'a str> + 'a {
243 satisfy_map(|frame: &'a BytesFrame| match frame {
244 BytesFrame::BlobString { data, .. }
245 | BytesFrame::SimpleString { data, .. }
246 | BytesFrame::VerbatimString {
247 data,
248 format: VerbatimStringFormat::Text,
249 ..
250 } => str::from_utf8(data).ok(),
251 _ => None,
252 })
253 .expected("string")
254}
255
256pub fn data<'a>() -> impl Parser<FrameStream<'a>, Output = &'a [u8]> + 'a {
257 satisfy_map(|frame: &'a BytesFrame| match frame {
258 BytesFrame::BlobString { data, .. }
259 | BytesFrame::SimpleString { data, .. }
260 | BytesFrame::VerbatimString {
261 data,
262 format: VerbatimStringFormat::Text,
263 ..
264 } => Some(&**data),
265 _ => None,
266 })
267 .expected("string or bytes")
268}
269
270pub fn data_owned<'a>() -> impl Parser<FrameStream<'a>, Output = Vec<u8>> + 'a {
271 data().map(ToOwned::to_owned)
272}
273
274pub fn keyword<'a>(kw: &'static str) -> impl Parser<FrameStream<'a>, Output = &'a str> + 'a {
275 debug_assert_eq!(kw, kw.to_uppercase(), "keywords should be uppercase");
276 satisfy_map(move |frame: &'a BytesFrame| match frame {
277 BytesFrame::BlobString { data, .. }
278 | BytesFrame::SimpleString { data, .. }
279 | BytesFrame::VerbatimString {
280 data,
281 format: VerbatimStringFormat::Text,
282 ..
283 } => str::from_utf8(data).ok().and_then(|s| {
284 if s.to_uppercase() == kw {
285 Some(s)
286 } else {
287 None
288 }
289 }),
290 _ => None,
291 })
292 .expected(kw)
293}
294
295pub fn number_u64<'a>() -> impl Parser<FrameStream<'a>, Output = u64> + 'a {
296 satisfy_map(|frame: &'a BytesFrame| match frame {
297 BytesFrame::BlobString { data, .. }
298 | BytesFrame::SimpleString { data, .. }
299 | BytesFrame::BigNumber { data, .. }
300 | BytesFrame::VerbatimString {
301 data,
302 format: VerbatimStringFormat::Text,
303 ..
304 } => str::from_utf8(data).ok().and_then(|s| s.parse().ok()),
305 BytesFrame::Number { data, .. } => (*data).try_into().ok(),
306 _ => None,
307 })
308 .expected("number")
309}
310
311pub fn number_u64_min<'a>(min: u64) -> impl Parser<FrameStream<'a>, Output = u64> + 'a {
312 number_u64().and_then(move |n| {
313 if n < min {
314 Err(easy::Error::message_format(format!(
315 "number {n} must not be less than {min}"
316 )))
317 } else {
318 Ok(n)
319 }
320 })
321}
322
323pub fn number_i64<'a>() -> impl Parser<FrameStream<'a>, Output = i64> + 'a {
324 satisfy_map(|frame: &'a BytesFrame| match frame {
325 BytesFrame::BlobString { data, .. }
326 | BytesFrame::SimpleString { data, .. }
327 | BytesFrame::BigNumber { data, .. }
328 | BytesFrame::VerbatimString {
329 data,
330 format: VerbatimStringFormat::Text,
331 ..
332 } => str::from_utf8(data).ok().and_then(|s| s.parse().ok()),
333 BytesFrame::Number { data, .. } => Some(*data),
334 _ => None,
335 })
336 .expected("number")
337}
338
339pub fn number_u32<'a>() -> impl Parser<FrameStream<'a>, Output = u32> + 'a {
340 satisfy_map(|frame: &'a BytesFrame| match frame {
341 BytesFrame::BlobString { data, .. }
342 | BytesFrame::SimpleString { data, .. }
343 | BytesFrame::BigNumber { data, .. }
344 | BytesFrame::VerbatimString {
345 data,
346 format: VerbatimStringFormat::Text,
347 ..
348 } => str::from_utf8(data).ok().and_then(|s| s.parse().ok()),
349 BytesFrame::Number { data, .. } => (*data).try_into().ok(),
350 _ => None,
351 })
352 .expected("number")
353}
354
355pub fn partition_id<'a>() -> impl Parser<FrameStream<'a>, Output = PartitionId> + 'a {
356 satisfy_map(|frame: &'a BytesFrame| match frame {
357 BytesFrame::BlobString { data, .. }
358 | BytesFrame::SimpleString { data, .. }
359 | BytesFrame::BigNumber { data, .. }
360 | BytesFrame::VerbatimString {
361 data,
362 format: VerbatimStringFormat::Text,
363 ..
364 } => str::from_utf8(data).ok().and_then(|s| s.parse().ok()),
365 BytesFrame::Number { data, .. } => (*data).try_into().ok(),
366 _ => None,
367 })
368 .expected("partition id")
369}
370
371fn uuid<'a>(expected: &'static str) -> impl Parser<FrameStream<'a>, Output = Uuid> + 'a {
372 string()
373 .and_then(move |s| {
374 Uuid::parse_str(s.trim())
375 .map_err(|_| easy::Error::message_format(format!("invalid {expected}")))
376 })
377 .expected(expected)
378}
379
380pub fn event_id<'a>() -> impl Parser<FrameStream<'a>, Output = Uuid> + 'a {
381 uuid("event id")
382}
383
384pub fn partition_key<'a>() -> impl Parser<FrameStream<'a>, Output = Uuid> + 'a {
385 uuid("partition key")
386}
387
388pub fn subscription_id<'a>() -> impl Parser<FrameStream<'a>, Output = Uuid> + 'a {
389 uuid("subscription id")
390}
391
392pub fn expected_version<'a>() -> impl Parser<FrameStream<'a>, Output = ExpectedVersion> + 'a {
393 let exact = number_u64().map(ExpectedVersion::Exact);
394
395 let keyword = choice((
396 keyword("ANY").map(|_| ExpectedVersion::Any),
397 keyword("EXISTS").map(|_| ExpectedVersion::Exists),
398 keyword("EMPTY").map(|_| ExpectedVersion::Empty),
399 ));
400
401 exact
402 .or(keyword)
403 .message("expected version number or 'any', 'exists', 'empty'")
404}
405
406pub fn range_value<'a>() -> impl Parser<FrameStream<'a>, Output = RangeValue> + 'a {
407 choice!(
408 keyword("-").map(|_| RangeValue::Start),
409 keyword("+").map(|_| RangeValue::End),
410 number_u64().map(RangeValue::Value)
411 )
412 .expected("range value (-, +, or number)")
413}
414
415pub fn partition_selector<'a>() -> impl Parser<FrameStream<'a>, Output = PartitionSelector> + 'a {
416 or(
417 attempt(partition_key().map(PartitionSelector::ByKey)),
418 partition_id().map(PartitionSelector::ById),
419 )
420 .expected("partition id or key")
421}
422
423pub fn all_selector<'a>() -> impl Parser<FrameStream<'a>, Output = char> + 'a {
424 keyword("*").map(|_| '*')
425}
426
427pub fn partition_ids<'a>() -> impl Parser<FrameStream<'a>, Output = HashSet<PartitionId>> + 'a {
428 satisfy_map(|frame: &'a BytesFrame| {
429 match frame {
430 BytesFrame::BlobString { data, .. }
431 | BytesFrame::SimpleString { data, .. }
432 | BytesFrame::VerbatimString {
433 data,
434 format: VerbatimStringFormat::Text,
435 ..
436 } => {
437 let data = str::from_utf8(data).ok()?;
438
439 data.split(',')
441 .map(|part| part.trim().parse::<PartitionId>())
442 .collect::<Result<_, _>>()
443 .ok()
444 }
445 BytesFrame::BigNumber { data, .. } => {
446 let id: PartitionId = str::from_utf8(data).ok()?.parse().ok()?;
448 Some(HashSet::from_iter([id]))
449 }
450 BytesFrame::Number { data, .. } => {
451 let id: PartitionId = (*data).try_into().ok()?;
453 Some(HashSet::from_iter([id]))
454 }
455 _ => None,
456 }
457 })
458 .expected("comma-separated partition ids")
459}
460
461pub fn partition_id_sequence<'a>() -> impl Parser<FrameStream<'a>, Output = (PartitionId, u64)> + 'a
463{
464 satisfy_map(|frame: &'a BytesFrame| match frame {
465 BytesFrame::BlobString { data, .. }
466 | BytesFrame::SimpleString { data, .. }
467 | BytesFrame::VerbatimString {
468 data,
469 format: VerbatimStringFormat::Text,
470 ..
471 } => {
472 let (partition_id, sequence) = str::from_utf8(data).ok()?.split_once('=')?;
473 let partition_id: PartitionId = partition_id.parse().ok()?;
474 let sequence: u64 = sequence.parse().ok()?;
475 Some((partition_id, sequence))
476 }
477 _ => None,
478 })
479 .expected("partition id sequence value")
480}
481
482pub fn stream_id_version<'a>() -> impl Parser<FrameStream<'a>, Output = (StreamId, u64)> + 'a {
484 string()
485 .and_then(|s| {
486 let (stream_id, version) = s
487 .split_once('=')
488 .ok_or_else(|| easy::Error::message_format("missing `=` in stream id version"))?;
489 let stream_id = StreamId::new(stream_id).map_err(easy::Error::message_format)?;
490 let version: u64 = version
491 .parse()
492 .map_err(|_| easy::Error::message_format("invalid stream id version number"))?;
493 Ok::<_, easy::Error<_, _>>((stream_id, version))
494 })
495 .expected("stream id version value")
496}
497
498pub fn stream_id<'a>() -> impl Parser<FrameStream<'a>, Output = StreamId> + 'a {
499 string()
500 .and_then(|s| StreamId::new(s).map_err(easy::Error::message_format))
501 .expected("stream id")
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507
508 #[test]
509 fn test_string_parser() {
510 let frames = vec![BytesFrame::SimpleString {
511 data: b"GET".to_vec().into(),
512 attributes: None,
513 }];
514
515 let stream = frame_stream(&frames);
516 let result = string().parse(stream);
517 assert!(result.is_ok());
518 let (parsed, _) = result.unwrap();
519 assert_eq!(parsed, "GET");
520 }
521
522 #[test]
523 fn test_keyword_parser() {
524 let frames = vec![BytesFrame::SimpleString {
525 data: b"PARTITION_KEY".to_vec().into(),
526 attributes: None,
527 }];
528
529 let stream = frame_stream(&frames);
530 let (parsed, _) = keyword("PARTITION_KEY").parse(stream).unwrap();
531 assert_eq!(parsed, "PARTITION_KEY");
532 }
533
534 #[test]
535 fn test_partition_id_parser() {
536 let frames = vec![BytesFrame::Number {
537 data: 10,
538 attributes: None,
539 }];
540
541 let stream = frame_stream(&frames);
542 let (parsed, _) = partition_id().parse(stream).unwrap();
543 assert_eq!(parsed, 10);
544
545 let frames = vec![BytesFrame::SimpleString {
546 data: b"10".to_vec().into(),
547 attributes: None,
548 }];
549
550 let stream = frame_stream(&frames);
551 let (parsed, _) = partition_id().parse(stream).unwrap();
552 assert_eq!(parsed, 10);
553 }
554
555 #[test]
556 fn test_partition_ids_parser() {
557 let frames = vec![BytesFrame::Number {
558 data: 10,
559 attributes: None,
560 }];
561
562 let stream = frame_stream(&frames);
563 let (parsed, _) = partition_ids().parse(stream).unwrap();
564 assert_eq!(parsed, HashSet::from_iter([10]));
565
566 let frames = vec![BytesFrame::SimpleString {
567 data: b"10".to_vec().into(),
568 attributes: None,
569 }];
570
571 let stream = frame_stream(&frames);
572 let (parsed, _) = partition_ids().parse(stream).unwrap();
573 assert_eq!(parsed, HashSet::from_iter([10]));
574
575 let frames = vec![BytesFrame::SimpleString {
576 data: b"10,59,24".to_vec().into(),
577 attributes: None,
578 }];
579
580 let stream = frame_stream(&frames);
581 let (parsed, _) = partition_ids().parse(stream).unwrap();
582 assert_eq!(parsed, HashSet::from_iter([10, 59, 24]));
583 }
584}