substrait_explain/parser/
relations.rs

1use substrait::proto::rel::RelType;
2use substrait::proto::rel_common::{Emit, EmitKind};
3use substrait::proto::{
4    AggregateRel, Expression, FilterRel, NamedStruct, ProjectRel, ReadRel, Rel, RelCommon, Type,
5    aggregate_rel, read_rel, r#type,
6};
7
8use super::{ErrorKind, MessageParseError, Rule, ScopedParsePair, unwrap_single_pair};
9use crate::extensions::SimpleExtensions;
10use crate::parser::expressions::{Name, reference};
11use crate::parser::{ParsePair, RuleIter};
12
13/// A trait for parsing relations with full context needed for tree building.
14/// This includes extensions, the parsed pair, input children, and output field count.
15pub trait RelationParsePair: Sized {
16    fn rule() -> Rule;
17
18    fn message() -> &'static str;
19
20    /// Parse a relation with full context for tree building.
21    ///
22    /// Args:
23    /// - extensions: The extensions context
24    /// - pair: The parsed pest pair
25    /// - input_children: The input relations (for wiring)
26    /// - input_field_count: Number of output fields from input children (for output mapping)
27    fn parse_pair_with_context(
28        extensions: &SimpleExtensions,
29        pair: pest::iterators::Pair<Rule>,
30        input_children: Vec<Box<Rel>>,
31        input_field_count: usize,
32    ) -> Result<Self, MessageParseError>;
33
34    fn into_rel(self) -> Rel;
35}
36
37pub struct TableName(Vec<String>);
38
39impl ParsePair for TableName {
40    fn rule() -> Rule {
41        Rule::table_name
42    }
43
44    fn message() -> &'static str {
45        "TableName"
46    }
47
48    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
49        assert_eq!(pair.as_rule(), Self::rule());
50        let pairs = pair.into_inner();
51        let mut names = Vec::with_capacity(pairs.len());
52        let mut iter = RuleIter::from(pairs);
53        while let Some(name) = iter.parse_if_next::<Name>() {
54            names.push(name.0);
55        }
56        iter.done();
57        Self(names)
58    }
59}
60
61#[derive(Debug, Clone)]
62pub struct Column {
63    pub name: String,
64    pub typ: Type,
65}
66
67impl ScopedParsePair for Column {
68    fn rule() -> Rule {
69        Rule::named_column
70    }
71
72    fn message() -> &'static str {
73        "Column"
74    }
75
76    fn parse_pair(
77        extensions: &SimpleExtensions,
78        pair: pest::iterators::Pair<Rule>,
79    ) -> Result<Self, MessageParseError> {
80        assert_eq!(pair.as_rule(), Self::rule());
81        let mut iter = RuleIter::from(pair.into_inner());
82        let name = iter.parse_next::<Name>().0;
83        let typ = iter.parse_next_scoped(extensions)?;
84        iter.done();
85        Ok(Self { name, typ })
86    }
87}
88
89pub struct NamedColumnList(Vec<Column>);
90
91impl ScopedParsePair for NamedColumnList {
92    fn rule() -> Rule {
93        Rule::named_column_list
94    }
95
96    fn message() -> &'static str {
97        "NamedColumnList"
98    }
99
100    fn parse_pair(
101        extensions: &SimpleExtensions,
102        pair: pest::iterators::Pair<Rule>,
103    ) -> Result<Self, MessageParseError> {
104        assert_eq!(pair.as_rule(), Self::rule());
105        let mut columns = Vec::new();
106        for col in pair.into_inner() {
107            columns.push(Column::parse_pair(extensions, col)?);
108        }
109        Ok(Self(columns))
110    }
111}
112
113/// This is a utility function for extracting a single child from the list of
114/// children, to be used in the RelationParsePair trait. The RelationParsePair
115/// trait passes a Vec of children, because some relations have multiple
116/// children - but most accept exactly one child.
117#[allow(clippy::vec_box)]
118pub(crate) fn expect_one_child(
119    message: &'static str,
120    pair: &pest::iterators::Pair<Rule>,
121    mut input_children: Vec<Box<Rel>>,
122) -> Result<Box<Rel>, MessageParseError> {
123    match input_children.len() {
124        0 => Err(MessageParseError::invalid(
125            message,
126            pair.as_span(),
127            format!("{message} missing child"),
128        )),
129        1 => Ok(input_children.pop().unwrap()),
130        n => Err(MessageParseError::invalid(
131            message,
132            pair.as_span(),
133            format!("{message} should have 1 input child, got {n}"),
134        )),
135    }
136}
137
138impl RelationParsePair for ReadRel {
139    fn rule() -> Rule {
140        Rule::read_relation
141    }
142
143    fn message() -> &'static str {
144        "ReadRel"
145    }
146
147    fn into_rel(self) -> Rel {
148        Rel {
149            rel_type: Some(RelType::Read(Box::new(self))),
150        }
151    }
152
153    fn parse_pair_with_context(
154        extensions: &SimpleExtensions,
155        pair: pest::iterators::Pair<Rule>,
156        input_children: Vec<Box<Rel>>,
157        input_field_count: usize,
158    ) -> Result<Self, MessageParseError> {
159        assert_eq!(pair.as_rule(), Self::rule());
160        // ReadRel is a leaf node - it should have no input children and 0 input fields
161        if !input_children.is_empty() {
162            return Err(MessageParseError::invalid(
163                Self::message(),
164                pair.as_span(),
165                "ReadRel should have no input children",
166            ));
167        }
168        if input_field_count != 0 {
169            let error = pest::error::Error::new_from_span(
170                pest::error::ErrorVariant::CustomError {
171                    message: "ReadRel should have 0 input fields".to_string(),
172                },
173                pair.as_span(),
174            );
175            return Err(MessageParseError::new(
176                "ReadRel",
177                ErrorKind::InvalidValue,
178                Box::new(error),
179            ));
180        }
181
182        let mut iter = RuleIter::from(pair.into_inner());
183        let table = iter.parse_next::<TableName>().0;
184        let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
185        iter.done();
186
187        let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
188        let struct_ = r#type::Struct {
189            types,
190            type_variation_reference: 0,
191            nullability: r#type::Nullability::Required as i32,
192        };
193        let named_struct = NamedStruct {
194            names,
195            r#struct: Some(struct_),
196        };
197
198        let read_rel = ReadRel {
199            base_schema: Some(named_struct),
200            read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
201                names: table,
202                advanced_extension: None,
203            })),
204            ..Default::default()
205        };
206
207        Ok(read_rel)
208    }
209}
210
211impl RelationParsePair for FilterRel {
212    fn rule() -> Rule {
213        Rule::filter_relation
214    }
215
216    fn message() -> &'static str {
217        "FilterRel"
218    }
219
220    fn into_rel(self) -> Rel {
221        Rel {
222            rel_type: Some(RelType::Filter(Box::new(self))),
223        }
224    }
225
226    fn parse_pair_with_context(
227        extensions: &SimpleExtensions,
228        pair: pest::iterators::Pair<Rule>,
229        input_children: Vec<Box<Rel>>,
230        _input_field_count: usize,
231    ) -> Result<Self, MessageParseError> {
232        assert_eq!(pair.as_rule(), Self::rule());
233        let input = expect_one_child(Self::message(), &pair, input_children)?;
234        let mut iter = RuleIter::from(pair.into_inner());
235        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
236        let references_pair = iter.pop(Rule::reference_list);
237        let output_mapping = references_pair
238            .into_inner()
239            .map(|p| {
240                let inner = crate::parser::unwrap_single_pair(p);
241                inner.as_str().parse::<i32>().unwrap()
242            })
243            .collect::<Vec<i32>>();
244        iter.done();
245        let emit = EmitKind::Emit(Emit { output_mapping });
246        let common = RelCommon {
247            emit_kind: Some(emit),
248            ..Default::default()
249        };
250        Ok(FilterRel {
251            input: Some(input),
252            condition: Some(Box::new(condition)),
253            common: Some(common),
254            advanced_extension: None,
255        })
256    }
257}
258
259impl RelationParsePair for ProjectRel {
260    fn rule() -> Rule {
261        Rule::project_relation
262    }
263
264    fn message() -> &'static str {
265        "ProjectRel"
266    }
267
268    fn into_rel(self) -> Rel {
269        Rel {
270            rel_type: Some(RelType::Project(Box::new(self))),
271        }
272    }
273
274    fn parse_pair_with_context(
275        extensions: &SimpleExtensions,
276        pair: pest::iterators::Pair<Rule>,
277        input_children: Vec<Box<Rel>>,
278        input_field_count: usize,
279    ) -> Result<Self, MessageParseError> {
280        assert_eq!(pair.as_rule(), Self::rule());
281        let input = expect_one_child(Self::message(), &pair, input_children)?;
282
283        // Get the argument list (contains references and expressions)
284        let arguments_pair = unwrap_single_pair(pair);
285
286        let mut expressions = Vec::new();
287        let mut output_mapping = Vec::new();
288
289        // Process each argument (can be either a reference or expression)
290        for arg in arguments_pair.into_inner() {
291            let inner_arg = crate::parser::unwrap_single_pair(arg);
292            match inner_arg.as_rule() {
293                Rule::reference => {
294                    // Parse reference like "$0" -> 0
295                    let inner = crate::parser::unwrap_single_pair(inner_arg);
296                    let ref_index = inner.as_str().parse::<i32>().unwrap();
297                    output_mapping.push(ref_index);
298                }
299                Rule::expression => {
300                    // Parse as expression (e.g., 42, add($0, $1))
301                    let _expr = Expression::parse_pair(extensions, inner_arg)?;
302                    expressions.push(_expr);
303                    // Expression: index after all input fields
304                    output_mapping.push(input_field_count as i32 + (expressions.len() as i32 - 1));
305                }
306                _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
307            }
308        }
309
310        let emit = EmitKind::Emit(Emit { output_mapping });
311        let common = RelCommon {
312            emit_kind: Some(emit),
313            ..Default::default()
314        };
315
316        Ok(ProjectRel {
317            input: Some(input),
318            expressions,
319            common: Some(common),
320            advanced_extension: None,
321        })
322    }
323}
324
325impl RelationParsePair for AggregateRel {
326    fn rule() -> Rule {
327        Rule::aggregate_relation
328    }
329
330    fn message() -> &'static str {
331        "AggregateRel"
332    }
333
334    fn into_rel(self) -> Rel {
335        Rel {
336            rel_type: Some(RelType::Aggregate(Box::new(self))),
337        }
338    }
339
340    fn parse_pair_with_context(
341        extensions: &SimpleExtensions,
342        pair: pest::iterators::Pair<Rule>,
343        input_children: Vec<Box<Rel>>,
344        _input_field_count: usize,
345    ) -> Result<Self, MessageParseError> {
346        assert_eq!(pair.as_rule(), Self::rule());
347        let input = expect_one_child(Self::message(), &pair, input_children)?;
348        let mut iter = RuleIter::from(pair.into_inner());
349        let group_by_pair = iter.pop(Rule::aggregate_group_by);
350        let output_pair = iter.pop(Rule::aggregate_output);
351        iter.done();
352        let mut grouping_expressions = Vec::new();
353        for group_by_item in group_by_pair.into_inner() {
354            match group_by_item.as_rule() {
355                Rule::reference => {
356                    let inner = crate::parser::unwrap_single_pair(group_by_item);
357                    let ref_index = inner.as_str().parse::<i32>().unwrap();
358                    grouping_expressions.push(Expression {
359                        rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
360                            reference(ref_index),
361                        ))),
362                    });
363                }
364                Rule::empty => {
365                    // No grouping expressions to add
366                }
367                _ => panic!(
368                    "Unexpected group-by item rule: {:?}",
369                    group_by_item.as_rule()
370                ),
371            }
372        }
373
374        // Parse output items (can be references or aggregate measures)
375        let mut measures = Vec::new();
376        let mut output_mapping = Vec::new();
377        let group_by_count = grouping_expressions.len();
378        let mut measure_count = 0;
379
380        for output_item in output_pair.into_inner() {
381            let inner_item = unwrap_single_pair(output_item);
382            match inner_item.as_rule() {
383                Rule::reference => {
384                    let inner = crate::parser::unwrap_single_pair(inner_item);
385                    let ref_index = inner.as_str().parse::<i32>().unwrap();
386                    output_mapping.push(ref_index);
387                }
388                Rule::aggregate_measure => {
389                    let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
390                    measures.push(measure);
391                    output_mapping.push(group_by_count as i32 + measure_count);
392                    measure_count += 1;
393                }
394                _ => panic!(
395                    "Unexpected inner output item rule: {:?}",
396                    inner_item.as_rule()
397                ),
398            }
399        }
400
401        let emit = EmitKind::Emit(Emit { output_mapping });
402        let common = RelCommon {
403            emit_kind: Some(emit),
404            ..Default::default()
405        };
406
407        Ok(AggregateRel {
408            input: Some(input),
409            grouping_expressions,
410            groupings: vec![], // TODO: Create groupings from grouping_expressions for complex grouping scenarios
411            measures,
412            common: Some(common),
413            advanced_extension: None,
414        })
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use pest::Parser;
421
422    use super::*;
423    use crate::fixtures::TestContext;
424    use crate::parser::{ExpressionParser, Rule};
425
426    #[test]
427    fn test_parse_relation() {
428        // Removed: test_parse_relation for old Relation struct
429    }
430
431    #[test]
432    fn test_parse_read_relation() {
433        let extensions = SimpleExtensions::default();
434        let read = ReadRel::parse_pair_with_context(
435            &extensions,
436            parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
437            vec![],
438            0,
439        )
440        .unwrap();
441        let names = match &read.read_type {
442            Some(read_rel::ReadType::NamedTable(table)) => &table.names,
443            _ => panic!("Expected NamedTable"),
444        };
445        assert_eq!(names, &["ab", "cd", "ef"]);
446        let columns = &read
447            .base_schema
448            .as_ref()
449            .unwrap()
450            .r#struct
451            .as_ref()
452            .unwrap()
453            .types;
454        assert_eq!(columns.len(), 2);
455    }
456
457    /// Produces a ReadRel with 3 columns: a:i32, b:string?, c:i64
458    fn example_read_relation() -> ReadRel {
459        let extensions = SimpleExtensions::default();
460        ReadRel::parse_pair_with_context(
461            &extensions,
462            parse_exact(
463                Rule::read_relation,
464                "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
465            ),
466            vec![],
467            0,
468        )
469        .unwrap()
470    }
471
472    #[test]
473    fn test_parse_filter_relation() {
474        let extensions = SimpleExtensions::default();
475        let filter = FilterRel::parse_pair_with_context(
476            &extensions,
477            parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
478            vec![Box::new(example_read_relation().into_rel())],
479            3,
480        )
481        .unwrap();
482        let emit_kind = &filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
483        let emit = match emit_kind {
484            EmitKind::Emit(emit) => &emit.output_mapping,
485            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
486        };
487        assert_eq!(emit, &[0, 1, 2]);
488    }
489
490    #[test]
491    fn test_parse_project_relation() {
492        let extensions = SimpleExtensions::default();
493        let project = ProjectRel::parse_pair_with_context(
494            &extensions,
495            parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
496            vec![Box::new(example_read_relation().into_rel())],
497            3,
498        )
499        .unwrap();
500
501        // Should have 1 expression (42) and 2 references ($0, $1)
502        assert_eq!(project.expressions.len(), 1);
503
504        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
505        let emit = match emit_kind {
506            EmitKind::Emit(emit) => &emit.output_mapping,
507            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
508        };
509        // Output mapping should be [0, 1, 3]. References are 0-2; expression is 3.
510        assert_eq!(emit, &[0, 1, 3]);
511    }
512
513    #[test]
514    fn test_parse_project_relation_complex() {
515        let extensions = SimpleExtensions::default();
516        let project = ProjectRel::parse_pair_with_context(
517            &extensions,
518            parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
519            vec![Box::new(example_read_relation().into_rel())],
520            5, // Assume 5 input fields
521        )
522        .unwrap();
523
524        // Should have 2 expressions (42, 100) and 3 references ($0, $2, $1)
525        assert_eq!(project.expressions.len(), 2);
526
527        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
528        let emit = match emit_kind {
529            EmitKind::Emit(emit) => &emit.output_mapping,
530            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
531        };
532        // Direct mapping: [input_fields..., 42, 100] (input fields first, then expressions)
533        // Output mapping: [5, 0, 6, 2, 1] (to get: 42, $0, 100, $2, $1)
534        assert_eq!(emit, &[5, 0, 6, 2, 1]);
535    }
536
537    #[test]
538    fn test_parse_aggregate_relation() {
539        let extensions = TestContext::new()
540            .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
541            .with_function(1, 10, "sum")
542            .with_function(1, 11, "count")
543            .extensions;
544
545        let aggregate = AggregateRel::parse_pair_with_context(
546            &extensions,
547            parse_exact(
548                Rule::aggregate_relation,
549                "Aggregate[$0, $1 => sum($2), $0, count($2)]",
550            ),
551            vec![Box::new(example_read_relation().into_rel())],
552            3,
553        )
554        .unwrap();
555
556        // Should have 2 group-by fields ($0, $1) and 2 measures (sum($2), count($2))
557        assert_eq!(aggregate.grouping_expressions.len(), 2);
558        assert_eq!(aggregate.measures.len(), 2);
559
560        let emit_kind = &aggregate
561            .common
562            .as_ref()
563            .unwrap()
564            .emit_kind
565            .as_ref()
566            .unwrap();
567        let emit = match emit_kind {
568            EmitKind::Emit(emit) => &emit.output_mapping,
569            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
570        };
571        // Output mapping should be [2, 0, 3] (measures and group-by fields in order)
572        // sum($2) -> 2, $0 -> 0, count($2) -> 3
573        assert_eq!(emit, &[2, 0, 3]);
574    }
575
576    #[test]
577    fn test_parse_aggregate_relation_simple() {
578        let extensions = TestContext::new()
579            .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
580            .with_function(1, 10, "sum")
581            .with_function(1, 11, "count")
582            .extensions;
583
584        let aggregate = AggregateRel::parse_pair_with_context(
585            &extensions,
586            parse_exact(
587                Rule::aggregate_relation,
588                "Aggregate[$0 => sum($1), count($1)]",
589            ),
590            vec![Box::new(example_read_relation().into_rel())],
591            3,
592        )
593        .unwrap();
594
595        // Should have 1 group-by field ($0) and 2 measures (sum($1), count($1))
596        assert_eq!(aggregate.grouping_expressions.len(), 1);
597        assert_eq!(aggregate.measures.len(), 2);
598
599        let emit_kind = &aggregate
600            .common
601            .as_ref()
602            .unwrap()
603            .emit_kind
604            .as_ref()
605            .unwrap();
606        let emit = match emit_kind {
607            EmitKind::Emit(emit) => &emit.output_mapping,
608            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
609        };
610        // Output mapping should be [1, 2] (measures only)
611        assert_eq!(emit, &[1, 2]);
612    }
613
614    #[test]
615    fn test_parse_aggregate_relation_no_group_by() {
616        let extensions = TestContext::new()
617            .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
618            .with_function(1, 10, "sum")
619            .with_function(1, 11, "count")
620            .extensions;
621
622        let aggregate = AggregateRel::parse_pair_with_context(
623            &extensions,
624            parse_exact(
625                Rule::aggregate_relation,
626                "Aggregate[_ => sum($0), count($1)]",
627            ),
628            vec![Box::new(example_read_relation().into_rel())],
629            3,
630        )
631        .unwrap();
632
633        // Should have 0 group-by fields and 2 measures
634        assert_eq!(aggregate.grouping_expressions.len(), 0);
635        assert_eq!(aggregate.measures.len(), 2);
636
637        let emit_kind = &aggregate
638            .common
639            .as_ref()
640            .unwrap()
641            .emit_kind
642            .as_ref()
643            .unwrap();
644        let emit = match emit_kind {
645            EmitKind::Emit(emit) => &emit.output_mapping,
646            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
647        };
648        // Output mapping should be [0, 1] (measures only, no group-by fields)
649        assert_eq!(emit, &[0, 1]);
650    }
651
652    #[test]
653    fn test_parse_aggregate_relation_empty_group_by() {
654        let extensions = TestContext::new()
655            .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
656            .with_function(1, 10, "sum")
657            .with_function(1, 11, "count")
658            .extensions;
659
660        let aggregate = AggregateRel::parse_pair_with_context(
661            &extensions,
662            parse_exact(
663                Rule::aggregate_relation,
664                "Aggregate[_ => sum($0), count($1)]",
665            ),
666            vec![Box::new(example_read_relation().into_rel())],
667            3,
668        )
669        .unwrap();
670
671        // Should have 0 group-by fields and 2 measures
672        assert_eq!(aggregate.grouping_expressions.len(), 0);
673        assert_eq!(aggregate.measures.len(), 2);
674
675        let emit_kind = &aggregate
676            .common
677            .as_ref()
678            .unwrap()
679            .emit_kind
680            .as_ref()
681            .unwrap();
682        let emit = match emit_kind {
683            EmitKind::Emit(emit) => &emit.output_mapping,
684            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
685        };
686        // Output mapping should be [0, 1] (measures only, no group-by fields)
687        assert_eq!(emit, &[0, 1]);
688    }
689
690    fn parse_exact(rule: Rule, input: &str) -> pest::iterators::Pair<Rule> {
691        let mut pairs = ExpressionParser::parse(rule, input).unwrap();
692        assert_eq!(pairs.as_str(), input);
693        let pair = pairs.next().unwrap();
694        assert_eq!(pairs.next(), None);
695        pair
696    }
697}