1use std::{fmt::Display, ops::{Deref, DerefMut}};
2
3use winnow::{ascii::{space0, Caseless}, combinator::{alt, delimited}, error::StrContext, Parser};
4
5use crate::{error::RifError, rifgen::{order_dict::OrderDict, GenericRange, GenericValues}};
6use super::{identifier, val_f64, val_isize, ws, Res};
7
8#[derive(Clone, Copy, PartialEq, Debug)]
9pub enum OpKind {
10 Plus,
12 Minus,
14 Mult,
16 Div,
18 Rem,
20 Pow,
22 Not,
24 ShiftLeft, ShiftRight,
26 Equal, NotEqual, Greater, GreaterEq, Lesser, LesserEq
28}
29
30#[derive(Clone, Copy, PartialEq, Debug)]
31pub enum FuncKind {
32 Log2, Log10, Power, Round, Ceil, Floor,
33}
34
35impl std::fmt::Display for FuncKind {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 use FuncKind::*;
38 match &self {
39 Log2 => write!(f, "log2"),
40 Log10 => write!(f, "log10"),
41 Power => write!(f, "pow"),
42 Round => write!(f, "round"),
43 Ceil => write!(f, "ceil"),
44 Floor => write!(f, "floor"),
45 }
46 }
47}
48
49
50#[derive(Clone, PartialEq, Debug)]
51pub enum Token {
52 Operator(OpKind),
54 FuncCall(FuncKind),
56 ParenL,
58 ParenR,
60 Comma,
62 Number(f64),
64 Var(String),
66}
67
68impl std::fmt::Display for Token {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 use Token::*;
71 use OpKind::*;
72 match &self {
73 Operator(Plus) => write!(f, "+"),
74 Operator(Minus) => write!(f, "-"),
75 Operator(Mult) => write!(f, "*"),
76 Operator(Div) => write!(f, "/"),
77 Operator(Pow) => write!(f, "^"),
78 Operator(Rem) => write!(f, "%"),
79 Operator(Equal) => write!(f, "=="),
80 Operator(NotEqual) => write!(f, "!="),
81 Operator(Greater) => write!(f, ">"),
82 Operator(GreaterEq) => write!(f, ">="),
83 Operator(Lesser) => write!(f, "<"),
84 Operator(LesserEq) => write!(f, "<="),
85 Operator(ShiftLeft) => write!(f, "<<"),
86 Operator(ShiftRight) => write!(f, ">>"),
87 Operator(Not) => write!(f, "!"),
88 FuncCall(s) => write!(f, "{s}()"),
89 ParenL => write!(f, "("),
90 ParenR => write!(f, ")"),
91 Comma => write!(f, ","),
92 Number(v) => write!(f, "{v}"),
93 Var(n) => write!(f, "${n}"),
94 }
95 }
96}
97
98fn operator<'a>(input: &mut &'a str) -> Res<'a, Token> {
99 use Token::Operator;
100 use OpKind::*;
101 alt((
102 ws("+").value(Operator(Plus)),
103 ws("-").value(Operator(Minus)),
104 ws("*").value(Operator(Mult)),
105 ws("/").value(Operator(Div)),
106 ws("^").value(Operator(Pow)),
107 ws("%").value(Operator(Rem)),
108 ws("==").value(Operator(Equal)),
109 ws("!=").value(Operator(NotEqual)),
110 ws(">=").value(Operator(GreaterEq)),
111 ws("<=").value(Operator(LesserEq)),
112 ws("<<").value(Operator(ShiftLeft)),
113 ws(">>").value(Operator(ShiftRight)),
114 ws(">").value(Operator(Greater)),
115 ws("<").value(Operator(Lesser)),
116 )).parse_next(input)
117}
118
119fn not<'a>(input: &mut &'a str) -> Res<'a, Token> {
120 ws(alt(("not","!","~"))).value(Token::Operator(OpKind::Not)).parse_next(input)
121}
122
123fn parenl<'a>(input: &mut &'a str) -> Res<'a, Token> {
124 ws("(").value(Token::ParenL).parse_next(input)
125}
126
127fn parenr<'a>(input: &mut &'a str) -> Res<'a, Token> {
128 ws(")").value(Token::ParenR).parse_next(input)
129}
130
131fn comma<'a>(input: &mut &'a str) -> Res<'a, Token> {
132 ws(",").value(Token::Comma).parse_next(input)
133}
134
135fn func_call<'a>(input: &mut &'a str) -> Res<'a, Token> {
136 use Token::FuncCall;
137 use FuncKind::*;
138 alt((
139 ws("log2(").value(FuncCall(Log2)),
140 ws("log10(").value(FuncCall(Log10)),
141 ws("pow(").value(FuncCall(Power)),
142 ws("int(").value(FuncCall(Round)),
143 ws("round(").value(FuncCall(Round)),
144 ws("ceil(").value(FuncCall(Ceil)),
145 ws("floor(").value(FuncCall(Floor)),
146 )).parse_next(input)
147}
148
149fn variable<'a>(input: &mut &'a str) -> Res<'a, Token> {
150 delimited("$", identifier, space0).map(|n| Token::Var(n.to_owned())).parse_next(input)
151}
152
153fn idx<'a>(input: &mut &'a str) -> Res<'a, Token> {
154 ws("i").value(Token::Var("i".to_owned())).parse_next(input)
155}
156
157fn number<'a>(input: &mut &'a str) -> Res<'a, Token> {
158 ws(alt((
159 val_isize.map(|v| Token::Number(v as f64)),
160 val_f64.map(Token::Number),
161 Caseless("true").value(Token::Number(1.0)),
162 Caseless("false").value(Token::Number(0.0)),
163 ))).parse_next(input)
164}
165
166fn precedence(op: OpKind) -> u8 {
167 match op {
168 OpKind::Not => 2,
170 OpKind::Mult => 3,
172 OpKind::Div => 3,
173 OpKind::Rem => 3,
174 OpKind::Plus => 4,
176 OpKind::Minus => 4,
177 OpKind::Pow => 5,
179 OpKind::ShiftLeft => 5,
181 OpKind::ShiftRight => 5,
182 OpKind::Equal => 7,
184 OpKind::NotEqual => 7,
185 OpKind::Greater => 6,
186 OpKind::GreaterEq => 6,
187 OpKind::Lesser => 6,
188 OpKind::LesserEq => 6,
189 }
190}
191
192#[derive(Clone, Copy, PartialEq, Debug)]
193enum ExprState {
194 Operand,
195 Operator
196}
197
198#[derive(Clone, Copy, PartialEq, Debug)]
199enum ExprContext {
200 SubExpr,
201 FuncCall(u8)
202}
203
204
205#[allow(dead_code)]
206pub fn parse_expr(input: &str) -> Result<ExprTokens,RifError> {
222 let mut tokens = ExprTokens::new(2);
223 let mut op_stack = ExprTokens::new(1);
225 let mut cntxt : Vec<ExprContext> = Vec::new();
226 let mut state = ExprState::Operand;
227 let mut s = input;
229 while !s.is_empty() {
230
231 let token = match state {
232 ExprState::Operand => alt((parenl,variable,idx,number,func_call, not)).context(StrContext::Label("operand")).parse_next(&mut s)?,
233 ExprState::Operator => match cntxt.last() {
234 None => operator(&mut s)?,
235 Some(ExprContext::SubExpr) |
236 Some(ExprContext::FuncCall(0)) => alt((operator,parenr)).context(StrContext::Label("function call / Sub expression")).parse_next(&mut s)?,
237 Some(ExprContext::FuncCall(_)) => alt((operator,comma)).context(StrContext::Label("function call")).parse_next(&mut s)?,
238 }
239 };
240
241 match token {
243 Token::Number(_) |
245 Token::Var(_) => {
246 tokens.push(token);
247 state = ExprState::Operator;
248 },
249 Token::Operator(OpKind::Not) => op_stack.push(token),
251 Token::Operator(op_r) => {
254 while let Some(t) = op_stack.last() {
255 match t {
256 Token::Operator(op_l) if precedence(op_r) >= precedence(*op_l) => tokens.push(op_stack.pop().unwrap()),
257 _ => break,
258 }
259 }
260 op_stack.push(token);
261 state = ExprState::Operand;
262 },
263 Token::FuncCall(kind) => {
265 op_stack.push(token);
266 let nb_sep = match kind {
267 FuncKind::Power => 1,
268 _ => 0,
269 };
270 cntxt.push(ExprContext::FuncCall(nb_sep));
271 },
272 Token::ParenL => {
274 cntxt.push(ExprContext::SubExpr);
275 op_stack.push(Token::ParenL);
276 },
277 Token::ParenR => {
279 cntxt.pop();
280 while let Some(op) = op_stack.pop() {
281 match op {
282 Token::ParenL => {
283 break;
284 },
285 Token::FuncCall(_) => {
286 tokens.push(op);
287 break
288 },
289 _ => {tokens.push(op)},
290 }
291 }
292 },
293 Token::Comma => {
296 state = ExprState::Operand;
297 if let Some(ExprContext::FuncCall(n)) = cntxt.last_mut() {
298 *n -= 1;
299 }
300 }
301 }
302 }
303
304 while let Some(op) = op_stack.pop() {
307 tokens.push(op);
308 }
309
310 Ok(tokens)
311}
312
313#[derive(Clone, Debug, PartialEq, Default)]
314pub struct ExprTokens(Vec<Token>);
315
316impl Deref for ExprTokens {
317 type Target = Vec<Token>;
318 fn deref(&self) -> &Self::Target {
319 &self.0
320 }
321}
322
323impl DerefMut for ExprTokens {
324 fn deref_mut(&mut self) -> &mut Self::Target {
325 &mut self.0
326 }
327}
328
329impl ExprTokens {
330
331 pub fn new(capacity: usize) -> Self {
332 ExprTokens(Vec::with_capacity(capacity))
333 }
334
335 pub fn eval(&self, variables: &ParamValues) -> Result<isize, ExprError> {
336 if self.is_empty() {
337 return Ok(0);
338 }
339 let mut values : Vec<f64> = Vec::with_capacity(self.len()>>1);
340 for token in self.iter() {
342 match token {
343 Token::Number(v) => values.push(*v),
344 Token::Var(n) => {
345 let v = variables.get(n).ok_or(ExprError::UnknownVar(n.to_owned()))?;
346 values.push(*v as f64)
347 },
348 Token::Operator(op) => {
349 let v2 = if *op != OpKind::Not {
350 values.pop().ok_or(ExprError::Malformed)?
351 } else {
352 0.0
353 };
354 let v1 = values.pop().ok_or(ExprError::Malformed)?;
355 let res =
356 match op {
357 OpKind::Plus => v1+v2,
358 OpKind::Minus => v1-v2,
359 OpKind::Mult => v1*v2,
360 OpKind::Div => v1/v2,
361 OpKind::Rem => (v1 as isize % v2 as isize) as f64,
362 OpKind::Pow => v1.powf(v2),
363 OpKind::Not => if v1==0.0 {1.0} else {0.0},
365 OpKind::ShiftLeft => ((v1 as isize) << v2 as usize) as f64,
367 OpKind::ShiftRight => ((v1 as isize) >> v2 as usize) as f64,
368 OpKind::Equal => if v1 == v2 {1.0} else {0.0},
370 OpKind::NotEqual => if v1 != v2 {1.0} else {0.0},
371 OpKind::Greater => if v1 > v2 {1.0} else {0.0},
372 OpKind::GreaterEq => if v1 >= v2 {1.0} else {0.0},
373 OpKind::Lesser => if v1 < v2 {1.0} else {0.0},
374 OpKind::LesserEq => if v1 <= v2 {1.0} else {0.0},
375 };
376 values.push(res);
378 },
379 Token::FuncCall(func) => {
380 let v = values.pop().ok_or(ExprError::Malformed)?;
381 let res = match func {
382 FuncKind::Log2 => v.log2(),
383 FuncKind::Log10 => v.log10(),
384 FuncKind::Power => {
385 let base = values.pop().ok_or(ExprError::Malformed)?;
386 base.powf(v)
387 },
388 FuncKind::Round => v.round(),
389 FuncKind::Ceil => v.ceil(),
390 FuncKind::Floor => v.floor(),
391 };
392 values.push(res);
393 },
394 _ => return Err(ExprError::Malformed),
396 }
397 }
398 let result = values.pop().ok_or(ExprError::Malformed)?;
400 if values.is_empty() {
401 Ok(result.round() as isize)
402 } else {
403 Err(ExprError::Malformed)
404 }
405 }
406
407 pub fn eval_with_gen(&self, variables: &ParamValues, generics: &GenericValues) -> Result<ExprValue, ExprError> {
408 match self.eval(variables) {
409 Ok(n) => Ok(ExprValue::Value(n)),
410 Err(ExprError::UnknownVar(n)) => {
411 if self.len() > 1 {
412 Err(ExprError::Malformed)
413 } else if let Some(range) = generics.get(&n) {
414 Ok(ExprValue::Range(n,range.clone()))
415 } else {
416 Err(ExprError::UnknownVar(n))
417 }
418 }
419 Err(e) => Err(e)
420 }
421 }
422}
423
424#[derive(Clone, Debug, PartialEq)]
425pub enum ExprValue {
426 Value(isize),
427 Range(String,GenericRange),
428}
429
430impl ExprValue {
431
432 pub fn max(&self) -> isize {
433 match self {
434 ExprValue::Value(n) => *n,
435 ExprValue::Range(_,r) => r.max.into(),
436 }
437 }
438}
439
440impl Default for ExprValue {
441 fn default() -> Self {
442 ExprValue::Value(0)
443 }
444}
445
446#[derive(Clone, Debug, PartialEq)]
447pub enum ExprError {
448 Malformed,
449 UnknownVar(String),
450}
451
452impl From<ExprError> for String {
453 fn from(value: ExprError) -> Self {
454 match value {
455 ExprError::Malformed => "Malformed expression".to_owned(),
456 ExprError::UnknownVar(v) => format!("Unknown var {v} in expression"),
457 }
458 }
459}
460
461#[derive(Clone, Debug)]
462pub struct ParamValues(OrderDict<String,isize>);
463
464
465
466impl ParamValues {
467
468 pub fn new() -> Self {
469 ParamValues(OrderDict::new())
470 }
471
472 pub fn new_with_idx(idx: isize) -> Self {
473 let mut params = ParamValues(OrderDict::new());
474 params.0.insert("i".to_owned(), idx);
475 params
476 }
477
478 pub fn from_items<'a, I>(dict: I) -> Result<Self,String>
479 where I: Iterator<Item = (&'a String,&'a ExprTokens)> {
480 let mut params = ParamValues(OrderDict::new());
481 for (name,expr) in dict.into_iter() {
482 let v = expr.eval(¶ms).map_err(|_| format!("Malformed parameter {name} : {expr:?}"))?;
483 params.0.insert(name.to_owned(), v);
484 }
485 Ok(params)
486 }
487
488 pub fn compile<'a, I>(&mut self, dict: I) -> Result<(),String>
489 where I: Iterator<Item = (&'a String,&'a ExprTokens)> {
490 for (name,expr) in dict.into_iter() {
491 if self.0.contains_key(name) {
492 continue;
494 }
495 let v = expr.eval(self).map_err(|_| format!("Malformed parameter {name} : {expr:?}"))?;
496 self.0.insert(name.to_owned(), v);
498 }
499 Ok(())
500 }
501
502 pub fn get(&self, k: &String) -> Option<&isize> {
503 self.0.get(k)
504 }
505
506 pub fn insert(&mut self, k: String, v: isize) {
507 self.0.insert(k,v);
508 }
509
510 pub fn items(&self) -> impl Iterator<Item=(&String,&isize)> {
511 self.0.items()
512 }
513
514 #[allow(dead_code)]
515 pub fn len(&self) -> usize {
516 self.0.len()
517 }
518
519 #[allow(dead_code)]
520 pub fn is_empty(&self) -> bool {
521 self.0.len()==0
522 }
523}
524
525impl Default for ParamValues {
526 fn default() -> Self {
527 Self::new()
528 }
529}
530
531impl Display for ParamValues {
532
533 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
534 let tab = if f.alternate() {"\t"} else {""};
535 let end = if f.alternate() {"\n"} else {", "};
536 if f.alternate() {
537 writeln!(f)?;
538 }
539 for (k,v) in self.items() {
540 write!(f, "{tab}{k} = {v}{end}")?;
541 }
542 Ok(())
543 }
544}
545
546
547
548#[cfg(test)]
549mod tests_parsing {
550 use super::*;
551 use super::OpKind::*;
552 use super::FuncKind::*;
553 use super::Token::*;
554
555 #[test]
556 fn test_parse_expr() {
557 assert_eq!(
558 parse_expr("256 "),
559 Ok(ExprTokens(vec![Number(256.0)]))
560 );
561
562 assert_eq!(
563 parse_expr("$v1 +3"),
564 Ok(ExprTokens(vec![Var("v1".to_owned()), Number(3.0), Operator(Plus)]))
565 );
566
567 assert_eq!(
568 parse_expr("ceil(log2($v3-5))"),
569 Ok(ExprTokens(vec![Var("v3".to_owned()), Number(5.0), Operator(Minus), FuncCall(Log2), FuncCall(Ceil)]))
570 );
571
572 assert_eq!(
573 parse_expr("pow(3,$x )-1"),
574 Ok(ExprTokens(vec![Number(3.0), Var("x".to_owned()), FuncCall(Power), Number(1.0), Operator(Minus)]))
575 );
576 }
577
578 #[test]
579 fn test_eval_expr() {
580 let mut variables = ParamValues(OrderDict::new());
581 variables.0.insert("v1".to_owned(), 1);
582 variables.0.insert("x".to_owned(), 17);
583 let expr = parse_expr("16*(not $v1) + 256*$v1").unwrap();
584 assert_eq!(expr.eval(&variables),Ok(256));
585 let expr = parse_expr("pow(2, $x) - 1").unwrap();
586 assert_eq!(expr.eval(&variables),Ok((1<<17)-1));
587 }
588
589}