yarig/parser/
parser_expr.rs

1use std::{fmt::Display, ops::{Deref, DerefMut}};
2
3use winnow::{ascii::{space0, Caseless}, combinator::{alt, delimited}, error::StrContext, Parser};
4
5use crate::{error::RifError, rifgen::{order_dict::OrderDict, GenericRange, GenericValues}};
6use super::{identifier, val_f64, val_isize, ws, Res};
7
8#[derive(Clone, Copy, PartialEq, Debug)]
9pub enum OpKind {
10    /// Addition operator +
11    Plus,
12    /// Subtraction opeartor
13    Minus,
14    /// Multiplication operator
15    Mult,
16    /// Division operator /
17    Div,
18    /// Remainder operator %
19    Rem,
20    /// Power operator: ^
21    Pow,
22    /// Not operator: 'not x' , '!x', '~x'
23    Not,
24    /// Shift left / right
25    ShiftLeft, ShiftRight,
26    /// Comparison operator
27    Equal, NotEqual, Greater, GreaterEq, Lesser, LesserEq
28}
29
30#[derive(Clone, Copy, PartialEq, Debug)]
31pub enum FuncKind {
32    Log2, Log10, Power, Round, Ceil, Floor,
33}
34
35impl std::fmt::Display for FuncKind {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        use FuncKind::*;
38        match &self {
39            Log2  =>  write!(f, "log2"),
40            Log10 =>  write!(f, "log10"),
41            Power =>  write!(f, "pow"),
42            Round =>  write!(f, "round"),
43            Ceil  =>  write!(f, "ceil"),
44            Floor =>  write!(f, "floor"),
45        }
46    }
47}
48
49
50#[derive(Clone, PartialEq, Debug)]
51pub enum Token {
52    /// Basic math operator: +,-,*,/,%,^
53    Operator(OpKind),
54    /// Function call: ceil, log2
55    FuncCall(FuncKind),
56    /// Left Parenthesis
57    ParenL,
58    /// Right Parenthesis
59    ParenR,
60    /// Comma (used as argument separator in function call)
61    Comma,
62    /// Number
63    Number(f64),
64    /// Variable (starts with $)
65    Var(String),
66}
67
68impl std::fmt::Display for Token {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        use Token::*;
71        use OpKind::*;
72        match &self {
73            Operator(Plus)  => write!(f, "+"),
74            Operator(Minus) => write!(f, "-"),
75            Operator(Mult)  => write!(f, "*"),
76            Operator(Div)   => write!(f, "/"),
77            Operator(Pow)   => write!(f, "^"),
78            Operator(Rem)   => write!(f, "%"),
79            Operator(Equal)     => write!(f, "=="),
80            Operator(NotEqual)  => write!(f, "!="),
81            Operator(Greater)   => write!(f, ">"),
82            Operator(GreaterEq) => write!(f, ">="),
83            Operator(Lesser)    => write!(f, "<"),
84            Operator(LesserEq)  => write!(f, "<="),
85            Operator(ShiftLeft)  => write!(f, "<<"),
86            Operator(ShiftRight) => write!(f, ">>"),
87            Operator(Not) => write!(f, "!"),
88            FuncCall(s) => write!(f, "{s}()"),
89            ParenL     => write!(f, "("),
90            ParenR     => write!(f, ")"),
91            Comma      => write!(f, ","),
92            Number(v)  => write!(f, "{v}"),
93            Var(n)     => write!(f, "${n}"),
94        }
95    }
96}
97
98fn operator<'a>(input: &mut &'a str) -> Res<'a, Token> {
99    use Token::Operator;
100    use OpKind::*;
101    alt((
102        ws("+").value(Operator(Plus)),
103        ws("-").value(Operator(Minus)),
104        ws("*").value(Operator(Mult)),
105        ws("/").value(Operator(Div)),
106        ws("^").value(Operator(Pow)),
107        ws("%").value(Operator(Rem)),
108        ws("==").value(Operator(Equal)),
109        ws("!=").value(Operator(NotEqual)),
110        ws(">=").value(Operator(GreaterEq)),
111        ws("<=").value(Operator(LesserEq)),
112        ws("<<").value(Operator(ShiftLeft)),
113        ws(">>").value(Operator(ShiftRight)),
114        ws(">").value(Operator(Greater)),
115        ws("<").value(Operator(Lesser)),
116    )).parse_next(input)
117}
118
119fn not<'a>(input: &mut &'a str) -> Res<'a, Token> {
120    ws(alt(("not","!","~"))).value(Token::Operator(OpKind::Not)).parse_next(input)
121}
122
123fn parenl<'a>(input: &mut &'a str) -> Res<'a, Token> {
124    ws("(").value(Token::ParenL).parse_next(input)
125}
126
127fn parenr<'a>(input: &mut &'a str) -> Res<'a, Token> {
128    ws(")").value(Token::ParenR).parse_next(input)
129}
130
131fn comma<'a>(input: &mut &'a str) -> Res<'a, Token> {
132    ws(",").value(Token::Comma).parse_next(input)
133}
134
135fn func_call<'a>(input: &mut &'a str) -> Res<'a, Token> {
136    use Token::FuncCall;
137    use FuncKind::*;
138    alt((
139        ws("log2(").value(FuncCall(Log2)),
140        ws("log10(").value(FuncCall(Log10)),
141        ws("pow(").value(FuncCall(Power)),
142        ws("int(").value(FuncCall(Round)),
143        ws("round(").value(FuncCall(Round)),
144        ws("ceil(").value(FuncCall(Ceil)),
145        ws("floor(").value(FuncCall(Floor)),
146    )).parse_next(input)
147}
148
149fn variable<'a>(input: &mut &'a str) -> Res<'a, Token> {
150    delimited("$", identifier, space0).map(|n| Token::Var(n.to_owned())).parse_next(input)
151}
152
153fn idx<'a>(input: &mut &'a str) -> Res<'a, Token> {
154    ws("i").value(Token::Var("i".to_owned())).parse_next(input)
155}
156
157fn number<'a>(input: &mut &'a str) -> Res<'a, Token> {
158    ws(alt((
159        val_isize.map(|v| Token::Number(v as f64)),
160        val_f64.map(Token::Number),
161        Caseless("true").value(Token::Number(1.0)),
162        Caseless("false").value(Token::Number(0.0)),
163    ))).parse_next(input)
164}
165
166fn precedence(op: OpKind) -> u8 {
167    match op {
168        //
169        OpKind::Not => 2,
170        // Multiplication, Division Remainder
171        OpKind::Mult => 3,
172        OpKind::Div => 3,
173        OpKind::Rem => 3,
174        // Addition/Subtraction
175        OpKind::Plus => 4,
176        OpKind::Minus => 4,
177        // Power
178        OpKind::Pow => 5,
179        // Shift
180        OpKind::ShiftLeft => 5,
181        OpKind::ShiftRight => 5,
182        // Comparison
183        OpKind::Equal => 7,
184        OpKind::NotEqual => 7,
185        OpKind::Greater => 6,
186        OpKind::GreaterEq => 6,
187        OpKind::Lesser => 6,
188        OpKind::LesserEq => 6,
189    }
190}
191
192#[derive(Clone, Copy, PartialEq, Debug)]
193enum ExprState {
194    Operand,
195    Operator
196}
197
198#[derive(Clone, Copy, PartialEq, Debug)]
199enum ExprContext {
200    SubExpr,
201    FuncCall(u8)
202}
203
204
205#[allow(dead_code)]
206/// Parse a string and return a sequence of tokens in Reverse-Polish Notation (RPN)
207/// Use the Shunting Yard Algorithm to transform infix notation to RPN:
208///   - While there are tokens to be read:
209///     - Read a token
210///     - If it's a number add it to queue
211///     - If it's an operator
212///       - While there's an operator on the top of the stack with greater precedence:
213///         - Pop operators from the stack onto the output queue
214///       - Push the current operator onto the stack
215///     - If it's a left bracket push it onto the stack
216///     - If it's a right bracket
217///       - While there's not a left bracket at the top of the stack:
218///         - Pop operators from the stack onto the output queue.
219///       - Pop the left bracket from the stack and discard it
220///   - While there are operators on the stack, pop them to the queue
221pub fn parse_expr(input: &str) -> Result<ExprTokens,RifError> {
222    let mut tokens = ExprTokens::new(2);
223    //
224    let mut op_stack = ExprTokens::new(1);
225    let mut cntxt : Vec<ExprContext> = Vec::new();
226    let mut state = ExprState::Operand;
227    //
228    let mut s = input;
229    while !s.is_empty() {
230
231        let token = match state {
232            ExprState::Operand => alt((parenl,variable,idx,number,func_call, not)).context(StrContext::Label("operand")).parse_next(&mut s)?,
233            ExprState::Operator => match cntxt.last() {
234                None => operator(&mut s)?,
235                Some(ExprContext::SubExpr) |
236                Some(ExprContext::FuncCall(0)) => alt((operator,parenr)).context(StrContext::Label("function call / Sub expression")).parse_next(&mut s)?,
237                Some(ExprContext::FuncCall(_)) => alt((operator,comma)).context(StrContext::Label("function call")).parse_next(&mut s)?,
238            }
239        };
240
241        // println!("{state:?} : {token:?} s='{s}' | cntxt={cntxt:?} | Stack = {op_stack:?} | Output = {tokens:?}");
242        match token {
243            // Operand -> save token and change state to Operator
244            Token::Number(_) |
245            Token::Var(_) => {
246                tokens.push(token);
247                state = ExprState::Operator;
248            },
249            // Push not operator on stack
250            Token::Operator(OpKind::Not) => op_stack.push(token),
251            // Operator -> Move operator stack to output until higher precedence operator is found
252            // and then push operator to the stack, and change state to operand
253            Token::Operator(op_r) => {
254                while let Some(t) = op_stack.last() {
255                    match t {
256                        Token::Operator(op_l) if precedence(op_r) >= precedence(*op_l) => tokens.push(op_stack.pop().unwrap()),
257                        _ => break,
258                    }
259                }
260                op_stack.push(token);
261                state = ExprState::Operand;
262            },
263            // Function call: push on operator stack and increase parenthesis counter
264            Token::FuncCall(kind) => {
265                op_stack.push(token);
266                let nb_sep = match kind {
267                    FuncKind::Power => 1,
268                    _ => 0,
269                };
270                cntxt.push(ExprContext::FuncCall(nb_sep));
271            },
272            // Open parenthesis: Push ParenL on operator stack
273            Token::ParenL => {
274                cntxt.push(ExprContext::SubExpr);
275                op_stack.push(Token::ParenL);
276            },
277            // Closing parenthesis : Pop last context and pop operators stack
278            Token::ParenR => {
279                cntxt.pop();
280                while let Some(op) = op_stack.pop() {
281                    match op {
282                        Token::ParenL => {
283                            break;
284                        },
285                        Token::FuncCall(_) => {
286                            tokens.push(op);
287                            break
288                        },
289                        _ => {tokens.push(op)},
290                    }
291                }
292            },
293            // Argument separator : decrease the expected number of argument
294            // and now expect operand
295            Token::Comma => {
296                state = ExprState::Operand;
297                if let Some(ExprContext::FuncCall(n)) = cntxt.last_mut() {
298                    *n -= 1;
299                }
300            }
301        }
302    }
303
304    // println!("Done : {state:?} | cntxt={cntxt:?} | Stack = {op_stack:?} | Output = {tokens:?}");
305    // Empty the operator stack once all tokens have been parsed
306    while let Some(op) = op_stack.pop() {
307        tokens.push(op);
308    }
309
310    Ok(tokens)
311}
312
313#[derive(Clone, Debug, PartialEq, Default)]
314pub struct ExprTokens(Vec<Token>);
315
316impl Deref for ExprTokens {
317    type Target = Vec<Token>;
318    fn deref(&self) -> &Self::Target {
319        &self.0
320    }
321}
322
323impl DerefMut for ExprTokens {
324    fn deref_mut(&mut self) -> &mut Self::Target {
325        &mut self.0
326    }
327}
328
329impl ExprTokens {
330
331    pub fn new(capacity: usize) -> Self {
332        ExprTokens(Vec::with_capacity(capacity))
333    }
334
335    pub fn eval(&self, variables: &ParamValues) -> Result<isize, ExprError> {
336        if self.is_empty() {
337            return Ok(0);
338        }
339        let mut values : Vec<f64> = Vec::with_capacity(self.len()>>1);
340        // println!("[eval] Expression = {self:?}");
341        for token in self.iter() {
342            match token {
343                Token::Number(v) => values.push(*v),
344                Token::Var(n) => {
345                    let v = variables.get(n).ok_or(ExprError::UnknownVar(n.to_owned()))?;
346                    values.push(*v as f64)
347                },
348                Token::Operator(op) => {
349                    let v2 = if *op != OpKind::Not {
350                        values.pop().ok_or(ExprError::Malformed)?
351                    } else {
352                        0.0
353                    };
354                    let v1 = values.pop().ok_or(ExprError::Malformed)?;
355                    let res =
356                        match op {
357                            OpKind::Plus  => v1+v2,
358                            OpKind::Minus => v1-v2,
359                            OpKind::Mult  => v1*v2,
360                            OpKind::Div   => v1/v2,
361                            OpKind::Rem   => (v1 as isize % v2 as isize) as f64,
362                            OpKind::Pow   => v1.powf(v2),
363                            // Logical inversion
364                            OpKind::Not   => if v1==0.0 {1.0} else {0.0},
365                            // Shift
366                            OpKind::ShiftLeft  => ((v1 as isize) << v2 as usize) as f64,
367                            OpKind::ShiftRight => ((v1 as isize) >> v2 as usize) as f64,
368                            // Comparison
369                            OpKind::Equal     => if v1 == v2 {1.0} else {0.0},
370                            OpKind::NotEqual  => if v1 != v2 {1.0} else {0.0},
371                            OpKind::Greater   => if v1 >  v2 {1.0} else {0.0},
372                            OpKind::GreaterEq => if v1 >= v2 {1.0} else {0.0},
373                            OpKind::Lesser    => if v1 <  v2 {1.0} else {0.0},
374                            OpKind::LesserEq  => if v1 <= v2 {1.0} else {0.0},
375                        };
376                    // println!("[eval] {v1} {op:?} {v2} -> {res} | {values:?}");
377                    values.push(res);
378                },
379                Token::FuncCall(func) => {
380                    let v = values.pop().ok_or(ExprError::Malformed)?;
381                    let res = match func {
382                        FuncKind::Log2  => v.log2(),
383                        FuncKind::Log10 => v.log10(),
384                        FuncKind::Power   => {
385                            let base = values.pop().ok_or(ExprError::Malformed)?;
386                            base.powf(v)
387                        },
388                        FuncKind::Round => v.round(),
389                        FuncKind::Ceil  => v.ceil(),
390                        FuncKind::Floor => v.floor(),
391                    };
392                    values.push(res);
393                },
394                // Other token variant should never appear in the expression
395                _ => return Err(ExprError::Malformed),
396            }
397        }
398        // Cast result to integer and check the stack is empty at the end of the evaluation
399        let result = values.pop().ok_or(ExprError::Malformed)?;
400        if values.is_empty() {
401            Ok(result.round() as isize)
402        } else {
403            Err(ExprError::Malformed)
404        }
405    }
406
407    pub fn eval_with_gen(&self, variables: &ParamValues, generics: &GenericValues) -> Result<ExprValue, ExprError> {
408        match self.eval(variables) {
409            Ok(n) => Ok(ExprValue::Value(n)),
410            Err(ExprError::UnknownVar(n)) => {
411                if self.len() > 1 {
412                    Err(ExprError::Malformed)
413                } else if let Some(range) = generics.get(&n) {
414                    Ok(ExprValue::Range(n,range.clone()))
415                } else {
416                    Err(ExprError::UnknownVar(n))
417                }
418            }
419            Err(e) => Err(e)
420        }
421    }
422}
423
424#[derive(Clone, Debug, PartialEq)]
425pub enum ExprValue {
426    Value(isize),
427    Range(String,GenericRange),
428}
429
430impl ExprValue {
431
432    pub fn max(&self) -> isize {
433        match self {
434            ExprValue::Value(n) => *n,
435            ExprValue::Range(_,r) => r.max.into(),
436        }
437    }
438}
439
440impl Default for ExprValue {
441    fn default() -> Self {
442        ExprValue::Value(0)
443    }
444}
445
446#[derive(Clone, Debug, PartialEq)]
447pub enum ExprError {
448    Malformed,
449    UnknownVar(String),
450}
451
452impl From<ExprError> for String {
453    fn from(value: ExprError) -> Self {
454        match value {
455            ExprError::Malformed => "Malformed expression".to_owned(),
456            ExprError::UnknownVar(v) => format!("Unknown var {v} in expression"),
457        }
458    }
459}
460
461#[derive(Clone, Debug)]
462pub struct ParamValues(OrderDict<String,isize>);
463
464
465
466impl ParamValues {
467
468    pub fn new() -> Self {
469        ParamValues(OrderDict::new())
470    }
471
472    pub fn new_with_idx(idx: isize) -> Self {
473        let mut params = ParamValues(OrderDict::new());
474        params.0.insert("i".to_owned(), idx);
475        params
476    }
477
478    pub fn from_items<'a, I>(dict: I) -> Result<Self,String>
479    where I: Iterator<Item = (&'a String,&'a ExprTokens)> {
480        let mut params = ParamValues(OrderDict::new());
481        for (name,expr) in dict.into_iter() {
482            let v = expr.eval(&params).map_err(|_| format!("Malformed parameter {name} : {expr:?}"))?;
483            params.0.insert(name.to_owned(), v);
484        }
485        Ok(params)
486    }
487
488    pub fn compile<'a, I>(&mut self, dict: I) -> Result<(),String>
489    where I: Iterator<Item = (&'a String,&'a ExprTokens)> {
490        for (name,expr) in dict.into_iter() {
491            if self.0.contains_key(name) {
492                // println!("Skipping {name} = {expr:?} : current value is {:?}", self.0.get(name).unwrap());
493                continue;
494            }
495            let v = expr.eval(self).map_err(|_| format!("Malformed parameter {name} : {expr:?}"))?;
496            // println!("Compiling {name} = {expr:?} ==> {v}");
497            self.0.insert(name.to_owned(), v);
498        }
499        Ok(())
500    }
501
502    pub fn get(&self, k: &String) -> Option<&isize> {
503        self.0.get(k)
504    }
505
506    pub fn insert(&mut self, k: String, v: isize) {
507        self.0.insert(k,v);
508    }
509
510    pub fn items(&self) -> impl Iterator<Item=(&String,&isize)> {
511        self.0.items()
512    }
513
514    #[allow(dead_code)]
515    pub fn len(&self) -> usize {
516        self.0.len()
517    }
518
519    #[allow(dead_code)]
520    pub fn is_empty(&self) -> bool {
521        self.0.len()==0
522    }
523}
524
525impl Default for ParamValues {
526    fn default() -> Self {
527        Self::new()
528    }
529}
530
531impl Display for ParamValues {
532
533    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
534        let tab = if f.alternate() {"\t"} else {""};
535        let end = if f.alternate() {"\n"} else {", "};
536        if f.alternate() {
537            writeln!(f)?;
538        }
539        for (k,v) in self.items() {
540            write!(f, "{tab}{k} = {v}{end}")?;
541        }
542        Ok(())
543    }
544}
545
546
547
548#[cfg(test)]
549mod tests_parsing {
550    use super::*;
551    use super::OpKind::*;
552    use super::FuncKind::*;
553    use super::Token::*;
554
555    #[test]
556    fn test_parse_expr() {
557        assert_eq!(
558            parse_expr("256 "),
559            Ok(ExprTokens(vec![Number(256.0)]))
560        );
561
562        assert_eq!(
563            parse_expr("$v1 +3"),
564            Ok(ExprTokens(vec![Var("v1".to_owned()), Number(3.0), Operator(Plus)]))
565        );
566
567        assert_eq!(
568            parse_expr("ceil(log2($v3-5))"),
569            Ok(ExprTokens(vec![Var("v3".to_owned()), Number(5.0), Operator(Minus), FuncCall(Log2), FuncCall(Ceil)]))
570        );
571
572        assert_eq!(
573            parse_expr("pow(3,$x )-1"),
574            Ok(ExprTokens(vec![Number(3.0), Var("x".to_owned()), FuncCall(Power), Number(1.0), Operator(Minus)]))
575        );
576    }
577
578    #[test]
579    fn test_eval_expr() {
580        let mut variables = ParamValues(OrderDict::new());
581        variables.0.insert("v1".to_owned(), 1);
582        variables.0.insert("x".to_owned(), 17);
583        let expr = parse_expr("16*(not $v1) + 256*$v1").unwrap();
584        assert_eq!(expr.eval(&variables),Ok(256));
585        let expr = parse_expr("pow(2, $x) - 1").unwrap();
586        assert_eq!(expr.eval(&variables),Ok((1<<17)-1));
587    }
588
589}