1use crate::expression::{BindingId, BoolExpression, Expression, RealExpression};
2use crate::StringExpression;
3use num_traits::Float;
4use once_cell::sync::Lazy;
5use pest::iterators::Pairs;
6use pest::pratt_parser::{Assoc, Op, PrattParser};
7use pest::Parser;
8use pest_derive::Parser;
9use std::collections::HashSet;
10use std::str::FromStr;
11
12#[derive(Parser)]
13#[grammar = "grammar.pest"] struct ExpressionParser;
15
16pub type ParseError = Box<pest::error::Error<Rule>>;
18
19impl<Real: Float + FromStr> Expression<Real> {
20 pub fn unwrap_real(self) -> RealExpression<Real> {
22 match self {
23 Self::Real(r) => r,
24 _ => panic!("Expected Real"),
25 }
26 }
27
28 pub fn unwrap_string(self) -> StringExpression {
30 match self {
31 Self::String(s) => s,
32 _ => panic!("Expected String"),
33 }
34 }
35
36 pub fn unwrap_bool(self) -> BoolExpression<Real> {
38 match self {
39 Self::Boolean(b) => b,
40 _ => panic!("Expected Boolean"),
41 }
42 }
43
44 pub fn parse_real_variable_names(input: &str) -> Result<HashSet<String>, ParseError> {
45 Ok(ExpressionParser::parse(Rule::calculation, input)?
46 .flatten()
47 .filter(|p| (p.as_rule() == Rule::real_variable))
48 .map(|p| p.as_str().to_string())
49 .collect())
50 }
51
52 pub fn parse_string_variable_names(input: &str) -> Result<HashSet<String>, ParseError> {
53 Ok(ExpressionParser::parse(Rule::calculation, input)?
54 .flatten()
55 .filter(|p| (p.as_rule() == Rule::str_variable))
56 .map(|p| p.as_str().to_string())
57 .collect())
58 }
59
60 pub fn parse(input: &str, binding_map: impl Fn(&str) -> BindingId) -> Result<Self, ParseError> {
68 let mut pairs = ExpressionParser::parse(Rule::calculation, input)?;
69 let inner_expr = pairs.next().unwrap().into_inner();
71 Ok(parse_recursive(inner_expr, &binding_map))
72 }
73}
74
75static PRATT_PARSER: Lazy<PrattParser<Rule>> = Lazy::new(|| {
76 use Assoc::*;
77 use Rule::*;
78
79 PrattParser::new()
80 .op(Op::infix(and, Left) | Op::infix(or, Left))
81 .op(Op::infix(str_eq, Left)
82 | Op::infix(str_neq, Left)
83 | Op::infix(real_eq, Left)
84 | Op::infix(real_neq, Left)
85 | Op::infix(less, Left)
86 | Op::infix(le, Left)
87 | Op::infix(greater, Left)
88 | Op::infix(ge, Left))
89 .op(Op::infix(add, Left) | Op::infix(subtract, Left))
90 .op(Op::infix(multiply, Left) | Op::infix(divide, Left))
91 .op(Op::infix(power, Right))
92});
93
94fn parse_recursive<Real: FromStr + Float>(
95 pairs: Pairs<Rule>,
96 binding_map: &impl Fn(&str) -> BindingId,
97) -> Expression<Real> {
98 PRATT_PARSER
99 .map_primary(|pair| match pair.as_rule() {
100 Rule::bool_expr => parse_recursive(pair.into_inner(), binding_map),
101 Rule::real_expr => parse_recursive(pair.into_inner(), binding_map),
102 Rule::string_expr => parse_recursive(pair.into_inner(), binding_map),
103 Rule::real_literal => {
104 let literal_str = pair.as_str();
105 if let Ok(value) = literal_str.parse::<Real>() {
106 return Expression::Real(RealExpression::Literal(value));
107 }
108 panic!("Unexpected literal: {}", literal_str)
109 }
110 Rule::string_literal => parse_recursive(pair.into_inner(), binding_map),
111 Rule::string_literal_value => {
112 let literal_str = pair.as_str();
113 if let Ok(value) = literal_str.parse::<String>() {
114 return Expression::String(StringExpression::Literal(value));
115 }
116 panic!("Unexpected literal: {}", literal_str)
117 }
118 Rule::unary_real_op_expr => {
119 let mut inner = pair.into_inner();
120 let unary = inner.next().unwrap();
121 match unary.as_rule() {
122 Rule::neg => Expression::Real(RealExpression::Neg(Box::new(
123 parse_recursive(inner, binding_map).unwrap_real(),
124 ))),
125 x => panic!("Unexpected unary logic operator: {x:?}"),
126 }
127 }
128 Rule::unary_logic_expr => {
129 let mut inner = pair.into_inner();
130 let unary = inner.next().unwrap();
131 match unary.as_rule() {
132 Rule::not => Expression::Boolean(BoolExpression::Not(Box::new(
133 parse_recursive(inner, binding_map).unwrap_bool(),
134 ))),
135 x => panic!("Unexpected unary logic operator: {x:?}"),
136 }
137 }
138 Rule::real_variable => {
139 Expression::Real(RealExpression::Binding(binding_map(pair.as_str())))
140 }
141 Rule::str_variable => {
142 Expression::String(StringExpression::Binding(binding_map(pair.as_str())))
143 }
144 x => panic!("Unexpected primary rule {x:?}"),
145 })
146 .map_infix(|lhs, op, rhs| match op.as_rule() {
147 Rule::add => Expression::Real(RealExpression::Add(
148 Box::new(lhs.unwrap_real()),
149 Box::new(rhs.unwrap_real()),
150 )),
151 Rule::subtract => Expression::Real(RealExpression::Sub(
152 Box::new(lhs.unwrap_real()),
153 Box::new(rhs.unwrap_real()),
154 )),
155 Rule::multiply => Expression::Real(RealExpression::Mul(
156 Box::new(lhs.unwrap_real()),
157 Box::new(rhs.unwrap_real()),
158 )),
159 Rule::divide => Expression::Real(RealExpression::Div(
160 Box::new(lhs.unwrap_real()),
161 Box::new(rhs.unwrap_real()),
162 )),
163 Rule::power => Expression::Real(RealExpression::Pow(
164 Box::new(lhs.unwrap_real()),
165 Box::new(rhs.unwrap_real()),
166 )),
167 Rule::real_eq => Expression::Boolean(BoolExpression::Equal(
168 Box::new(lhs.unwrap_real()),
169 Box::new(rhs.unwrap_real()),
170 )),
171 Rule::real_neq => Expression::Boolean(BoolExpression::NotEqual(
172 Box::new(lhs.unwrap_real()),
173 Box::new(rhs.unwrap_real()),
174 )),
175 Rule::str_eq => Expression::Boolean(BoolExpression::StrEqual(
176 lhs.unwrap_string(),
177 rhs.unwrap_string(),
178 )),
179 Rule::str_neq => Expression::Boolean(BoolExpression::StrNotEqual(
180 lhs.unwrap_string(),
181 rhs.unwrap_string(),
182 )),
183 Rule::less => Expression::Boolean(BoolExpression::Less(
184 Box::new(lhs.unwrap_real()),
185 Box::new(rhs.unwrap_real()),
186 )),
187 Rule::le => Expression::Boolean(BoolExpression::LessEqual(
188 Box::new(lhs.unwrap_real()),
189 Box::new(rhs.unwrap_real()),
190 )),
191 Rule::greater => Expression::Boolean(BoolExpression::Greater(
192 Box::new(lhs.unwrap_real()),
193 Box::new(rhs.unwrap_real()),
194 )),
195 Rule::ge => Expression::Boolean(BoolExpression::GreaterEqual(
196 Box::new(lhs.unwrap_real()),
197 Box::new(rhs.unwrap_real()),
198 )),
199 Rule::and => Expression::Boolean(BoolExpression::And(
200 Box::new(lhs.unwrap_bool()),
201 Box::new(rhs.unwrap_bool()),
202 )),
203 Rule::or => Expression::Boolean(BoolExpression::Or(
204 Box::new(lhs.unwrap_bool()),
205 Box::new(rhs.unwrap_bool()),
206 )),
207 x => panic!("Unexpected operator {x:?}"),
208 })
209 .parse(pairs)
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn parse_variable_names() {
218 let vars = Expression::<f32>::parse_real_variable_names("v1_dest + x + y + z99").unwrap();
219 assert!(vars.contains("x"), "{vars:?}");
220 assert!(vars.contains("y"), "{vars:?}");
221 assert!(vars.contains("z99"), "{vars:?}");
222 assert!(vars.contains("v1_dest"), "{vars:?}");
223 let vars = Expression::<f32>::parse_string_variable_names("x == \"W\"").unwrap();
224 assert!(vars.contains("x"), "{vars:?}");
225 }
226
227 #[test]
228 fn parse_comparisons() {
229 fn binding_map(var_name: &str) -> BindingId {
230 match var_name {
231 "x" => 0,
232 "y" => 1,
233 _ => unreachable!(),
234 }
235 }
236 Expression::<f32>::parse("x == y", binding_map).unwrap();
237 Expression::<f32>::parse("x != y", binding_map).unwrap();
238 Expression::<f32>::parse("x > y", binding_map).unwrap();
239 Expression::<f32>::parse("x < y", binding_map).unwrap();
240 Expression::<f32>::parse("x <= y", binding_map).unwrap();
241 Expression::<f32>::parse("x >= y", binding_map).unwrap();
242 }
243}