substrait_explain/parser/
expressions.rs1use substrait::proto::aggregate_rel::Measure;
2use substrait::proto::expression::field_reference::ReferenceType;
3use substrait::proto::expression::literal::LiteralType;
4use substrait::proto::expression::{
5 FieldReference, Literal, ReferenceSegment, RexType, ScalarFunction, reference_segment,
6};
7use substrait::proto::function_argument::ArgType;
8use substrait::proto::r#type::{I64, Kind, Nullability};
9use substrait::proto::{AggregateFunction, Expression, FunctionArgument, Type};
10
11use super::types::get_and_validate_anchor;
12use super::{
13 MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
14 unwrap_single_pair,
15};
16use crate::extensions::SimpleExtensions;
17use crate::extensions::simple::ExtensionKind;
18use crate::parser::ErrorKind;
19
20pub fn reference(index: i32) -> FieldReference {
22 FieldReference {
25 reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
26 reference_type: Some(reference_segment::ReferenceType::StructField(Box::new(
27 reference_segment::StructField {
28 field: index,
29 child: None,
30 },
31 ))),
32 })),
33 root_type: None,
34 }
35}
36
37impl ParsePair for FieldReference {
38 fn rule() -> Rule {
39 Rule::reference
40 }
41
42 fn message() -> &'static str {
43 "FieldReference"
44 }
45
46 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
47 assert_eq!(pair.as_rule(), Self::rule());
48 let inner = unwrap_single_pair(pair);
49 let index: i32 = inner.as_str().parse().unwrap();
50
51 reference(index)
53 }
54}
55
56fn to_int_literal(
57 value: pest::iterators::Pair<Rule>,
58 typ: Option<Type>,
59) -> Result<Literal, MessageParseError> {
60 assert_eq!(value.as_rule(), Rule::integer);
61 let parsed_value: i64 = value.as_str().parse().unwrap();
62
63 const DEFAULT_KIND: Kind = Kind::I64(I64 {
64 type_variation_reference: 0,
65 nullability: Nullability::Required as i32,
66 });
67
68 let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
70
71 let (lit, nullability, tvar) = match &kind {
72 Kind::I8(i) => (
74 LiteralType::I8(parsed_value as i32),
75 i.nullability,
76 i.type_variation_reference,
77 ),
78 Kind::I16(i) => (
79 LiteralType::I16(parsed_value as i32),
80 i.nullability,
81 i.type_variation_reference,
82 ),
83 Kind::I32(i) => (
84 LiteralType::I32(parsed_value as i32),
85 i.nullability,
86 i.type_variation_reference,
87 ),
88 Kind::I64(i) => (
89 LiteralType::I64(parsed_value),
90 i.nullability,
91 i.type_variation_reference,
92 ),
93 k => {
94 let pest_error = pest::error::Error::new_from_span(
95 pest::error::ErrorVariant::CustomError {
96 message: format!("Invalid type for integer literal: {k:?}"),
97 },
98 value.as_span(),
99 );
100 let error = MessageParseError {
101 message: "int_literal_type",
102 kind: ErrorKind::InvalidValue,
103 error: Box::new(pest_error),
104 };
105 return Err(error);
106 }
107 };
108
109 Ok(Literal {
110 literal_type: Some(lit),
111 nullable: nullability != Nullability::Required as i32,
112 type_variation_reference: tvar,
113 })
114}
115
116impl ScopedParsePair for Literal {
117 fn rule() -> Rule {
118 Rule::literal
119 }
120
121 fn message() -> &'static str {
122 "Literal"
123 }
124
125 fn parse_pair(
126 extensions: &SimpleExtensions,
127 pair: pest::iterators::Pair<Rule>,
128 ) -> Result<Self, MessageParseError> {
129 assert_eq!(pair.as_rule(), Self::rule());
130 let mut pairs = pair.into_inner();
131 let value = pairs.next().unwrap(); let typ = pairs.next(); assert!(pairs.next().is_none());
134 let typ = match typ {
135 Some(t) => Some(Type::parse_pair(extensions, t)?),
136 None => None,
137 };
138 match value.as_rule() {
139 Rule::integer => to_int_literal(value, typ),
140 Rule::string_literal => Ok(Literal {
141 literal_type: Some(LiteralType::String(unescape_string(value))),
142 nullable: false,
143 type_variation_reference: 0,
144 }),
145 _ => unreachable!("Literal unexpected rule: {:?}", value.as_rule()),
146 }
147 }
148}
149
150impl ScopedParsePair for ScalarFunction {
151 fn rule() -> Rule {
152 Rule::function_call
153 }
154
155 fn message() -> &'static str {
156 "ScalarFunction"
157 }
158
159 fn parse_pair(
160 extensions: &SimpleExtensions,
161 pair: pest::iterators::Pair<Rule>,
162 ) -> Result<Self, MessageParseError> {
163 assert_eq!(pair.as_rule(), Self::rule());
164 let span = pair.as_span();
165 let mut iter = RuleIter::from(pair.into_inner());
166
167 let name = iter.parse_next::<Name>();
169
170 let anchor = iter
172 .try_pop(Rule::anchor)
173 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
174
175 let _uri_anchor = iter
177 .try_pop(Rule::uri_anchor)
178 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
179
180 let argument_list = iter.pop(Rule::argument_list);
182 let mut arguments = Vec::new();
183 for e in argument_list.into_inner() {
184 arguments.push(FunctionArgument {
185 arg_type: Some(ArgType::Value(Expression::parse_pair(extensions, e)?)),
186 });
187 }
188
189 let output_type = match iter.try_pop(Rule::r#type) {
191 Some(t) => Some(Type::parse_pair(extensions, t)?),
192 None => None,
193 };
194
195 iter.done();
196 let anchor =
197 get_and_validate_anchor(extensions, ExtensionKind::Function, anchor, &name.0, span)?;
198 Ok(ScalarFunction {
199 function_reference: anchor,
200 arguments,
201 options: vec![], output_type,
203 #[allow(deprecated)]
204 args: vec![],
205 })
206 }
207}
208
209impl ScopedParsePair for Expression {
210 fn rule() -> Rule {
211 Rule::expression
212 }
213
214 fn message() -> &'static str {
215 "Expression"
216 }
217
218 fn parse_pair(
219 extensions: &SimpleExtensions,
220 pair: pest::iterators::Pair<Rule>,
221 ) -> Result<Self, MessageParseError> {
222 assert_eq!(pair.as_rule(), Self::rule());
223 let inner = unwrap_single_pair(pair);
224
225 match inner.as_rule() {
226 Rule::literal => Ok(Expression {
227 rex_type: Some(RexType::Literal(Literal::parse_pair(extensions, inner)?)),
228 }),
229 Rule::function_call => Ok(Expression {
230 rex_type: Some(RexType::ScalarFunction(ScalarFunction::parse_pair(
231 extensions, inner,
232 )?)),
233 }),
234 Rule::reference => Ok(Expression {
235 rex_type: Some(RexType::Selection(Box::new(FieldReference::parse_pair(
236 inner,
237 )))),
238 }),
239 _ => unimplemented!("Expression unexpected rule: {:?}", inner.as_rule()),
240 }
241 }
242}
243
244pub struct Name(pub String);
245
246impl ParsePair for Name {
247 fn rule() -> Rule {
248 Rule::name
249 }
250
251 fn message() -> &'static str {
252 "Name"
253 }
254
255 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
256 assert_eq!(pair.as_rule(), Self::rule());
257 let inner = unwrap_single_pair(pair);
258 match inner.as_rule() {
259 Rule::identifier => Name(inner.as_str().to_string()),
260 Rule::quoted_name => Name(unescape_string(inner)),
261 _ => unreachable!("Name unexpected rule: {:?}", inner.as_rule()),
262 }
263 }
264}
265
266impl ScopedParsePair for Measure {
267 fn rule() -> Rule {
268 Rule::aggregate_measure
269 }
270
271 fn message() -> &'static str {
272 "Measure"
273 }
274
275 fn parse_pair(
276 extensions: &SimpleExtensions,
277 pair: pest::iterators::Pair<Rule>,
278 ) -> Result<Self, MessageParseError> {
279 assert_eq!(pair.as_rule(), Self::rule());
280
281 let function_call_pair = unwrap_single_pair(pair);
283 assert_eq!(function_call_pair.as_rule(), Rule::function_call);
284
285 let scalar = ScalarFunction::parse_pair(extensions, function_call_pair)?;
287 Ok(Measure {
288 measure: Some(AggregateFunction {
289 function_reference: scalar.function_reference,
290 arguments: scalar.arguments,
291 options: scalar.options,
292 output_type: scalar.output_type,
293 invocation: 0, phase: 0, sorts: vec![], #[allow(deprecated)]
297 args: scalar.args,
298 }),
299 filter: None, })
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use pest::Parser as PestParser;
307
308 use super::*;
309 use crate::parser::ExpressionParser;
310
311 fn parse_exact(rule: Rule, input: &str) -> pest::iterators::Pair<Rule> {
312 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
313 assert_eq!(pairs.as_str(), input);
314 let pair = pairs.next().unwrap();
315 assert_eq!(pairs.next(), None);
316 pair
317 }
318
319 fn assert_parses_to<T: ParsePair + PartialEq + std::fmt::Debug>(input: &str, expected: T) {
320 let pair = parse_exact(T::rule(), input);
321 let actual = T::parse_pair(pair);
322 assert_eq!(actual, expected);
323 }
324
325 fn assert_parses_with<T: ScopedParsePair + PartialEq + std::fmt::Debug>(
326 ext: &SimpleExtensions,
327 input: &str,
328 expected: T,
329 ) {
330 let pair = parse_exact(T::rule(), input);
331 let actual = T::parse_pair(ext, pair).unwrap();
332 assert_eq!(actual, expected);
333 }
334
335 #[test]
336 fn test_parse_field_reference() {
337 assert_parses_to("$1", reference(1));
338 }
339
340 #[test]
341 fn test_parse_integer_literal() {
342 let extensions = SimpleExtensions::default();
343 let expected = Literal {
344 literal_type: Some(LiteralType::I64(1)),
345 nullable: false,
346 type_variation_reference: 0,
347 };
348 assert_parses_with(&extensions, "1", expected);
349 }
350
351 }