substrait_explain/parser/
expressions.rs

1use substrait::proto::aggregate_rel::Measure;
2use substrait::proto::expression::field_reference::ReferenceType;
3use substrait::proto::expression::literal::LiteralType;
4use substrait::proto::expression::{
5    FieldReference, Literal, ReferenceSegment, RexType, ScalarFunction, reference_segment,
6};
7use substrait::proto::function_argument::ArgType;
8use substrait::proto::r#type::{I64, Kind, Nullability};
9use substrait::proto::{AggregateFunction, Expression, FunctionArgument, Type};
10
11use super::types::get_and_validate_anchor;
12use super::{
13    MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
14    unwrap_single_pair,
15};
16use crate::extensions::SimpleExtensions;
17use crate::extensions::simple::ExtensionKind;
18use crate::parser::ErrorKind;
19
20/// Create a reference to a particular field.
21pub fn reference(index: i32) -> FieldReference {
22    // XXX: Why is it so many layers to make a struct field reference? This is
23    // surprisingly complex
24    FieldReference {
25        reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
26            reference_type: Some(reference_segment::ReferenceType::StructField(Box::new(
27                reference_segment::StructField {
28                    field: index,
29                    child: None,
30                },
31            ))),
32        })),
33        root_type: None,
34    }
35}
36
37impl ParsePair for FieldReference {
38    fn rule() -> Rule {
39        Rule::reference
40    }
41
42    fn message() -> &'static str {
43        "FieldReference"
44    }
45
46    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
47        assert_eq!(pair.as_rule(), Self::rule());
48        let inner = unwrap_single_pair(pair);
49        let index: i32 = inner.as_str().parse().unwrap();
50
51        // TODO: Other types of references.
52        reference(index)
53    }
54}
55
56fn to_int_literal(
57    value: pest::iterators::Pair<Rule>,
58    typ: Option<Type>,
59) -> Result<Literal, MessageParseError> {
60    assert_eq!(value.as_rule(), Rule::integer);
61    let parsed_value: i64 = value.as_str().parse().unwrap();
62
63    const DEFAULT_KIND: Kind = Kind::I64(I64 {
64        type_variation_reference: 0,
65        nullability: Nullability::Required as i32,
66    });
67
68    // If no type is provided, we assume i64, Nullability::Required.
69    let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
70
71    let (lit, nullability, tvar) = match &kind {
72        // If no type is provided, we assume i64, Nullability::Required.
73        Kind::I8(i) => (
74            LiteralType::I8(parsed_value as i32),
75            i.nullability,
76            i.type_variation_reference,
77        ),
78        Kind::I16(i) => (
79            LiteralType::I16(parsed_value as i32),
80            i.nullability,
81            i.type_variation_reference,
82        ),
83        Kind::I32(i) => (
84            LiteralType::I32(parsed_value as i32),
85            i.nullability,
86            i.type_variation_reference,
87        ),
88        Kind::I64(i) => (
89            LiteralType::I64(parsed_value),
90            i.nullability,
91            i.type_variation_reference,
92        ),
93        k => {
94            let pest_error = pest::error::Error::new_from_span(
95                pest::error::ErrorVariant::CustomError {
96                    message: format!("Invalid type for integer literal: {k:?}"),
97                },
98                value.as_span(),
99            );
100            let error = MessageParseError {
101                message: "int_literal_type",
102                kind: ErrorKind::InvalidValue,
103                error: Box::new(pest_error),
104            };
105            return Err(error);
106        }
107    };
108
109    Ok(Literal {
110        literal_type: Some(lit),
111        nullable: nullability != Nullability::Required as i32,
112        type_variation_reference: tvar,
113    })
114}
115
116impl ScopedParsePair for Literal {
117    fn rule() -> Rule {
118        Rule::literal
119    }
120
121    fn message() -> &'static str {
122        "Literal"
123    }
124
125    fn parse_pair(
126        extensions: &SimpleExtensions,
127        pair: pest::iterators::Pair<Rule>,
128    ) -> Result<Self, MessageParseError> {
129        assert_eq!(pair.as_rule(), Self::rule());
130        let mut pairs = pair.into_inner();
131        let value = pairs.next().unwrap(); // First item is always the value
132        let typ = pairs.next(); // Second item is optional type
133        assert!(pairs.next().is_none());
134        let typ = match typ {
135            Some(t) => Some(Type::parse_pair(extensions, t)?),
136            None => None,
137        };
138        match value.as_rule() {
139            Rule::integer => to_int_literal(value, typ),
140            Rule::string_literal => Ok(Literal {
141                literal_type: Some(LiteralType::String(unescape_string(value))),
142                nullable: false,
143                type_variation_reference: 0,
144            }),
145            _ => unreachable!("Literal unexpected rule: {:?}", value.as_rule()),
146        }
147    }
148}
149
150impl ScopedParsePair for ScalarFunction {
151    fn rule() -> Rule {
152        Rule::function_call
153    }
154
155    fn message() -> &'static str {
156        "ScalarFunction"
157    }
158
159    fn parse_pair(
160        extensions: &SimpleExtensions,
161        pair: pest::iterators::Pair<Rule>,
162    ) -> Result<Self, MessageParseError> {
163        assert_eq!(pair.as_rule(), Self::rule());
164        let span = pair.as_span();
165        let mut iter = RuleIter::from(pair.into_inner());
166
167        // Parse function name (required)
168        let name = iter.parse_next::<Name>();
169
170        // Parse optional anchor (e.g., #1)
171        let anchor = iter
172            .try_pop(Rule::anchor)
173            .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
174
175        // Parse optional URI anchor (e.g., @1)
176        let _uri_anchor = iter
177            .try_pop(Rule::uri_anchor)
178            .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
179
180        // Parse argument list (required)
181        let argument_list = iter.pop(Rule::argument_list);
182        let mut arguments = Vec::new();
183        for e in argument_list.into_inner() {
184            arguments.push(FunctionArgument {
185                arg_type: Some(ArgType::Value(Expression::parse_pair(extensions, e)?)),
186            });
187        }
188
189        // Parse optional output type (e.g., :i64)
190        let output_type = match iter.try_pop(Rule::r#type) {
191            Some(t) => Some(Type::parse_pair(extensions, t)?),
192            None => None,
193        };
194
195        iter.done();
196        let anchor =
197            get_and_validate_anchor(extensions, ExtensionKind::Function, anchor, &name.0, span)?;
198        Ok(ScalarFunction {
199            function_reference: anchor,
200            arguments,
201            options: vec![], // TODO: Function Options
202            output_type,
203            #[allow(deprecated)]
204            args: vec![],
205        })
206    }
207}
208
209impl ScopedParsePair for Expression {
210    fn rule() -> Rule {
211        Rule::expression
212    }
213
214    fn message() -> &'static str {
215        "Expression"
216    }
217
218    fn parse_pair(
219        extensions: &SimpleExtensions,
220        pair: pest::iterators::Pair<Rule>,
221    ) -> Result<Self, MessageParseError> {
222        assert_eq!(pair.as_rule(), Self::rule());
223        let inner = unwrap_single_pair(pair);
224
225        match inner.as_rule() {
226            Rule::literal => Ok(Expression {
227                rex_type: Some(RexType::Literal(Literal::parse_pair(extensions, inner)?)),
228            }),
229            Rule::function_call => Ok(Expression {
230                rex_type: Some(RexType::ScalarFunction(ScalarFunction::parse_pair(
231                    extensions, inner,
232                )?)),
233            }),
234            Rule::reference => Ok(Expression {
235                rex_type: Some(RexType::Selection(Box::new(FieldReference::parse_pair(
236                    inner,
237                )))),
238            }),
239            _ => unimplemented!("Expression unexpected rule: {:?}", inner.as_rule()),
240        }
241    }
242}
243
244pub struct Name(pub String);
245
246impl ParsePair for Name {
247    fn rule() -> Rule {
248        Rule::name
249    }
250
251    fn message() -> &'static str {
252        "Name"
253    }
254
255    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
256        assert_eq!(pair.as_rule(), Self::rule());
257        let inner = unwrap_single_pair(pair);
258        match inner.as_rule() {
259            Rule::identifier => Name(inner.as_str().to_string()),
260            Rule::quoted_name => Name(unescape_string(inner)),
261            _ => unreachable!("Name unexpected rule: {:?}", inner.as_rule()),
262        }
263    }
264}
265
266impl ScopedParsePair for Measure {
267    fn rule() -> Rule {
268        Rule::aggregate_measure
269    }
270
271    fn message() -> &'static str {
272        "Measure"
273    }
274
275    fn parse_pair(
276        extensions: &SimpleExtensions,
277        pair: pest::iterators::Pair<Rule>,
278    ) -> Result<Self, MessageParseError> {
279        assert_eq!(pair.as_rule(), Self::rule());
280
281        // Extract the inner function_call from aggregate_measure
282        let function_call_pair = unwrap_single_pair(pair);
283        assert_eq!(function_call_pair.as_rule(), Rule::function_call);
284
285        // Parse as ScalarFunction, then convert to AggregateFunction
286        let scalar = ScalarFunction::parse_pair(extensions, function_call_pair)?;
287        Ok(Measure {
288            measure: Some(AggregateFunction {
289                function_reference: scalar.function_reference,
290                arguments: scalar.arguments,
291                options: scalar.options,
292                output_type: scalar.output_type,
293                invocation: 0, // TODO: support invocation (ALL, DISTINCT, etc.)
294                phase: 0, // TODO: support phase (INITIAL_TO_RESULT, PARTIAL_TO_INTERMEDIATE, etc.)
295                sorts: vec![], // TODO: support sorts for ordered aggregates
296                #[allow(deprecated)]
297                args: scalar.args,
298            }),
299            filter: None, // TODO: support filter conditions on aggregate measures
300        })
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use pest::Parser as PestParser;
307
308    use super::*;
309    use crate::parser::ExpressionParser;
310
311    fn parse_exact(rule: Rule, input: &str) -> pest::iterators::Pair<Rule> {
312        let mut pairs = ExpressionParser::parse(rule, input).unwrap();
313        assert_eq!(pairs.as_str(), input);
314        let pair = pairs.next().unwrap();
315        assert_eq!(pairs.next(), None);
316        pair
317    }
318
319    fn assert_parses_to<T: ParsePair + PartialEq + std::fmt::Debug>(input: &str, expected: T) {
320        let pair = parse_exact(T::rule(), input);
321        let actual = T::parse_pair(pair);
322        assert_eq!(actual, expected);
323    }
324
325    fn assert_parses_with<T: ScopedParsePair + PartialEq + std::fmt::Debug>(
326        ext: &SimpleExtensions,
327        input: &str,
328        expected: T,
329    ) {
330        let pair = parse_exact(T::rule(), input);
331        let actual = T::parse_pair(ext, pair).unwrap();
332        assert_eq!(actual, expected);
333    }
334
335    #[test]
336    fn test_parse_field_reference() {
337        assert_parses_to("$1", reference(1));
338    }
339
340    #[test]
341    fn test_parse_integer_literal() {
342        let extensions = SimpleExtensions::default();
343        let expected = Literal {
344            literal_type: Some(LiteralType::I64(1)),
345            nullable: false,
346            type_variation_reference: 0,
347        };
348        assert_parses_with(&extensions, "1", expected);
349    }
350
351    // #[test]
352    // fn test_parse_string_literal() {
353    //     assert_parses_to("'hello'", Literal::String("hello".to_string()));
354    // }
355
356    // #[test]
357    // fn test_parse_function_call_simple() {
358    //     assert_parses_to(
359    //         "add()",
360    //         FunctionCall {
361    //             name: "add".to_string(),
362    //             parameters: None,
363    //             anchor: None,
364    //             uri_anchor: None,
365    //             arguments: vec![],
366    //         },
367    //     );
368    // }
369
370    // #[test]
371    // fn test_parse_function_call_with_parameters() {
372    //     assert_parses_to(
373    //         "add<param1, param2>()",
374    //         FunctionCall {
375    //             name: "add".to_string(),
376    //             parameters: Some(vec!["param1".to_string(), "param2".to_string()]),
377    //             anchor: None,
378    //             uri_anchor: None,
379    //             arguments: vec![],
380    //         },
381    //     );
382    // }
383
384    // #[test]
385    // fn test_parse_function_call_with_anchor() {
386    //     assert_parses_to(
387    //         "add#1()",
388    //         FunctionCall {
389    //             name: "add".to_string(),
390    //             parameters: None,
391    //             anchor: Some(1),
392    //             uri_anchor: None,
393    //             arguments: vec![],
394    //         },
395    //     );
396    // }
397
398    // #[test]
399    // fn test_parse_function_call_with_uri_anchor() {
400    //     assert_parses_to(
401    //         "add@1()",
402    //         FunctionCall {
403    //             name: "add".to_string(),
404    //             parameters: None,
405    //             anchor: None,
406    //             uri_anchor: Some(1),
407    //             arguments: vec![],
408    //         },
409    //     );
410    // }
411
412    // #[test]
413    // fn test_parse_function_call_all_optionals() {
414    //     assert_parses_to(
415    //         "add<param1, param2>#1@2()",
416    //         FunctionCall {
417    //             name: "add".to_string(),
418    //             parameters: Some(vec!["param1".to_string(), "param2".to_string()]),
419    //             anchor: Some(1),
420    //             uri_anchor: Some(2),
421    //             arguments: vec![],
422    //         },
423    //     );
424    // }
425
426    // #[test]
427    // fn test_parse_function_call_with_simple_arguments() {
428    //     assert_parses_to(
429    //         "add(1, 2)",
430    //         FunctionCall {
431    //             name: "add".to_string(),
432    //             parameters: None,
433    //             anchor: None,
434    //             uri_anchor: None,
435    //             arguments: vec![
436    //                 Expression::Literal(Literal::Integer(1)),
437    //                 Expression::Literal(Literal::Integer(2)),
438    //             ],
439    //         },
440    //     );
441    // }
442
443    // #[test]
444    // fn test_parse_function_call_with_nested_function() {
445    //     assert_parses_to(
446    //         "outer_func(inner_func(), $1)",
447    //         Expression::FunctionCall(Box::new(FunctionCall {
448    //             name: "outer_func".to_string(),
449    //             parameters: None,
450    //             anchor: None,
451    //             uri_anchor: None,
452    //             arguments: vec![
453    //                 Expression::FunctionCall(Box::new(FunctionCall {
454    //                     name: "inner_func".to_string(),
455    //                     parameters: None,
456    //                     anchor: None,
457    //                     uri_anchor: None,
458    //                     arguments: vec![],
459    //                 })),
460    //                 Expression::Reference(Reference(1)),
461    //             ],
462    //         })),
463    //     );
464    // }
465
466    // #[test]
467    // fn test_parse_function_call_funny_names() {
468    //     assert_parses_to(
469    //         "'funny name'<param1, param2>#1@2()",
470    //         FunctionCall {
471    //             name: "funny name".to_string(),
472    //             parameters: Some(vec!["param1".to_string(), "param2".to_string()]),
473    //             anchor: Some(1),
474    //             uri_anchor: Some(2),
475    //             arguments: vec![],
476    //         },
477    //     );
478    // }
479
480    // #[test]
481    // fn test_parse_empty_string_literal() {
482    //     assert_parses_to("''", Literal::String("".to_string()));
483    // }
484}