vector_expr/
parse.rs

1use crate::expression::{BindingId, BoolExpression, Expression, RealExpression};
2use crate::StringExpression;
3use num_traits::Float;
4use once_cell::sync::Lazy;
5use pest::iterators::Pairs;
6use pest::pratt_parser::{Assoc, Op, PrattParser};
7use pest::Parser;
8use pest_derive::Parser;
9use std::collections::HashSet;
10use std::str::FromStr;
11
12#[derive(Parser)]
13#[grammar = "grammar.pest"] // relative to project `src`
14struct ExpressionParser;
15
16// Boxed because error is much larger than Ok variant in most results.
17pub type ParseError = Box<pest::error::Error<Rule>>;
18
19impl<Real: Float + FromStr> Expression<Real> {
20    /// Assume this expression is real-valued.
21    pub fn unwrap_real(self) -> RealExpression<Real> {
22        match self {
23            Self::Real(r) => r,
24            _ => panic!("Expected Real"),
25        }
26    }
27
28    /// Assume this expression is string-valued.
29    pub fn unwrap_string(self) -> StringExpression {
30        match self {
31            Self::String(s) => s,
32            _ => panic!("Expected String"),
33        }
34    }
35
36    /// Assume this expression is boolean-valued.
37    pub fn unwrap_bool(self) -> BoolExpression<Real> {
38        match self {
39            Self::Boolean(b) => b,
40            _ => panic!("Expected Boolean"),
41        }
42    }
43
44    pub fn parse_real_variable_names(input: &str) -> Result<HashSet<String>, ParseError> {
45        Ok(ExpressionParser::parse(Rule::calculation, input)?
46            .flatten()
47            .filter(|p| (p.as_rule() == Rule::real_variable))
48            .map(|p| p.as_str().to_string())
49            .collect())
50    }
51
52    pub fn parse_string_variable_names(input: &str) -> Result<HashSet<String>, ParseError> {
53        Ok(ExpressionParser::parse(Rule::calculation, input)?
54            .flatten()
55            .filter(|p| (p.as_rule() == Rule::str_variable))
56            .map(|p| p.as_str().to_string())
57            .collect())
58    }
59
60    /// Parse the expression from `input`.
61    ///
62    /// `binding_map` determines which variable name maps to each data binding.
63    /// As variable names are encountered during parsing, they are replaced by
64    /// [`BindingId`]s in the [`Expression`] syntax tree. This allows the
65    /// [`Expression`] to be efficiently reused with many different data
66    /// bindings.
67    pub fn parse(input: &str, binding_map: impl Fn(&str) -> BindingId) -> Result<Self, ParseError> {
68        let mut pairs = ExpressionParser::parse(Rule::calculation, input)?;
69        // HACK: Working around https://github.com/pest-parser/pest/issues/943
70        let inner_expr = pairs.next().unwrap().into_inner();
71        Ok(parse_recursive(inner_expr, &binding_map))
72    }
73}
74
75static PRATT_PARSER: Lazy<PrattParser<Rule>> = Lazy::new(|| {
76    use Assoc::*;
77    use Rule::*;
78
79    PrattParser::new()
80        .op(Op::infix(and, Left) | Op::infix(or, Left))
81        .op(Op::infix(str_eq, Left)
82            | Op::infix(str_neq, Left)
83            | Op::infix(real_eq, Left)
84            | Op::infix(real_neq, Left)
85            | Op::infix(less, Left)
86            | Op::infix(le, Left)
87            | Op::infix(greater, Left)
88            | Op::infix(ge, Left))
89        .op(Op::infix(add, Left) | Op::infix(subtract, Left))
90        .op(Op::infix(multiply, Left) | Op::infix(divide, Left))
91        .op(Op::infix(power, Right))
92});
93
94fn parse_recursive<Real: FromStr + Float>(
95    pairs: Pairs<Rule>,
96    binding_map: &impl Fn(&str) -> BindingId,
97) -> Expression<Real> {
98    PRATT_PARSER
99        .map_primary(|pair| match pair.as_rule() {
100            Rule::bool_expr => parse_recursive(pair.into_inner(), binding_map),
101            Rule::real_expr => parse_recursive(pair.into_inner(), binding_map),
102            Rule::string_expr => parse_recursive(pair.into_inner(), binding_map),
103            Rule::real_literal => {
104                let literal_str = pair.as_str();
105                if let Ok(value) = literal_str.parse::<Real>() {
106                    return Expression::Real(RealExpression::Literal(value));
107                }
108                panic!("Unexpected literal: {}", literal_str)
109            }
110            Rule::string_literal => parse_recursive(pair.into_inner(), binding_map),
111            Rule::string_literal_value => {
112                let literal_str = pair.as_str();
113                if let Ok(value) = literal_str.parse::<String>() {
114                    return Expression::String(StringExpression::Literal(value));
115                }
116                panic!("Unexpected literal: {}", literal_str)
117            }
118            Rule::unary_real_op_expr => {
119                let mut inner = pair.into_inner();
120                let unary = inner.next().unwrap();
121                match unary.as_rule() {
122                    Rule::neg => Expression::Real(RealExpression::Neg(Box::new(
123                        parse_recursive(inner, binding_map).unwrap_real(),
124                    ))),
125                    x => panic!("Unexpected unary logic operator: {x:?}"),
126                }
127            }
128            Rule::unary_logic_expr => {
129                let mut inner = pair.into_inner();
130                let unary = inner.next().unwrap();
131                match unary.as_rule() {
132                    Rule::not => Expression::Boolean(BoolExpression::Not(Box::new(
133                        parse_recursive(inner, binding_map).unwrap_bool(),
134                    ))),
135                    x => panic!("Unexpected unary logic operator: {x:?}"),
136                }
137            }
138            Rule::real_variable => {
139                Expression::Real(RealExpression::Binding(binding_map(pair.as_str())))
140            }
141            Rule::str_variable => {
142                Expression::String(StringExpression::Binding(binding_map(pair.as_str())))
143            }
144            x => panic!("Unexpected primary rule {x:?}"),
145        })
146        .map_infix(|lhs, op, rhs| match op.as_rule() {
147            Rule::add => Expression::Real(RealExpression::Add(
148                Box::new(lhs.unwrap_real()),
149                Box::new(rhs.unwrap_real()),
150            )),
151            Rule::subtract => Expression::Real(RealExpression::Sub(
152                Box::new(lhs.unwrap_real()),
153                Box::new(rhs.unwrap_real()),
154            )),
155            Rule::multiply => Expression::Real(RealExpression::Mul(
156                Box::new(lhs.unwrap_real()),
157                Box::new(rhs.unwrap_real()),
158            )),
159            Rule::divide => Expression::Real(RealExpression::Div(
160                Box::new(lhs.unwrap_real()),
161                Box::new(rhs.unwrap_real()),
162            )),
163            Rule::power => Expression::Real(RealExpression::Pow(
164                Box::new(lhs.unwrap_real()),
165                Box::new(rhs.unwrap_real()),
166            )),
167            Rule::real_eq => Expression::Boolean(BoolExpression::Equal(
168                Box::new(lhs.unwrap_real()),
169                Box::new(rhs.unwrap_real()),
170            )),
171            Rule::real_neq => Expression::Boolean(BoolExpression::NotEqual(
172                Box::new(lhs.unwrap_real()),
173                Box::new(rhs.unwrap_real()),
174            )),
175            Rule::str_eq => Expression::Boolean(BoolExpression::StrEqual(
176                lhs.unwrap_string(),
177                rhs.unwrap_string(),
178            )),
179            Rule::str_neq => Expression::Boolean(BoolExpression::StrNotEqual(
180                lhs.unwrap_string(),
181                rhs.unwrap_string(),
182            )),
183            Rule::less => Expression::Boolean(BoolExpression::Less(
184                Box::new(lhs.unwrap_real()),
185                Box::new(rhs.unwrap_real()),
186            )),
187            Rule::le => Expression::Boolean(BoolExpression::LessEqual(
188                Box::new(lhs.unwrap_real()),
189                Box::new(rhs.unwrap_real()),
190            )),
191            Rule::greater => Expression::Boolean(BoolExpression::Greater(
192                Box::new(lhs.unwrap_real()),
193                Box::new(rhs.unwrap_real()),
194            )),
195            Rule::ge => Expression::Boolean(BoolExpression::GreaterEqual(
196                Box::new(lhs.unwrap_real()),
197                Box::new(rhs.unwrap_real()),
198            )),
199            Rule::and => Expression::Boolean(BoolExpression::And(
200                Box::new(lhs.unwrap_bool()),
201                Box::new(rhs.unwrap_bool()),
202            )),
203            Rule::or => Expression::Boolean(BoolExpression::Or(
204                Box::new(lhs.unwrap_bool()),
205                Box::new(rhs.unwrap_bool()),
206            )),
207            x => panic!("Unexpected operator {x:?}"),
208        })
209        .parse(pairs)
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn parse_variable_names() {
218        let vars = Expression::<f32>::parse_real_variable_names("v1_dest + x + y + z99").unwrap();
219        assert!(vars.contains("x"), "{vars:?}");
220        assert!(vars.contains("y"), "{vars:?}");
221        assert!(vars.contains("z99"), "{vars:?}");
222        assert!(vars.contains("v1_dest"), "{vars:?}");
223        let vars = Expression::<f32>::parse_string_variable_names("x == \"W\"").unwrap();
224        assert!(vars.contains("x"), "{vars:?}");
225    }
226
227    #[test]
228    fn parse_comparisons() {
229        fn binding_map(var_name: &str) -> BindingId {
230            match var_name {
231                "x" => 0,
232                "y" => 1,
233                _ => unreachable!(),
234            }
235        }
236        Expression::<f32>::parse("x == y", binding_map).unwrap();
237        Expression::<f32>::parse("x != y", binding_map).unwrap();
238        Expression::<f32>::parse("x > y", binding_map).unwrap();
239        Expression::<f32>::parse("x < y", binding_map).unwrap();
240        Expression::<f32>::parse("x <= y", binding_map).unwrap();
241        Expression::<f32>::parse("x >= y", binding_map).unwrap();
242    }
243}