1use std::{fmt::Display, ops::{Deref, DerefMut}};
2
3use winnow::{ascii::{space0, Caseless}, combinator::{alt, delimited}, error::StrContext, Parser};
4
5use crate::{error::RifError, rifgen::{order_dict::OrderDict, GenericRange, GenericValues}};
6use super::{identifier, val_f64, val_isize, ws, Res};
7
8#[derive(Clone, Copy, PartialEq, Debug)]
9pub enum OpKind {
10 Plus,
12 Minus,
14 Mult,
16 Div,
18 Rem,
20 Pow,
22 Not,
24 ShiftLeft, ShiftRight,
26 Equal, NotEqual, Greater, GreaterEq, Lesser, LesserEq
28}
29
30#[derive(Clone, Copy, PartialEq, Debug)]
31pub enum FuncKind {
32 Log2, Log10, Power, Round, Ceil, Floor,
33}
34
35impl std::fmt::Display for FuncKind {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 use FuncKind::*;
38 match &self {
39 Log2 => write!(f, "log2"),
40 Log10 => write!(f, "log10"),
41 Power => write!(f, "pow"),
42 Round => write!(f, "round"),
43 Ceil => write!(f, "ceil"),
44 Floor => write!(f, "floor"),
45 }
46 }
47}
48
49
50#[derive(Clone, PartialEq, Debug)]
51pub enum Token {
52 Operator(OpKind),
54 FuncCall(FuncKind),
56 ParenL,
58 ParenR,
60 Comma,
62 Number(f64),
64 Var(String),
66}
67
68impl std::fmt::Display for Token {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 use Token::*;
71 use OpKind::*;
72 match &self {
73 Operator(Plus) => write!(f, "+"),
74 Operator(Minus) => write!(f, "-"),
75 Operator(Mult) => write!(f, "*"),
76 Operator(Div) => write!(f, "/"),
77 Operator(Pow) => write!(f, "^"),
78 Operator(Rem) => write!(f, "%"),
79 Operator(Equal) => write!(f, "=="),
80 Operator(NotEqual) => write!(f, "!="),
81 Operator(Greater) => write!(f, ">"),
82 Operator(GreaterEq) => write!(f, ">="),
83 Operator(Lesser) => write!(f, "<"),
84 Operator(LesserEq) => write!(f, "<="),
85 Operator(ShiftLeft) => write!(f, "<<"),
86 Operator(ShiftRight) => write!(f, ">>"),
87 Operator(Not) => write!(f, "!"),
88 FuncCall(s) => write!(f, "{s}()"),
89 ParenL => write!(f, "("),
90 ParenR => write!(f, ")"),
91 Comma => write!(f, ","),
92 Number(v) => write!(f, "{v}"),
93 Var(n) => write!(f, "${n}"),
94 }
95 }
96}
97
98fn operator<'a>(input: &mut &'a str) -> Res<'a, Token> {
99 use Token::Operator;
100 use OpKind::*;
101 alt((
102 alt((
103 ws("+").value(Operator(Plus)),
104 ws("-").value(Operator(Minus)),
105 ws("*").value(Operator(Mult)),
106 ws("/").value(Operator(Div)),
107 ws("^").value(Operator(Pow)),
108 ws("%").value(Operator(Rem)),
109 )),
110 alt((
111 ws("==").value(Operator(Equal)),
112 ws("!=").value(Operator(NotEqual)),
113 ws(">=").value(Operator(GreaterEq)),
114 ws("<=").value(Operator(LesserEq)),
115 ws("<<").value(Operator(ShiftLeft)),
116 ws(">>").value(Operator(ShiftRight)),
117 ws(">").value(Operator(Greater)),
118 ws("<").value(Operator(Lesser)),
119 ))
120 )).parse_next(input)
121}
122
123fn not<'a>(input: &mut &'a str) -> Res<'a, Token> {
124 ws(alt(("not","!","~"))).value(Token::Operator(OpKind::Not)).parse_next(input)
125}
126
127fn parenl<'a>(input: &mut &'a str) -> Res<'a, Token> {
128 ws("(").value(Token::ParenL).parse_next(input)
129}
130
131fn parenr<'a>(input: &mut &'a str) -> Res<'a, Token> {
132 ws(")").value(Token::ParenR).parse_next(input)
133}
134
135fn comma<'a>(input: &mut &'a str) -> Res<'a, Token> {
136 ws(",").value(Token::Comma).parse_next(input)
137}
138
139fn func_call<'a>(input: &mut &'a str) -> Res<'a, Token> {
140 use Token::FuncCall;
141 use FuncKind::*;
142 alt((
143 ws("log2(").value(FuncCall(Log2)),
144 ws("log10(").value(FuncCall(Log10)),
145 ws("pow(").value(FuncCall(Power)),
146 ws("int(").value(FuncCall(Round)),
147 ws("round(").value(FuncCall(Round)),
148 ws("ceil(").value(FuncCall(Ceil)),
149 ws("floor(").value(FuncCall(Floor)),
150 )).parse_next(input)
151}
152
153fn variable<'a>(input: &mut &'a str) -> Res<'a, Token> {
154 delimited("$", identifier, space0).map(|n| Token::Var(n.to_owned())).parse_next(input)
155}
156
157fn idx<'a>(input: &mut &'a str) -> Res<'a, Token> {
158 ws("i").value(Token::Var("i".to_owned())).parse_next(input)
159}
160
161fn number<'a>(input: &mut &'a str) -> Res<'a, Token> {
162 ws(alt((
163 val_isize.map(|v| Token::Number(v as f64)),
164 val_f64.map(Token::Number),
165 Caseless("true").value(Token::Number(1.0)),
166 Caseless("false").value(Token::Number(0.0)),
167 ))).parse_next(input)
168}
169
170fn precedence(op: OpKind) -> u8 {
171 match op {
172 OpKind::Not => 2,
174 OpKind::Mult => 3,
176 OpKind::Div => 3,
177 OpKind::Rem => 3,
178 OpKind::Plus => 4,
180 OpKind::Minus => 4,
181 OpKind::Pow => 5,
183 OpKind::ShiftLeft => 5,
185 OpKind::ShiftRight => 5,
186 OpKind::Equal => 7,
188 OpKind::NotEqual => 7,
189 OpKind::Greater => 6,
190 OpKind::GreaterEq => 6,
191 OpKind::Lesser => 6,
192 OpKind::LesserEq => 6,
193 }
194}
195
196#[derive(Clone, Copy, PartialEq, Debug)]
197enum ExprState {
198 Operand,
199 Operator
200}
201
202#[derive(Clone, Copy, PartialEq, Debug)]
203enum ExprContext {
204 SubExpr,
205 FuncCall(u8)
206}
207
208
209#[allow(dead_code)]
210pub fn parse_expr(input: &str) -> Result<ExprTokens,RifError> {
226 let mut tokens = ExprTokens::new(2);
227 let mut op_stack = ExprTokens::new(1);
229 let mut cntxt : Vec<ExprContext> = Vec::new();
230 let mut state = ExprState::Operand;
231 let mut s = input;
233 while !s.is_empty() {
234
235 let token = match state {
236 ExprState::Operand => alt((parenl,variable,idx,number,func_call, not)).context(StrContext::Label("operand")).parse_next(&mut s)?,
237 ExprState::Operator => match cntxt.last() {
238 None => operator(&mut s)?,
239 Some(ExprContext::SubExpr) |
240 Some(ExprContext::FuncCall(0)) => alt((operator,parenr)).context(StrContext::Label("function call / Sub expression")).parse_next(&mut s)?,
241 Some(ExprContext::FuncCall(_)) => alt((operator,comma)).context(StrContext::Label("function call")).parse_next(&mut s)?,
242 }
243 };
244
245 match token {
247 Token::Number(_) |
249 Token::Var(_) => {
250 tokens.push(token);
251 state = ExprState::Operator;
252 },
253 Token::Operator(OpKind::Not) => op_stack.push(token),
255 Token::Operator(op_r) => {
258 while let Some(t) = op_stack.last() {
259 match t {
260 Token::Operator(op_l) if precedence(op_r) >= precedence(*op_l) => tokens.push(op_stack.pop().unwrap()),
261 _ => break,
262 }
263 }
264 op_stack.push(token);
265 state = ExprState::Operand;
266 },
267 Token::FuncCall(kind) => {
269 op_stack.push(token);
270 let nb_sep = match kind {
271 FuncKind::Power => 1,
272 _ => 0,
273 };
274 cntxt.push(ExprContext::FuncCall(nb_sep));
275 },
276 Token::ParenL => {
278 cntxt.push(ExprContext::SubExpr);
279 op_stack.push(Token::ParenL);
280 },
281 Token::ParenR => {
283 cntxt.pop();
284 while let Some(op) = op_stack.pop() {
285 match op {
286 Token::ParenL => {
287 break;
288 },
289 Token::FuncCall(_) => {
290 tokens.push(op);
291 break
292 },
293 _ => {tokens.push(op)},
294 }
295 }
296 },
297 Token::Comma => {
300 state = ExprState::Operand;
301 if let Some(ExprContext::FuncCall(n)) = cntxt.last_mut() {
302 *n -= 1;
303 }
304 }
305 }
306 }
307
308 while let Some(op) = op_stack.pop() {
311 tokens.push(op);
312 }
313
314 Ok(tokens)
315}
316
317#[derive(Clone, Debug, PartialEq, Default)]
318pub struct ExprTokens(Vec<Token>);
319
320impl Deref for ExprTokens {
321 type Target = Vec<Token>;
322 fn deref(&self) -> &Self::Target {
323 &self.0
324 }
325}
326
327impl DerefMut for ExprTokens {
328 fn deref_mut(&mut self) -> &mut Self::Target {
329 &mut self.0
330 }
331}
332
333impl ExprTokens {
334
335 pub fn new(capacity: usize) -> Self {
336 ExprTokens(Vec::with_capacity(capacity))
337 }
338
339 pub fn eval(&self, variables: &ParamValues) -> Result<isize, ExprError> {
340 if self.is_empty() {
341 return Ok(0);
342 }
343 let mut values : Vec<f64> = Vec::with_capacity(self.len()>>1);
344 for token in self.iter() {
346 match token {
347 Token::Number(v) => values.push(*v),
348 Token::Var(n) => {
349 let v = variables.get(n).ok_or(ExprError::UnknownVar(n.to_owned()))?;
350 values.push(*v as f64)
351 },
352 Token::Operator(op) => {
353 let v2 = if *op != OpKind::Not {
354 values.pop().ok_or(ExprError::Malformed)?
355 } else {
356 0.0
357 };
358 let v1 = values.pop().ok_or(ExprError::Malformed)?;
359 let res =
360 match op {
361 OpKind::Plus => v1+v2,
362 OpKind::Minus => v1-v2,
363 OpKind::Mult => v1*v2,
364 OpKind::Div => v1/v2,
365 OpKind::Rem => (v1 as isize % v2 as isize) as f64,
366 OpKind::Pow => v1.powf(v2),
367 OpKind::Not => if v1==0.0 {1.0} else {0.0},
369 OpKind::ShiftLeft => ((v1 as isize) << v2 as usize) as f64,
371 OpKind::ShiftRight => ((v1 as isize) >> v2 as usize) as f64,
372 OpKind::Equal => if v1 == v2 {1.0} else {0.0},
374 OpKind::NotEqual => if v1 != v2 {1.0} else {0.0},
375 OpKind::Greater => if v1 > v2 {1.0} else {0.0},
376 OpKind::GreaterEq => if v1 >= v2 {1.0} else {0.0},
377 OpKind::Lesser => if v1 < v2 {1.0} else {0.0},
378 OpKind::LesserEq => if v1 <= v2 {1.0} else {0.0},
379 };
380 values.push(res);
382 },
383 Token::FuncCall(func) => {
384 let v = values.pop().ok_or(ExprError::Malformed)?;
385 let res = match func {
386 FuncKind::Log2 => v.log2(),
387 FuncKind::Log10 => v.log10(),
388 FuncKind::Power => {
389 let base = values.pop().ok_or(ExprError::Malformed)?;
390 base.powf(v)
391 },
392 FuncKind::Round => v.round(),
393 FuncKind::Ceil => v.ceil(),
394 FuncKind::Floor => v.floor(),
395 };
396 values.push(res);
397 },
398 _ => return Err(ExprError::Malformed),
400 }
401 }
402 let result = values.pop().ok_or(ExprError::Malformed)?;
404 if values.is_empty() {
405 Ok(result.round() as isize)
406 } else {
407 Err(ExprError::Malformed)
408 }
409 }
410
411 pub fn eval_with_gen(&self, variables: &ParamValues, generics: &GenericValues) -> Result<ExprValue, ExprError> {
412 match self.eval(variables) {
413 Ok(n) => Ok(ExprValue::Value(n)),
414 Err(ExprError::UnknownVar(n)) => {
415 if self.len() > 1 {
416 Err(ExprError::Malformed)
417 } else if let Some(range) = generics.get(&n) {
418 Ok(ExprValue::Range(n,range.clone()))
419 } else {
420 Err(ExprError::UnknownVar(n))
421 }
422 }
423 Err(e) => Err(e)
424 }
425 }
426}
427
428#[derive(Clone, Debug, PartialEq)]
429pub enum ExprValue {
430 Value(isize),
431 Range(String,GenericRange),
432}
433
434impl ExprValue {
435
436 pub fn max(&self) -> isize {
437 match self {
438 ExprValue::Value(n) => *n,
439 ExprValue::Range(_,r) => r.max as isize,
440 }
441 }
442}
443
444impl Default for ExprValue {
445 fn default() -> Self {
446 ExprValue::Value(0)
447 }
448}
449
450#[derive(Clone, Debug, PartialEq)]
451pub enum ExprError {
452 Malformed,
453 UnknownVar(String),
454}
455
456impl From<ExprError> for String {
457 fn from(value: ExprError) -> Self {
458 match value {
459 ExprError::Malformed => "Malformed expression".to_owned(),
460 ExprError::UnknownVar(v) => format!("Unknown var {v} in expression"),
461 }
462 }
463}
464
465#[derive(Clone, Debug)]
466pub struct ParamValues(OrderDict<String,isize>);
467
468
469
470impl ParamValues {
471
472 pub fn new() -> Self {
473 ParamValues(OrderDict::new())
474 }
475
476 pub fn new_with_idx(idx: isize) -> Self {
477 let mut params = ParamValues(OrderDict::new());
478 params.0.insert("i".to_owned(), idx);
479 params
480 }
481
482 pub fn from_items<'a, I>(dict: I) -> Result<Self,String>
483 where I: Iterator<Item = (&'a String,&'a ExprTokens)> {
484 let mut params = ParamValues(OrderDict::new());
485 for (name,expr) in dict.into_iter() {
486 let v = expr.eval(¶ms).map_err(|_| format!("Malformed parameter {name} : {expr:?}"))?;
487 params.0.insert(name.to_owned(), v);
488 }
489 Ok(params)
490 }
491
492 pub fn compile<'a, I>(&mut self, dict: I) -> Result<(),String>
493 where I: Iterator<Item = (&'a String,&'a ExprTokens)> {
494 for (name,expr) in dict.into_iter() {
495 if self.0.contains_key(name) {
496 continue;
498 }
499 let v = expr.eval(self).map_err(|_| format!("Malformed parameter {name} : {expr:?}"))?;
500 self.0.insert(name.to_owned(), v);
502 }
503 Ok(())
504 }
505
506 pub fn get(&self, k: &String) -> Option<&isize> {
507 self.0.get(k)
508 }
509
510 pub fn insert(&mut self, k: String, v: isize) {
511 self.0.insert(k,v);
512 }
513
514 pub fn items(&self) -> impl Iterator<Item=(&String,&isize)> {
515 self.0.items()
516 }
517
518 #[allow(dead_code)]
519 pub fn len(&self) -> usize {
520 self.0.len()
521 }
522
523 #[allow(dead_code)]
524 pub fn is_empty(&self) -> bool {
525 self.0.len()==0
526 }
527}
528
529impl Default for ParamValues {
530 fn default() -> Self {
531 Self::new()
532 }
533}
534
535impl Display for ParamValues {
536
537 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
538 let tab = if f.alternate() {"\t"} else {""};
539 let end = if f.alternate() {"\n"} else {", "};
540 if f.alternate() {
541 writeln!(f)?;
542 }
543 for (k,v) in self.items() {
544 write!(f, "{tab}{k} = {v}{end}")?;
545 }
546 Ok(())
547 }
548}
549
550
551
552#[cfg(test)]
553mod tests_parsing {
554 use super::*;
555 use super::OpKind::*;
556 use super::FuncKind::*;
557 use super::Token::*;
558
559 #[test]
560 fn test_parse_expr() {
561 assert_eq!(
562 parse_expr("256 "),
563 Ok(ExprTokens(vec![Number(256.0)]))
564 );
565
566 assert_eq!(
567 parse_expr("$v1 +3"),
568 Ok(ExprTokens(vec![Var("v1".to_owned()), Number(3.0), Operator(Plus)]))
569 );
570
571 assert_eq!(
572 parse_expr("ceil(log2($v3-5))"),
573 Ok(ExprTokens(vec![Var("v3".to_owned()), Number(5.0), Operator(Minus), FuncCall(Log2), FuncCall(Ceil)]))
574 );
575
576 assert_eq!(
577 parse_expr("pow(3,$x )-1"),
578 Ok(ExprTokens(vec![Number(3.0), Var("x".to_owned()), FuncCall(Power), Number(1.0), Operator(Minus)]))
579 );
580 }
581
582 #[test]
583 fn test_eval_expr() {
584 let mut variables = ParamValues(OrderDict::new());
585 variables.0.insert("v1".to_owned(), 1);
586 variables.0.insert("x".to_owned(), 17);
587 let expr = parse_expr("16*(not $v1) + 256*$v1").unwrap();
588 assert_eq!(expr.eval(&variables),Ok(256));
589 let expr = parse_expr("pow(2, $x) - 1").unwrap();
590 assert_eq!(expr.eval(&variables),Ok((1<<17)-1));
591 }
592
593}