simple_expressions/
parser.rs

1use crate::types::error::{Error, Result};
2use crate::types::expression::{BinaryOp, Expr, UnaryOp};
3use crate::types::primitive::Primitive;
4use pest::Parser;
5use pest::iterators::Pair;
6use pest::pratt_parser::{Assoc, Op, PrattParser};
7
8#[derive(pest_derive::Parser)]
9#[grammar = "expr.pest"]
10struct InnerParser;
11
12pub fn parse_expression(input: &str) -> Result<Expr> {
13    parse_internal(input, Rule::program).map(|r| r.0)
14}
15
16pub(crate) fn parse_internal(input: &str, rule: Rule) -> Result<(Expr, usize)> {
17    let mut pairs = InnerParser::parse(rule, input).map_err(|e| Error::ParseError(format!("parse error: {}", e)))?;
18    let pair = pairs.next().expect("program always produces one pair");
19
20    debug_assert_eq!(pair.as_rule(), rule);
21    let end_pos = pair.as_span().end_pos().pos();
22    let expr_pair = pair.into_inner().next().expect("program contains expr");
23    let expr = parse_expr(expr_pair)?;
24    Ok((expr, end_pos))
25}
26
27fn pratt() -> PrattParser<Rule> {
28    PrattParser::new()
29        .op(Op::infix(Rule::op_or, Assoc::Left))
30        .op(Op::infix(Rule::op_and, Assoc::Left))
31        .op(Op::infix(Rule::op_eq, Assoc::Left))
32        .op(Op::infix(Rule::op_cmp, Assoc::Left))
33        .op(Op::infix(Rule::op_add, Assoc::Left))
34        .op(Op::infix(Rule::op_mul, Assoc::Left))
35        .op(Op::infix(Rule::op_pow, Assoc::Right))
36}
37
38fn parse_expr(pair: Pair<Rule>) -> Result<Expr> {
39    match pair.as_rule() {
40        Rule::expr => {
41            let pairs = pair.into_inner();
42            pratt()
43                .map_primary(|p: Pair<Rule>| parse_unary(p))
44                .map_infix(|lhs: Result<Expr>, op: Pair<Rule>, rhs: Result<Expr>| {
45                    let left = lhs?;
46                    let right = rhs?;
47                    let mut l = left;
48                    let mut r = right;
49                    let bop = match op.as_rule() {
50                        Rule::op_or => BinaryOp::Or,
51                        Rule::op_and => BinaryOp::And,
52                        Rule::op_eq => {
53                            let s = op.as_str();
54                            if s.contains("==") { BinaryOp::Eq } else { BinaryOp::Ne }
55                        }
56                        Rule::op_cmp => {
57                            let s = op.as_str();
58                            if s.contains("<=") {
59                                // a <= b  ==>  b >= a
60                                std::mem::swap(&mut l, &mut r);
61                                BinaryOp::Ge
62                            } else if s.contains(">=") {
63                                BinaryOp::Ge
64                            } else if s.contains('<') {
65                                BinaryOp::Lt
66                            } else {
67                                BinaryOp::Gt
68                            }
69                        }
70                        Rule::op_add => {
71                            if op.as_str().contains('-') {
72                                BinaryOp::Sub
73                            } else {
74                                BinaryOp::Add
75                            }
76                        }
77                        Rule::op_mul => {
78                            let s = op.as_str();
79                            if s.contains('*') {
80                                BinaryOp::Mul
81                            } else if s.contains('/') {
82                                BinaryOp::Div
83                            } else {
84                                BinaryOp::Mod
85                            }
86                        }
87                        Rule::op_pow => BinaryOp::Pow,
88                        r => {
89                            return Err(Error::InternalParserError(format!("unexpected infix op: {:?}", r)));
90                        }
91                    };
92                    Ok(Expr::Binary {
93                        left: Box::new(l),
94                        op: bop,
95                        right: Box::new(r),
96                    })
97                })
98                .parse(pairs)
99        }
100        _ => Err(Error::InternalParserError(format!("expected expr, got: {:?}", pair))),
101    }
102}
103
104fn parse_unary(pair: Pair<Rule>) -> Result<Expr> {
105    match pair.as_rule() {
106        Rule::unary => {
107            let mut ops: Vec<UnaryOp> = Vec::new();
108            let mut inner = pair.into_inner();
109            // Collect zero or more unary_op then the postfix expression
110            loop {
111                let Some(next) = inner.peek() else { break };
112                match next.as_rule() {
113                    Rule::unary_op => {
114                        let op_pair = inner.next().unwrap();
115                        let op_inner = op_pair.into_inner().next().unwrap();
116                        let op = match op_inner.as_rule() {
117                            Rule::not_op => UnaryOp::Not,
118                            Rule::neg_op => UnaryOp::Neg,
119                            r => {
120                                return Err(Error::InternalParserError(format!("unexpected unary op: {:?}", r)));
121                            }
122                        };
123                        ops.push(op);
124                    }
125                    _ => break,
126                }
127            }
128            let post = inner.next().expect("unary must end with postfix");
129            let mut expr = parse_postfix(post)?;
130            for op in ops.into_iter().rev() {
131                expr = Expr::Unary { op, expr: Box::new(expr) };
132            }
133            Ok(expr)
134        }
135        _ => parse_postfix(pair),
136    }
137}
138
139fn parse_postfix(pair: Pair<Rule>) -> Result<Expr> {
140    match pair.as_rule() {
141        Rule::postfix => {
142            let mut inner = pair.into_inner();
143            let first = inner.next().expect("postfix starts with primary");
144            let mut expr = parse_primary(first)?;
145            for next in inner {
146                match next.as_rule() {
147                    Rule::call => {
148                        let args = parse_call_args(next)?;
149                        expr = Expr::Call { callee: Box::new(expr), args };
150                    }
151                    Rule::index => {
152                        let idx_pair = next.into_inner().next().expect("index inner expr");
153                        let index_expr = parse_expr(idx_pair)?;
154                        expr = Expr::Index {
155                            object: Box::new(expr),
156                            index: Box::new(index_expr),
157                        };
158                    }
159                    Rule::property => {
160                        let name = next.into_inner().next().expect("property ident").as_str().to_string();
161                        expr = Expr::Member { object: Box::new(expr), field: name };
162                    }
163                    r => {
164                        return Err(Error::InternalParserError(format!("unexpected postfix op: {:?}", r)));
165                    }
166                }
167            }
168            Ok(expr)
169        }
170        _ => parse_primary(pair),
171    }
172}
173
174fn parse_call_args(pair: Pair<Rule>) -> Result<Vec<Expr>> {
175    debug_assert_eq!(pair.as_rule(), Rule::call);
176    let mut args = Vec::new();
177    for p in pair.into_inner() {
178        // call contains expr separated by commas -> grammar emits only expr pairs inside
179        if matches!(p.as_rule(), Rule::expr) {
180            args.push(parse_expr(p)?);
181        }
182    }
183    Ok(args)
184}
185
186fn parse_primary(pair: Pair<Rule>) -> Result<Expr> {
187    match pair.as_rule() {
188        Rule::primary => parse_primary(pair.into_inner().next().unwrap()),
189        Rule::parens => parse_expr(pair.into_inner().next().unwrap()),
190        Rule::ident => Ok(Expr::Var(pair.as_str().to_string())),
191        Rule::number => parse_number(pair),
192        Rule::boolean => {
193            let inner = pair.into_inner().next().unwrap();
194            let val = matches!(inner.as_rule(), Rule::true_kw);
195            Ok(Expr::Literal(Primitive::Bool(val)))
196        }
197        Rule::string => {
198            let s = unescape_string(pair.as_str())?;
199            Ok(Expr::Literal(Primitive::Str(s)))
200        }
201        Rule::list => parse_list(pair),
202        Rule::dict => parse_dict(pair),
203        r => Err(Error::InternalParserError(format!("unexpected primary op: {:?}", r))),
204    }
205}
206
207fn parse_number(pair: Pair<Rule>) -> Result<Expr> {
208    let inner = pair.into_inner().next().unwrap();
209    match inner.as_rule() {
210        Rule::int => {
211            let s = inner.as_str();
212            let v: i64 = s.parse().map_err(|_| Error::ParseError(format!("invalid int: {}", s)))?;
213            Ok(Expr::Literal(Primitive::Int(v)))
214        }
215        Rule::float => {
216            let s = inner.as_str();
217            let v: f64 = s.parse().map_err(|_| Error::ParseError(format!("invalid float: {}", s)))?;
218            Ok(Expr::Literal(Primitive::Float(v)))
219        }
220        r => Err(Error::InternalParserError(format!("unexpected number: {:?}", r))),
221    }
222}
223
224fn parse_list(pair: Pair<Rule>) -> Result<Expr> {
225    let mut elems = Vec::new();
226    for p in pair.into_inner() {
227        if let Rule::expr = p.as_rule() {
228            elems.push(parse_expr(p)?);
229        }
230    }
231    Ok(Expr::ListLiteral(elems))
232}
233
234fn parse_dict(pair: Pair<Rule>) -> Result<Expr> {
235    let mut items = Vec::new();
236    for p in pair.into_inner() {
237        if let Rule::pair = p.as_rule() {
238            let mut it = p.into_inner();
239            let key_pair = it.next().expect("pair key expr");
240            let key = parse_expr(key_pair)?;
241            let value_pair = it.next().expect("pair value expr");
242            let value = parse_expr(value_pair)?;
243            items.push((key, value));
244        }
245    }
246    Ok(Expr::DictLiteral(items))
247}
248
249fn unescape_string(src: &str) -> Result<String> {
250    // strip surrounding quotes if present (supports both ' and ")
251    // let raw = if src.starts_with('"') && src.ends_with('"') && src.len() >= 2 {
252    //     &src[1..src.len() - 1]
253    // } else if src.starts_with('\'') && src.ends_with('\'') && src.len() >= 2 {
254    //     &src[1..src.len() - 1]
255    // } else {
256    //     src
257    // };
258    let escape_char = src.chars().next().unwrap();
259    let mut out = String::with_capacity(src.len() - 2);
260    let mut chars = src[1..src.len() - 1].chars().peekable();
261    while let Some(c) = chars.next() {
262        if c == '\\' {
263            match chars.next() {
264                Some('n') => out.push('\n'),
265                Some('\\') => out.push('\\'),
266                Some(next) if next == escape_char => out.push(escape_char),
267                _ => return Err(Error::ParseError(format!("invalid escape character {}", c))),
268            }
269        } else {
270            out.push(c);
271        }
272    }
273    Ok(out)
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_interpolated_expr() {
282        let input = "123}x";
283        let (expr, idx) = parse_internal(input, Rule::delimited_expr).unwrap();
284        assert_eq!(expr, Expr::Literal(Primitive::Int(123)));
285        assert_eq!(idx, 4);
286    }
287}