1use substrait::proto::rel::RelType;
2use substrait::proto::rel_common::{Emit, EmitKind};
3use substrait::proto::{
4 AggregateRel, Expression, FilterRel, NamedStruct, ProjectRel, ReadRel, Rel, RelCommon, Type,
5 aggregate_rel, read_rel, r#type,
6};
7
8use super::{ErrorKind, MessageParseError, Rule, ScopedParsePair, unwrap_single_pair};
9use crate::extensions::SimpleExtensions;
10use crate::parser::expressions::{Name, reference};
11use crate::parser::{ParsePair, RuleIter};
12
13pub trait RelationParsePair: Sized {
16 fn rule() -> Rule;
17
18 fn message() -> &'static str;
19
20 fn parse_pair_with_context(
28 extensions: &SimpleExtensions,
29 pair: pest::iterators::Pair<Rule>,
30 input_children: Vec<Box<Rel>>,
31 input_field_count: usize,
32 ) -> Result<Self, MessageParseError>;
33
34 fn into_rel(self) -> Rel;
35}
36
37pub struct TableName(Vec<String>);
38
39impl ParsePair for TableName {
40 fn rule() -> Rule {
41 Rule::table_name
42 }
43
44 fn message() -> &'static str {
45 "TableName"
46 }
47
48 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
49 assert_eq!(pair.as_rule(), Self::rule());
50 let pairs = pair.into_inner();
51 let mut names = Vec::with_capacity(pairs.len());
52 let mut iter = RuleIter::from(pairs);
53 while let Some(name) = iter.parse_if_next::<Name>() {
54 names.push(name.0);
55 }
56 iter.done();
57 Self(names)
58 }
59}
60
61#[derive(Debug, Clone)]
62pub struct Column {
63 pub name: String,
64 pub typ: Type,
65}
66
67impl ScopedParsePair for Column {
68 fn rule() -> Rule {
69 Rule::named_column
70 }
71
72 fn message() -> &'static str {
73 "Column"
74 }
75
76 fn parse_pair(
77 extensions: &SimpleExtensions,
78 pair: pest::iterators::Pair<Rule>,
79 ) -> Result<Self, MessageParseError> {
80 assert_eq!(pair.as_rule(), Self::rule());
81 let mut iter = RuleIter::from(pair.into_inner());
82 let name = iter.parse_next::<Name>().0;
83 let typ = iter.parse_next_scoped(extensions)?;
84 iter.done();
85 Ok(Self { name, typ })
86 }
87}
88
89pub struct NamedColumnList(Vec<Column>);
90
91impl ScopedParsePair for NamedColumnList {
92 fn rule() -> Rule {
93 Rule::named_column_list
94 }
95
96 fn message() -> &'static str {
97 "NamedColumnList"
98 }
99
100 fn parse_pair(
101 extensions: &SimpleExtensions,
102 pair: pest::iterators::Pair<Rule>,
103 ) -> Result<Self, MessageParseError> {
104 assert_eq!(pair.as_rule(), Self::rule());
105 let mut columns = Vec::new();
106 for col in pair.into_inner() {
107 columns.push(Column::parse_pair(extensions, col)?);
108 }
109 Ok(Self(columns))
110 }
111}
112
113#[allow(clippy::vec_box)]
118pub(crate) fn expect_one_child(
119 message: &'static str,
120 pair: &pest::iterators::Pair<Rule>,
121 mut input_children: Vec<Box<Rel>>,
122) -> Result<Box<Rel>, MessageParseError> {
123 match input_children.len() {
124 0 => Err(MessageParseError::invalid(
125 message,
126 pair.as_span(),
127 format!("{message} missing child"),
128 )),
129 1 => Ok(input_children.pop().unwrap()),
130 n => Err(MessageParseError::invalid(
131 message,
132 pair.as_span(),
133 format!("{message} should have 1 input child, got {n}"),
134 )),
135 }
136}
137
138impl RelationParsePair for ReadRel {
139 fn rule() -> Rule {
140 Rule::read_relation
141 }
142
143 fn message() -> &'static str {
144 "ReadRel"
145 }
146
147 fn into_rel(self) -> Rel {
148 Rel {
149 rel_type: Some(RelType::Read(Box::new(self))),
150 }
151 }
152
153 fn parse_pair_with_context(
154 extensions: &SimpleExtensions,
155 pair: pest::iterators::Pair<Rule>,
156 input_children: Vec<Box<Rel>>,
157 input_field_count: usize,
158 ) -> Result<Self, MessageParseError> {
159 assert_eq!(pair.as_rule(), Self::rule());
160 if !input_children.is_empty() {
162 return Err(MessageParseError::invalid(
163 Self::message(),
164 pair.as_span(),
165 "ReadRel should have no input children",
166 ));
167 }
168 if input_field_count != 0 {
169 let error = pest::error::Error::new_from_span(
170 pest::error::ErrorVariant::CustomError {
171 message: "ReadRel should have 0 input fields".to_string(),
172 },
173 pair.as_span(),
174 );
175 return Err(MessageParseError::new(
176 "ReadRel",
177 ErrorKind::InvalidValue,
178 Box::new(error),
179 ));
180 }
181
182 let mut iter = RuleIter::from(pair.into_inner());
183 let table = iter.parse_next::<TableName>().0;
184 let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
185 iter.done();
186
187 let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
188 let struct_ = r#type::Struct {
189 types,
190 type_variation_reference: 0,
191 nullability: r#type::Nullability::Required as i32,
192 };
193 let named_struct = NamedStruct {
194 names,
195 r#struct: Some(struct_),
196 };
197
198 let read_rel = ReadRel {
199 base_schema: Some(named_struct),
200 read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
201 names: table,
202 advanced_extension: None,
203 })),
204 ..Default::default()
205 };
206
207 Ok(read_rel)
208 }
209}
210
211impl RelationParsePair for FilterRel {
212 fn rule() -> Rule {
213 Rule::filter_relation
214 }
215
216 fn message() -> &'static str {
217 "FilterRel"
218 }
219
220 fn into_rel(self) -> Rel {
221 Rel {
222 rel_type: Some(RelType::Filter(Box::new(self))),
223 }
224 }
225
226 fn parse_pair_with_context(
227 extensions: &SimpleExtensions,
228 pair: pest::iterators::Pair<Rule>,
229 input_children: Vec<Box<Rel>>,
230 _input_field_count: usize,
231 ) -> Result<Self, MessageParseError> {
232 assert_eq!(pair.as_rule(), Self::rule());
233 let input = expect_one_child(Self::message(), &pair, input_children)?;
234 let mut iter = RuleIter::from(pair.into_inner());
235 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
236 let references_pair = iter.pop(Rule::reference_list);
237 let output_mapping = references_pair
238 .into_inner()
239 .map(|p| {
240 let inner = crate::parser::unwrap_single_pair(p);
241 inner.as_str().parse::<i32>().unwrap()
242 })
243 .collect::<Vec<i32>>();
244 iter.done();
245 let emit = EmitKind::Emit(Emit { output_mapping });
246 let common = RelCommon {
247 emit_kind: Some(emit),
248 ..Default::default()
249 };
250 Ok(FilterRel {
251 input: Some(input),
252 condition: Some(Box::new(condition)),
253 common: Some(common),
254 advanced_extension: None,
255 })
256 }
257}
258
259impl RelationParsePair for ProjectRel {
260 fn rule() -> Rule {
261 Rule::project_relation
262 }
263
264 fn message() -> &'static str {
265 "ProjectRel"
266 }
267
268 fn into_rel(self) -> Rel {
269 Rel {
270 rel_type: Some(RelType::Project(Box::new(self))),
271 }
272 }
273
274 fn parse_pair_with_context(
275 extensions: &SimpleExtensions,
276 pair: pest::iterators::Pair<Rule>,
277 input_children: Vec<Box<Rel>>,
278 input_field_count: usize,
279 ) -> Result<Self, MessageParseError> {
280 assert_eq!(pair.as_rule(), Self::rule());
281 let input = expect_one_child(Self::message(), &pair, input_children)?;
282
283 let arguments_pair = unwrap_single_pair(pair);
285
286 let mut expressions = Vec::new();
287 let mut output_mapping = Vec::new();
288
289 for arg in arguments_pair.into_inner() {
291 let inner_arg = crate::parser::unwrap_single_pair(arg);
292 match inner_arg.as_rule() {
293 Rule::reference => {
294 let inner = crate::parser::unwrap_single_pair(inner_arg);
296 let ref_index = inner.as_str().parse::<i32>().unwrap();
297 output_mapping.push(ref_index);
298 }
299 Rule::expression => {
300 let _expr = Expression::parse_pair(extensions, inner_arg)?;
302 expressions.push(_expr);
303 output_mapping.push(input_field_count as i32 + (expressions.len() as i32 - 1));
305 }
306 _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
307 }
308 }
309
310 let emit = EmitKind::Emit(Emit { output_mapping });
311 let common = RelCommon {
312 emit_kind: Some(emit),
313 ..Default::default()
314 };
315
316 Ok(ProjectRel {
317 input: Some(input),
318 expressions,
319 common: Some(common),
320 advanced_extension: None,
321 })
322 }
323}
324
325impl RelationParsePair for AggregateRel {
326 fn rule() -> Rule {
327 Rule::aggregate_relation
328 }
329
330 fn message() -> &'static str {
331 "AggregateRel"
332 }
333
334 fn into_rel(self) -> Rel {
335 Rel {
336 rel_type: Some(RelType::Aggregate(Box::new(self))),
337 }
338 }
339
340 fn parse_pair_with_context(
341 extensions: &SimpleExtensions,
342 pair: pest::iterators::Pair<Rule>,
343 input_children: Vec<Box<Rel>>,
344 _input_field_count: usize,
345 ) -> Result<Self, MessageParseError> {
346 assert_eq!(pair.as_rule(), Self::rule());
347 let input = expect_one_child(Self::message(), &pair, input_children)?;
348 let mut iter = RuleIter::from(pair.into_inner());
349 let group_by_pair = iter.pop(Rule::aggregate_group_by);
350 let output_pair = iter.pop(Rule::aggregate_output);
351 iter.done();
352 let mut grouping_expressions = Vec::new();
353 for group_by_item in group_by_pair.into_inner() {
354 match group_by_item.as_rule() {
355 Rule::reference => {
356 let inner = crate::parser::unwrap_single_pair(group_by_item);
357 let ref_index = inner.as_str().parse::<i32>().unwrap();
358 grouping_expressions.push(Expression {
359 rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
360 reference(ref_index),
361 ))),
362 });
363 }
364 Rule::empty => {
365 }
367 _ => panic!(
368 "Unexpected group-by item rule: {:?}",
369 group_by_item.as_rule()
370 ),
371 }
372 }
373
374 let mut measures = Vec::new();
376 let mut output_mapping = Vec::new();
377 let group_by_count = grouping_expressions.len();
378 let mut measure_count = 0;
379
380 for output_item in output_pair.into_inner() {
381 let inner_item = unwrap_single_pair(output_item);
382 match inner_item.as_rule() {
383 Rule::reference => {
384 let inner = crate::parser::unwrap_single_pair(inner_item);
385 let ref_index = inner.as_str().parse::<i32>().unwrap();
386 output_mapping.push(ref_index);
387 }
388 Rule::aggregate_measure => {
389 let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
390 measures.push(measure);
391 output_mapping.push(group_by_count as i32 + measure_count);
392 measure_count += 1;
393 }
394 _ => panic!(
395 "Unexpected inner output item rule: {:?}",
396 inner_item.as_rule()
397 ),
398 }
399 }
400
401 let emit = EmitKind::Emit(Emit { output_mapping });
402 let common = RelCommon {
403 emit_kind: Some(emit),
404 ..Default::default()
405 };
406
407 Ok(AggregateRel {
408 input: Some(input),
409 grouping_expressions,
410 groupings: vec![], measures,
412 common: Some(common),
413 advanced_extension: None,
414 })
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use pest::Parser;
421
422 use super::*;
423 use crate::fixtures::TestContext;
424 use crate::parser::{ExpressionParser, Rule};
425
426 #[test]
427 fn test_parse_relation() {
428 }
430
431 #[test]
432 fn test_parse_read_relation() {
433 let extensions = SimpleExtensions::default();
434 let read = ReadRel::parse_pair_with_context(
435 &extensions,
436 parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
437 vec![],
438 0,
439 )
440 .unwrap();
441 let names = match &read.read_type {
442 Some(read_rel::ReadType::NamedTable(table)) => &table.names,
443 _ => panic!("Expected NamedTable"),
444 };
445 assert_eq!(names, &["ab", "cd", "ef"]);
446 let columns = &read
447 .base_schema
448 .as_ref()
449 .unwrap()
450 .r#struct
451 .as_ref()
452 .unwrap()
453 .types;
454 assert_eq!(columns.len(), 2);
455 }
456
457 fn example_read_relation() -> ReadRel {
459 let extensions = SimpleExtensions::default();
460 ReadRel::parse_pair_with_context(
461 &extensions,
462 parse_exact(
463 Rule::read_relation,
464 "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
465 ),
466 vec![],
467 0,
468 )
469 .unwrap()
470 }
471
472 #[test]
473 fn test_parse_filter_relation() {
474 let extensions = SimpleExtensions::default();
475 let filter = FilterRel::parse_pair_with_context(
476 &extensions,
477 parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
478 vec![Box::new(example_read_relation().into_rel())],
479 3,
480 )
481 .unwrap();
482 let emit_kind = &filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
483 let emit = match emit_kind {
484 EmitKind::Emit(emit) => &emit.output_mapping,
485 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
486 };
487 assert_eq!(emit, &[0, 1, 2]);
488 }
489
490 #[test]
491 fn test_parse_project_relation() {
492 let extensions = SimpleExtensions::default();
493 let project = ProjectRel::parse_pair_with_context(
494 &extensions,
495 parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
496 vec![Box::new(example_read_relation().into_rel())],
497 3,
498 )
499 .unwrap();
500
501 assert_eq!(project.expressions.len(), 1);
503
504 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
505 let emit = match emit_kind {
506 EmitKind::Emit(emit) => &emit.output_mapping,
507 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
508 };
509 assert_eq!(emit, &[0, 1, 3]);
511 }
512
513 #[test]
514 fn test_parse_project_relation_complex() {
515 let extensions = SimpleExtensions::default();
516 let project = ProjectRel::parse_pair_with_context(
517 &extensions,
518 parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
519 vec![Box::new(example_read_relation().into_rel())],
520 5, )
522 .unwrap();
523
524 assert_eq!(project.expressions.len(), 2);
526
527 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
528 let emit = match emit_kind {
529 EmitKind::Emit(emit) => &emit.output_mapping,
530 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
531 };
532 assert_eq!(emit, &[5, 0, 6, 2, 1]);
535 }
536
537 #[test]
538 fn test_parse_aggregate_relation() {
539 let extensions = TestContext::new()
540 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
541 .with_function(1, 10, "sum")
542 .with_function(1, 11, "count")
543 .extensions;
544
545 let aggregate = AggregateRel::parse_pair_with_context(
546 &extensions,
547 parse_exact(
548 Rule::aggregate_relation,
549 "Aggregate[$0, $1 => sum($2), $0, count($2)]",
550 ),
551 vec![Box::new(example_read_relation().into_rel())],
552 3,
553 )
554 .unwrap();
555
556 assert_eq!(aggregate.grouping_expressions.len(), 2);
558 assert_eq!(aggregate.measures.len(), 2);
559
560 let emit_kind = &aggregate
561 .common
562 .as_ref()
563 .unwrap()
564 .emit_kind
565 .as_ref()
566 .unwrap();
567 let emit = match emit_kind {
568 EmitKind::Emit(emit) => &emit.output_mapping,
569 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
570 };
571 assert_eq!(emit, &[2, 0, 3]);
574 }
575
576 #[test]
577 fn test_parse_aggregate_relation_simple() {
578 let extensions = TestContext::new()
579 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
580 .with_function(1, 10, "sum")
581 .with_function(1, 11, "count")
582 .extensions;
583
584 let aggregate = AggregateRel::parse_pair_with_context(
585 &extensions,
586 parse_exact(
587 Rule::aggregate_relation,
588 "Aggregate[$0 => sum($1), count($1)]",
589 ),
590 vec![Box::new(example_read_relation().into_rel())],
591 3,
592 )
593 .unwrap();
594
595 assert_eq!(aggregate.grouping_expressions.len(), 1);
597 assert_eq!(aggregate.measures.len(), 2);
598
599 let emit_kind = &aggregate
600 .common
601 .as_ref()
602 .unwrap()
603 .emit_kind
604 .as_ref()
605 .unwrap();
606 let emit = match emit_kind {
607 EmitKind::Emit(emit) => &emit.output_mapping,
608 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
609 };
610 assert_eq!(emit, &[1, 2]);
612 }
613
614 #[test]
615 fn test_parse_aggregate_relation_no_group_by() {
616 let extensions = TestContext::new()
617 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
618 .with_function(1, 10, "sum")
619 .with_function(1, 11, "count")
620 .extensions;
621
622 let aggregate = AggregateRel::parse_pair_with_context(
623 &extensions,
624 parse_exact(
625 Rule::aggregate_relation,
626 "Aggregate[_ => sum($0), count($1)]",
627 ),
628 vec![Box::new(example_read_relation().into_rel())],
629 3,
630 )
631 .unwrap();
632
633 assert_eq!(aggregate.grouping_expressions.len(), 0);
635 assert_eq!(aggregate.measures.len(), 2);
636
637 let emit_kind = &aggregate
638 .common
639 .as_ref()
640 .unwrap()
641 .emit_kind
642 .as_ref()
643 .unwrap();
644 let emit = match emit_kind {
645 EmitKind::Emit(emit) => &emit.output_mapping,
646 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
647 };
648 assert_eq!(emit, &[0, 1]);
650 }
651
652 #[test]
653 fn test_parse_aggregate_relation_empty_group_by() {
654 let extensions = TestContext::new()
655 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
656 .with_function(1, 10, "sum")
657 .with_function(1, 11, "count")
658 .extensions;
659
660 let aggregate = AggregateRel::parse_pair_with_context(
661 &extensions,
662 parse_exact(
663 Rule::aggregate_relation,
664 "Aggregate[_ => sum($0), count($1)]",
665 ),
666 vec![Box::new(example_read_relation().into_rel())],
667 3,
668 )
669 .unwrap();
670
671 assert_eq!(aggregate.grouping_expressions.len(), 0);
673 assert_eq!(aggregate.measures.len(), 2);
674
675 let emit_kind = &aggregate
676 .common
677 .as_ref()
678 .unwrap()
679 .emit_kind
680 .as_ref()
681 .unwrap();
682 let emit = match emit_kind {
683 EmitKind::Emit(emit) => &emit.output_mapping,
684 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
685 };
686 assert_eq!(emit, &[0, 1]);
688 }
689
690 fn parse_exact(rule: Rule, input: &str) -> pest::iterators::Pair<Rule> {
691 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
692 assert_eq!(pairs.as_str(), input);
693 let pair = pairs.next().unwrap();
694 assert_eq!(pairs.next(), None);
695 pair
696 }
697}