Skip to main content

yarig/parser/
parser_expr.rs

1use std::{fmt::Display, ops::{Deref, DerefMut}};
2
3use winnow::{ascii::{space0, Caseless}, combinator::{alt, delimited}, error::StrContext, Parser};
4
5use crate::{error::RifError, rifgen::{order_dict::OrderDict, GenericRange, GenericValues}};
6use super::{identifier, val_f64, val_isize, ws, Res};
7
8#[derive(Clone, Copy, PartialEq, Debug)]
9pub enum OpKind {
10    /// Addition operator +
11    Plus,
12    /// Subtraction opeartor
13    Minus,
14    /// Multiplication operator
15    Mult,
16    /// Division operator /
17    Div,
18    /// Remainder operator %
19    Rem,
20    /// Power operator: ^
21    Pow,
22    /// Not operator: 'not x' , '!x', '~x'
23    Not,
24    /// Shift left / right
25    ShiftLeft, ShiftRight,
26    /// Comparison operator
27    Equal, NotEqual, Greater, GreaterEq, Lesser, LesserEq
28}
29
30#[derive(Clone, Copy, PartialEq, Debug)]
31pub enum FuncKind {
32    Log2, Log10, Power, Round, Ceil, Floor,
33}
34
35impl std::fmt::Display for FuncKind {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        use FuncKind::*;
38        match &self {
39            Log2  =>  write!(f, "log2"),
40            Log10 =>  write!(f, "log10"),
41            Power =>  write!(f, "pow"),
42            Round =>  write!(f, "round"),
43            Ceil  =>  write!(f, "ceil"),
44            Floor =>  write!(f, "floor"),
45        }
46    }
47}
48
49
50#[derive(Clone, PartialEq, Debug)]
51pub enum Token {
52    /// Basic math operator: +,-,*,/,%,^
53    Operator(OpKind),
54    /// Function call: ceil, log2
55    FuncCall(FuncKind),
56    /// Left Parenthesis
57    ParenL,
58    /// Right Parenthesis
59    ParenR,
60    /// Comma (used as argument separator in function call)
61    Comma,
62    /// Number
63    Number(f64),
64    /// Variable (starts with $)
65    Var(String),
66}
67
68impl std::fmt::Display for Token {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        use Token::*;
71        use OpKind::*;
72        match &self {
73            Operator(Plus)  => write!(f, "+"),
74            Operator(Minus) => write!(f, "-"),
75            Operator(Mult)  => write!(f, "*"),
76            Operator(Div)   => write!(f, "/"),
77            Operator(Pow)   => write!(f, "^"),
78            Operator(Rem)   => write!(f, "%"),
79            Operator(Equal)     => write!(f, "=="),
80            Operator(NotEqual)  => write!(f, "!="),
81            Operator(Greater)   => write!(f, ">"),
82            Operator(GreaterEq) => write!(f, ">="),
83            Operator(Lesser)    => write!(f, "<"),
84            Operator(LesserEq)  => write!(f, "<="),
85            Operator(ShiftLeft)  => write!(f, "<<"),
86            Operator(ShiftRight) => write!(f, ">>"),
87            Operator(Not) => write!(f, "!"),
88            FuncCall(s) => write!(f, "{s}()"),
89            ParenL     => write!(f, "("),
90            ParenR     => write!(f, ")"),
91            Comma      => write!(f, ","),
92            Number(v)  => write!(f, "{v}"),
93            Var(n)     => write!(f, "${n}"),
94        }
95    }
96}
97
98fn operator<'a>(input: &mut &'a str) -> Res<'a, Token> {
99    use Token::Operator;
100    use OpKind::*;
101    alt((
102        alt((
103            ws("+").value(Operator(Plus)),
104            ws("-").value(Operator(Minus)),
105            ws("*").value(Operator(Mult)),
106            ws("/").value(Operator(Div)),
107            ws("^").value(Operator(Pow)),
108            ws("%").value(Operator(Rem)),
109        )),
110        alt((
111            ws("==").value(Operator(Equal)),
112            ws("!=").value(Operator(NotEqual)),
113            ws(">=").value(Operator(GreaterEq)),
114            ws("<=").value(Operator(LesserEq)),
115            ws("<<").value(Operator(ShiftLeft)),
116            ws(">>").value(Operator(ShiftRight)),
117            ws(">").value(Operator(Greater)),
118            ws("<").value(Operator(Lesser)),
119        ))
120    )).parse_next(input)
121}
122
123fn not<'a>(input: &mut &'a str) -> Res<'a, Token> {
124    ws(alt(("not","!","~"))).value(Token::Operator(OpKind::Not)).parse_next(input)
125}
126
127fn parenl<'a>(input: &mut &'a str) -> Res<'a, Token> {
128    ws("(").value(Token::ParenL).parse_next(input)
129}
130
131fn parenr<'a>(input: &mut &'a str) -> Res<'a, Token> {
132    ws(")").value(Token::ParenR).parse_next(input)
133}
134
135fn comma<'a>(input: &mut &'a str) -> Res<'a, Token> {
136    ws(",").value(Token::Comma).parse_next(input)
137}
138
139fn func_call<'a>(input: &mut &'a str) -> Res<'a, Token> {
140    use Token::FuncCall;
141    use FuncKind::*;
142    alt((
143        ws("log2(").value(FuncCall(Log2)),
144        ws("log10(").value(FuncCall(Log10)),
145        ws("pow(").value(FuncCall(Power)),
146        ws("int(").value(FuncCall(Round)),
147        ws("round(").value(FuncCall(Round)),
148        ws("ceil(").value(FuncCall(Ceil)),
149        ws("floor(").value(FuncCall(Floor)),
150    )).parse_next(input)
151}
152
153fn variable<'a>(input: &mut &'a str) -> Res<'a, Token> {
154    delimited("$", identifier, space0).map(|n| Token::Var(n.to_owned())).parse_next(input)
155}
156
157fn idx<'a>(input: &mut &'a str) -> Res<'a, Token> {
158    ws("i").value(Token::Var("i".to_owned())).parse_next(input)
159}
160
161fn number<'a>(input: &mut &'a str) -> Res<'a, Token> {
162    ws(alt((
163        val_isize.map(|v| Token::Number(v as f64)),
164        val_f64.map(Token::Number),
165        Caseless("true").value(Token::Number(1.0)),
166        Caseless("false").value(Token::Number(0.0)),
167    ))).parse_next(input)
168}
169
170fn precedence(op: OpKind) -> u8 {
171    match op {
172        //
173        OpKind::Not => 2,
174        // Multiplication, Division Remainder
175        OpKind::Mult => 3,
176        OpKind::Div => 3,
177        OpKind::Rem => 3,
178        // Addition/Subtraction
179        OpKind::Plus => 4,
180        OpKind::Minus => 4,
181        // Power
182        OpKind::Pow => 5,
183        // Shift
184        OpKind::ShiftLeft => 5,
185        OpKind::ShiftRight => 5,
186        // Comparison
187        OpKind::Equal => 7,
188        OpKind::NotEqual => 7,
189        OpKind::Greater => 6,
190        OpKind::GreaterEq => 6,
191        OpKind::Lesser => 6,
192        OpKind::LesserEq => 6,
193    }
194}
195
196#[derive(Clone, Copy, PartialEq, Debug)]
197enum ExprState {
198    Operand,
199    Operator
200}
201
202#[derive(Clone, Copy, PartialEq, Debug)]
203enum ExprContext {
204    SubExpr,
205    FuncCall(u8)
206}
207
208
209#[allow(dead_code)]
210/// Parse a string and return a sequence of tokens in Reverse-Polish Notation (RPN)
211/// Use the Shunting Yard Algorithm to transform infix notation to RPN:
212///   - While there are tokens to be read:
213///     - Read a token
214///     - If it's a number add it to queue
215///     - If it's an operator
216///       - While there's an operator on the top of the stack with greater precedence:
217///         - Pop operators from the stack onto the output queue
218///       - Push the current operator onto the stack
219///     - If it's a left bracket push it onto the stack
220///     - If it's a right bracket
221///       - While there's not a left bracket at the top of the stack:
222///         - Pop operators from the stack onto the output queue.
223///       - Pop the left bracket from the stack and discard it
224///   - While there are operators on the stack, pop them to the queue
225pub fn parse_expr(input: &str) -> Result<ExprTokens,RifError> {
226    let mut tokens = ExprTokens::new(2);
227    //
228    let mut op_stack = ExprTokens::new(1);
229    let mut cntxt : Vec<ExprContext> = Vec::new();
230    let mut state = ExprState::Operand;
231    //
232    let mut s = input;
233    while !s.is_empty() {
234
235        let token = match state {
236            ExprState::Operand => alt((parenl,variable,idx,number,func_call, not)).context(StrContext::Label("operand")).parse_next(&mut s)?,
237            ExprState::Operator => match cntxt.last() {
238                None => operator(&mut s)?,
239                Some(ExprContext::SubExpr) |
240                Some(ExprContext::FuncCall(0)) => alt((operator,parenr)).context(StrContext::Label("function call / Sub expression")).parse_next(&mut s)?,
241                Some(ExprContext::FuncCall(_)) => alt((operator,comma)).context(StrContext::Label("function call")).parse_next(&mut s)?,
242            }
243        };
244
245        // println!("{state:?} : {token:?} s='{s}' | cntxt={cntxt:?} | Stack = {op_stack:?} | Output = {tokens:?}");
246        match token {
247            // Operand -> save token and change state to Operator
248            Token::Number(_) |
249            Token::Var(_) => {
250                tokens.push(token);
251                state = ExprState::Operator;
252            },
253            // Push not operator on stack
254            Token::Operator(OpKind::Not) => op_stack.push(token),
255            // Operator -> Move operator stack to output until higher precedence operator is found
256            // and then push operator to the stack, and change state to operand
257            Token::Operator(op_r) => {
258                while let Some(t) = op_stack.last() {
259                    match t {
260                        Token::Operator(op_l) if precedence(op_r) >= precedence(*op_l) => tokens.push(op_stack.pop().unwrap()),
261                        _ => break,
262                    }
263                }
264                op_stack.push(token);
265                state = ExprState::Operand;
266            },
267            // Function call: push on operator stack and increase parenthesis counter
268            Token::FuncCall(kind) => {
269                op_stack.push(token);
270                let nb_sep = match kind {
271                    FuncKind::Power => 1,
272                    _ => 0,
273                };
274                cntxt.push(ExprContext::FuncCall(nb_sep));
275            },
276            // Open parenthesis: Push ParenL on operator stack
277            Token::ParenL => {
278                cntxt.push(ExprContext::SubExpr);
279                op_stack.push(Token::ParenL);
280            },
281            // Closing parenthesis : Pop last context and pop operators stack
282            Token::ParenR => {
283                cntxt.pop();
284                while let Some(op) = op_stack.pop() {
285                    match op {
286                        Token::ParenL => {
287                            break;
288                        },
289                        Token::FuncCall(_) => {
290                            tokens.push(op);
291                            break
292                        },
293                        _ => {tokens.push(op)},
294                    }
295                }
296            },
297            // Argument separator : decrease the expected number of argument
298            // and now expect operand
299            Token::Comma => {
300                state = ExprState::Operand;
301                if let Some(ExprContext::FuncCall(n)) = cntxt.last_mut() {
302                    *n -= 1;
303                }
304            }
305        }
306    }
307
308    // println!("Done : {state:?} | cntxt={cntxt:?} | Stack = {op_stack:?} | Output = {tokens:?}");
309    // Empty the operator stack once all tokens have been parsed
310    while let Some(op) = op_stack.pop() {
311        tokens.push(op);
312    }
313
314    Ok(tokens)
315}
316
317#[derive(Clone, Debug, PartialEq, Default)]
318pub struct ExprTokens(Vec<Token>);
319
320impl Deref for ExprTokens {
321    type Target = Vec<Token>;
322    fn deref(&self) -> &Self::Target {
323        &self.0
324    }
325}
326
327impl DerefMut for ExprTokens {
328    fn deref_mut(&mut self) -> &mut Self::Target {
329        &mut self.0
330    }
331}
332
333impl ExprTokens {
334
335    pub fn new(capacity: usize) -> Self {
336        ExprTokens(Vec::with_capacity(capacity))
337    }
338
339    pub fn eval(&self, variables: &ParamValues) -> Result<isize, ExprError> {
340        if self.is_empty() {
341            return Ok(0);
342        }
343        let mut values : Vec<f64> = Vec::with_capacity(self.len()>>1);
344        // println!("[eval] Expression = {self:?}");
345        for token in self.iter() {
346            match token {
347                Token::Number(v) => values.push(*v),
348                Token::Var(n) => {
349                    let v = variables.get(n).ok_or(ExprError::UnknownVar(n.to_owned()))?;
350                    values.push(*v as f64)
351                },
352                Token::Operator(op) => {
353                    let v2 = if *op != OpKind::Not {
354                        values.pop().ok_or(ExprError::Malformed)?
355                    } else {
356                        0.0
357                    };
358                    let v1 = values.pop().ok_or(ExprError::Malformed)?;
359                    let res =
360                        match op {
361                            OpKind::Plus  => v1+v2,
362                            OpKind::Minus => v1-v2,
363                            OpKind::Mult  => v1*v2,
364                            OpKind::Div   => v1/v2,
365                            OpKind::Rem   => (v1 as isize % v2 as isize) as f64,
366                            OpKind::Pow   => v1.powf(v2),
367                            // Logical inversion
368                            OpKind::Not   => if v1==0.0 {1.0} else {0.0},
369                            // Shift
370                            OpKind::ShiftLeft  => ((v1 as isize) << v2 as usize) as f64,
371                            OpKind::ShiftRight => ((v1 as isize) >> v2 as usize) as f64,
372                            // Comparison
373                            OpKind::Equal     => if v1 == v2 {1.0} else {0.0},
374                            OpKind::NotEqual  => if v1 != v2 {1.0} else {0.0},
375                            OpKind::Greater   => if v1 >  v2 {1.0} else {0.0},
376                            OpKind::GreaterEq => if v1 >= v2 {1.0} else {0.0},
377                            OpKind::Lesser    => if v1 <  v2 {1.0} else {0.0},
378                            OpKind::LesserEq  => if v1 <= v2 {1.0} else {0.0},
379                        };
380                    // println!("[eval] {v1} {op:?} {v2} -> {res} | {values:?}");
381                    values.push(res);
382                },
383                Token::FuncCall(func) => {
384                    let v = values.pop().ok_or(ExprError::Malformed)?;
385                    let res = match func {
386                        FuncKind::Log2  => v.log2(),
387                        FuncKind::Log10 => v.log10(),
388                        FuncKind::Power   => {
389                            let base = values.pop().ok_or(ExprError::Malformed)?;
390                            base.powf(v)
391                        },
392                        FuncKind::Round => v.round(),
393                        FuncKind::Ceil  => v.ceil(),
394                        FuncKind::Floor => v.floor(),
395                    };
396                    values.push(res);
397                },
398                // Other token variant should never appear in the expression
399                _ => return Err(ExprError::Malformed),
400            }
401        }
402        // Cast result to integer and check the stack is empty at the end of the evaluation
403        let result = values.pop().ok_or(ExprError::Malformed)?;
404        if values.is_empty() {
405            Ok(result.round() as isize)
406        } else {
407            Err(ExprError::Malformed)
408        }
409    }
410
411    pub fn eval_with_gen(&self, variables: &ParamValues, generics: &GenericValues) -> Result<ExprValue, ExprError> {
412        match self.eval(variables) {
413            Ok(n) => Ok(ExprValue::Value(n)),
414            Err(ExprError::UnknownVar(n)) => {
415                if self.len() > 1 {
416                    Err(ExprError::Malformed)
417                } else if let Some(range) = generics.get(&n) {
418                    Ok(ExprValue::Range(n,range.clone()))
419                } else {
420                    Err(ExprError::UnknownVar(n))
421                }
422            }
423            Err(e) => Err(e)
424        }
425    }
426}
427
428#[derive(Clone, Debug, PartialEq)]
429pub enum ExprValue {
430    Value(isize),
431    Range(String,GenericRange),
432}
433
434impl ExprValue {
435
436    pub fn max(&self) -> isize {
437        match self {
438            ExprValue::Value(n) => *n,
439            ExprValue::Range(_,r) => r.max as isize,
440        }
441    }
442}
443
444impl Default for ExprValue {
445    fn default() -> Self {
446        ExprValue::Value(0)
447    }
448}
449
450#[derive(Clone, Debug, PartialEq)]
451pub enum ExprError {
452    Malformed,
453    UnknownVar(String),
454}
455
456impl From<ExprError> for String {
457    fn from(value: ExprError) -> Self {
458        match value {
459            ExprError::Malformed => "Malformed expression".to_owned(),
460            ExprError::UnknownVar(v) => format!("Unknown var {v} in expression"),
461        }
462    }
463}
464
465#[derive(Clone, Debug)]
466pub struct ParamValues(OrderDict<String,isize>);
467
468
469
470impl ParamValues {
471
472    pub fn new() -> Self {
473        ParamValues(OrderDict::new())
474    }
475
476    pub fn new_with_idx(idx: isize) -> Self {
477        let mut params = ParamValues(OrderDict::new());
478        params.0.insert("i".to_owned(), idx);
479        params
480    }
481
482    pub fn from_items<'a, I>(dict: I) -> Result<Self,String>
483    where I: Iterator<Item = (&'a String,&'a ExprTokens)> {
484        let mut params = ParamValues(OrderDict::new());
485        for (name,expr) in dict.into_iter() {
486            let v = expr.eval(&params).map_err(|_| format!("Malformed parameter {name} : {expr:?}"))?;
487            params.0.insert(name.to_owned(), v);
488        }
489        Ok(params)
490    }
491
492    pub fn compile<'a, I>(&mut self, dict: I) -> Result<(),String>
493    where I: Iterator<Item = (&'a String,&'a ExprTokens)> {
494        for (name,expr) in dict.into_iter() {
495            if self.0.contains_key(name) {
496                // println!("Skipping {name} = {expr:?} : current value is {:?}", self.0.get(name).unwrap());
497                continue;
498            }
499            let v = expr.eval(self).map_err(|_| format!("Malformed parameter {name} : {expr:?}"))?;
500            // println!("Compiling {name} = {expr:?} ==> {v}");
501            self.0.insert(name.to_owned(), v);
502        }
503        Ok(())
504    }
505
506    pub fn get(&self, k: &String) -> Option<&isize> {
507        self.0.get(k)
508    }
509
510    pub fn insert(&mut self, k: String, v: isize) {
511        self.0.insert(k,v);
512    }
513
514    pub fn items(&self) -> impl Iterator<Item=(&String,&isize)> {
515        self.0.items()
516    }
517
518    #[allow(dead_code)]
519    pub fn len(&self) -> usize {
520        self.0.len()
521    }
522
523    #[allow(dead_code)]
524    pub fn is_empty(&self) -> bool {
525        self.0.len()==0
526    }
527}
528
529impl Default for ParamValues {
530    fn default() -> Self {
531        Self::new()
532    }
533}
534
535impl Display for ParamValues {
536
537    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
538        let tab = if f.alternate() {"\t"} else {""};
539        let end = if f.alternate() {"\n"} else {", "};
540        if f.alternate() {
541            writeln!(f)?;
542        }
543        for (k,v) in self.items() {
544            write!(f, "{tab}{k} = {v}{end}")?;
545        }
546        Ok(())
547    }
548}
549
550
551
552#[cfg(test)]
553mod tests_parsing {
554    use super::*;
555    use super::OpKind::*;
556    use super::FuncKind::*;
557    use super::Token::*;
558
559    #[test]
560    fn test_parse_expr() {
561        assert_eq!(
562            parse_expr("256 "),
563            Ok(ExprTokens(vec![Number(256.0)]))
564        );
565
566        assert_eq!(
567            parse_expr("$v1 +3"),
568            Ok(ExprTokens(vec![Var("v1".to_owned()), Number(3.0), Operator(Plus)]))
569        );
570
571        assert_eq!(
572            parse_expr("ceil(log2($v3-5))"),
573            Ok(ExprTokens(vec![Var("v3".to_owned()), Number(5.0), Operator(Minus), FuncCall(Log2), FuncCall(Ceil)]))
574        );
575
576        assert_eq!(
577            parse_expr("pow(3,$x )-1"),
578            Ok(ExprTokens(vec![Number(3.0), Var("x".to_owned()), FuncCall(Power), Number(1.0), Operator(Minus)]))
579        );
580    }
581
582    #[test]
583    fn test_eval_expr() {
584        let mut variables = ParamValues(OrderDict::new());
585        variables.0.insert("v1".to_owned(), 1);
586        variables.0.insert("x".to_owned(), 17);
587        let expr = parse_expr("16*(not $v1) + 256*$v1").unwrap();
588        assert_eq!(expr.eval(&variables),Ok(256));
589        let expr = parse_expr("pow(2, $x) - 1").unwrap();
590        assert_eq!(expr.eval(&variables),Ok((1<<17)-1));
591    }
592
593}