1#![cfg_attr(feature = "func", doc = "```")]
94#![cfg_attr(not(feature = "func"), doc = "```ignore")]
95#![cfg_attr(feature = "cond", doc = "```")]
121#![cfg_attr(not(feature = "cond"), doc = "```ignore")]
122use std::collections::HashMap;
138#[cfg(feature = "introspect")]
139use std::collections::HashSet;
140#[cfg(feature = "func")]
141use std::sync::Arc;
142
143use num_traits::{Num, NumCast, One, Zero};
144use std::ops::Neg;
145
146#[cfg(feature = "optimize")]
147pub mod optimize;
148
149pub trait Numeric:
158 Num
159 + NumCast
160 + Copy
161 + PartialOrd
162 + Zero
163 + One
164 + Neg<Output = Self>
165 + std::fmt::Debug
166 + Send
167 + Sync
168 + 'static
169{
170 fn supports_bitwise() -> bool;
172
173 fn is_float() -> bool;
175
176 fn numeric_pow(self, exp: Self) -> Option<Self>;
180}
181
182impl Numeric for f32 {
183 fn supports_bitwise() -> bool {
184 false
185 }
186 fn is_float() -> bool {
187 true
188 }
189 fn numeric_pow(self, exp: Self) -> Option<Self> {
190 Some(self.powf(exp))
191 }
192}
193
194impl Numeric for f64 {
195 fn supports_bitwise() -> bool {
196 false
197 }
198 fn is_float() -> bool {
199 true
200 }
201 fn numeric_pow(self, exp: Self) -> Option<Self> {
202 Some(self.powf(exp))
203 }
204}
205
206impl Numeric for i32 {
207 fn supports_bitwise() -> bool {
208 true
209 }
210 fn is_float() -> bool {
211 false
212 }
213 fn numeric_pow(self, exp: Self) -> Option<Self> {
214 if exp < 0 {
215 return None;
216 }
217 Some(self.pow(exp as u32))
218 }
219}
220
221impl Numeric for i64 {
222 fn supports_bitwise() -> bool {
223 true
224 }
225 fn is_float() -> bool {
226 false
227 }
228 fn numeric_pow(self, exp: Self) -> Option<Self> {
229 if exp < 0 {
230 return None;
231 }
232 Some(self.pow(exp as u32))
233 }
234}
235
236#[cfg(feature = "func")]
248pub trait ExprFn: Send + Sync {
249 fn name(&self) -> &str;
251
252 fn arg_count(&self) -> usize;
254
255 fn call(&self, args: &[f32]) -> f32;
257
258 fn decompose(&self, _args: &[Ast]) -> Option<Ast> {
261 None
262 }
263}
264
265#[cfg(feature = "func")]
267#[derive(Clone, Default)]
268pub struct FunctionRegistry {
269 funcs: HashMap<String, Arc<dyn ExprFn>>,
270}
271
272#[cfg(feature = "func")]
273impl FunctionRegistry {
274 pub fn new() -> Self {
276 Self::default()
277 }
278
279 pub fn register<F: ExprFn + 'static>(&mut self, func: F) {
281 self.funcs.insert(func.name().to_string(), Arc::new(func));
282 }
283
284 pub fn get(&self, name: &str) -> Option<&Arc<dyn ExprFn>> {
286 self.funcs.get(name)
287 }
288
289 pub fn names(&self) -> impl Iterator<Item = &str> {
291 self.funcs.keys().map(|s| s.as_str())
292 }
293}
294
295#[derive(Debug, Clone, PartialEq)]
301pub enum ParseError {
302 UnexpectedChar(char),
303 UnexpectedEnd,
304 UnexpectedToken(String),
305 InvalidNumber(String),
306}
307
308impl std::fmt::Display for ParseError {
309 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310 match self {
311 ParseError::UnexpectedChar(c) => write!(f, "unexpected character: '{}'", c),
312 ParseError::UnexpectedEnd => write!(f, "unexpected end of expression"),
313 ParseError::UnexpectedToken(t) => write!(f, "unexpected token: '{}'", t),
314 ParseError::InvalidNumber(s) => write!(f, "invalid number: '{}'", s),
315 }
316 }
317}
318
319impl std::error::Error for ParseError {}
320
321#[derive(Debug, Clone, PartialEq)]
323pub enum EvalError {
324 UnknownVariable(String),
325 #[cfg(feature = "func")]
326 UnknownFunction(String),
327 #[cfg(feature = "func")]
328 WrongArgCount {
329 func: String,
330 expected: usize,
331 got: usize,
332 },
333 UnsupportedOperation(String),
335}
336
337impl std::fmt::Display for EvalError {
338 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339 match self {
340 EvalError::UnknownVariable(name) => write!(f, "unknown variable: '{}'", name),
341 #[cfg(feature = "func")]
342 EvalError::UnknownFunction(name) => write!(f, "unknown function: '{}'", name),
343 #[cfg(feature = "func")]
344 EvalError::WrongArgCount {
345 func,
346 expected,
347 got,
348 } => {
349 write!(
350 f,
351 "function '{}' expects {} args, got {}",
352 func, expected, got
353 )
354 }
355 EvalError::UnsupportedOperation(op) => {
356 write!(f, "unsupported operation for this numeric type: '{}'", op)
357 }
358 }
359 }
360}
361
362impl std::error::Error for EvalError {}
363
364#[derive(Debug, Clone, PartialEq)]
369enum Token {
370 Number(f64),
371 Ident(String),
372 Plus,
373 Minus,
374 Star,
375 Slash,
376 Caret,
377 Percent,
378 Ampersand,
379 Pipe,
380 Tilde,
381 Shl,
382 Shr,
383 LParen,
384 RParen,
385 #[cfg(feature = "func")]
386 Comma,
387 #[cfg(feature = "cond")]
389 Lt,
390 #[cfg(feature = "cond")]
391 Le,
392 #[cfg(feature = "cond")]
393 Gt,
394 #[cfg(feature = "cond")]
395 Ge,
396 #[cfg(feature = "cond")]
397 Eq,
398 #[cfg(feature = "cond")]
399 Ne,
400 #[cfg(feature = "cond")]
402 And,
403 #[cfg(feature = "cond")]
404 Or,
405 #[cfg(feature = "cond")]
406 Not,
407 #[cfg(feature = "cond")]
409 If,
410 #[cfg(feature = "cond")]
411 Then,
412 #[cfg(feature = "cond")]
413 Else,
414 Let,
416 Assign,
417 Semicolon,
418 Eof,
419}
420
421struct Lexer<'a> {
422 input: &'a str,
423 pos: usize,
424}
425
426impl<'a> Lexer<'a> {
427 fn new(input: &'a str) -> Self {
428 Self { input, pos: 0 }
429 }
430
431 fn peek_char(&self) -> Option<char> {
432 self.input[self.pos..].chars().next()
433 }
434
435 fn next_char(&mut self) -> Option<char> {
436 let c = self.peek_char()?;
437 self.pos += c.len_utf8();
438 Some(c)
439 }
440
441 fn skip_whitespace(&mut self) {
442 while let Some(c) = self.peek_char() {
443 if c.is_whitespace() {
444 self.next_char();
445 } else {
446 break;
447 }
448 }
449 }
450
451 fn read_number(&mut self) -> Result<f64, ParseError> {
452 let start = self.pos;
453 while let Some(c) = self.peek_char() {
454 if c.is_ascii_digit() || c == '.' {
455 self.next_char();
456 } else {
457 break;
458 }
459 }
460 let s = &self.input[start..self.pos];
461 s.parse()
462 .map_err(|_| ParseError::InvalidNumber(s.to_string()))
463 }
464
465 fn read_ident(&mut self) -> String {
466 let start = self.pos;
467 while let Some(c) = self.peek_char() {
468 if c.is_alphanumeric() || c == '_' {
469 self.next_char();
470 } else {
471 break;
472 }
473 }
474 self.input[start..self.pos].to_string()
475 }
476
477 fn next_token(&mut self) -> Result<Token, ParseError> {
478 self.skip_whitespace();
479
480 let Some(c) = self.peek_char() else {
481 return Ok(Token::Eof);
482 };
483
484 match c {
485 '+' => {
486 self.next_char();
487 Ok(Token::Plus)
488 }
489 '-' => {
490 self.next_char();
491 Ok(Token::Minus)
492 }
493 '*' => {
494 self.next_char();
495 Ok(Token::Star)
496 }
497 '/' => {
498 self.next_char();
499 Ok(Token::Slash)
500 }
501 '^' => {
502 self.next_char();
503 Ok(Token::Caret)
504 }
505 '%' => {
506 self.next_char();
507 Ok(Token::Percent)
508 }
509 '&' => {
510 self.next_char();
511 Ok(Token::Ampersand)
512 }
513 '|' => {
514 self.next_char();
515 Ok(Token::Pipe)
516 }
517 '~' => {
518 self.next_char();
519 Ok(Token::Tilde)
520 }
521 '(' => {
522 self.next_char();
523 Ok(Token::LParen)
524 }
525 ')' => {
526 self.next_char();
527 Ok(Token::RParen)
528 }
529 #[cfg(feature = "func")]
530 ',' => {
531 self.next_char();
532 Ok(Token::Comma)
533 }
534 ';' => {
535 self.next_char();
536 Ok(Token::Semicolon)
537 }
538 '<' => {
539 self.next_char();
540 if self.peek_char() == Some('<') {
541 self.next_char();
542 Ok(Token::Shl)
543 } else {
544 #[cfg(feature = "cond")]
545 {
546 if self.peek_char() == Some('=') {
547 self.next_char();
548 Ok(Token::Le)
549 } else {
550 Ok(Token::Lt)
551 }
552 }
553 #[cfg(not(feature = "cond"))]
554 Err(ParseError::UnexpectedChar('<'))
555 }
556 }
557 '>' => {
558 self.next_char();
559 if self.peek_char() == Some('>') {
560 self.next_char();
561 Ok(Token::Shr)
562 } else {
563 #[cfg(feature = "cond")]
564 {
565 if self.peek_char() == Some('=') {
566 self.next_char();
567 Ok(Token::Ge)
568 } else {
569 Ok(Token::Gt)
570 }
571 }
572 #[cfg(not(feature = "cond"))]
573 Err(ParseError::UnexpectedChar('>'))
574 }
575 }
576 '=' => {
577 self.next_char();
578 if self.peek_char() == Some('=') {
579 self.next_char();
580 #[cfg(feature = "cond")]
581 {
582 Ok(Token::Eq)
583 }
584 #[cfg(not(feature = "cond"))]
585 {
586 Err(ParseError::UnexpectedChar('='))
587 }
588 } else {
589 Ok(Token::Assign)
590 }
591 }
592 #[cfg(feature = "cond")]
593 '!' => {
594 self.next_char();
595 if self.peek_char() == Some('=') {
596 self.next_char();
597 Ok(Token::Ne)
598 } else {
599 Err(ParseError::UnexpectedChar('!'))
600 }
601 }
602 '0'..='9' | '.' => Ok(Token::Number(self.read_number()?)),
603 'a'..='z' | 'A'..='Z' | '_' => {
604 let ident = self.read_ident();
605 if ident == "let" {
607 return Ok(Token::Let);
608 }
609 #[cfg(feature = "cond")]
610 match ident.as_str() {
611 "and" => return Ok(Token::And),
612 "or" => return Ok(Token::Or),
613 "not" => return Ok(Token::Not),
614 "if" => return Ok(Token::If),
615 "then" => return Ok(Token::Then),
616 "else" => return Ok(Token::Else),
617 _ => {}
618 }
619 Ok(Token::Ident(ident))
620 }
621 _ => Err(ParseError::UnexpectedChar(c)),
622 }
623 }
624}
625
626#[derive(Debug, Clone, PartialEq)]
657pub enum Ast {
658 Num(f64),
660 Var(String),
662 BinOp(BinOp, Box<Ast>, Box<Ast>),
664 UnaryOp(UnaryOp, Box<Ast>),
666 #[cfg(feature = "func")]
668 Call(String, Vec<Ast>),
669 #[cfg(feature = "cond")]
671 Compare(CompareOp, Box<Ast>, Box<Ast>),
672 #[cfg(feature = "cond")]
674 And(Box<Ast>, Box<Ast>),
675 #[cfg(feature = "cond")]
677 Or(Box<Ast>, Box<Ast>),
678 #[cfg(feature = "cond")]
680 If(Box<Ast>, Box<Ast>, Box<Ast>),
681 Let {
683 name: String,
684 value: Box<Ast>,
685 body: Box<Ast>,
686 },
687}
688
689#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
693pub enum BinOp {
694 Add,
696 Sub,
698 Mul,
700 Div,
702 Pow,
704 Rem,
706 BitAnd,
708 BitOr,
710 Shl,
712 Shr,
714}
715
716#[cfg(feature = "cond")]
721#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
722pub enum CompareOp {
723 Lt,
725 Le,
727 Gt,
729 Ge,
731 Eq,
733 Ne,
735}
736
737#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
741pub enum UnaryOp {
742 Neg,
744 #[cfg(feature = "cond")]
747 Not,
748 BitNot,
750}
751
752impl std::fmt::Display for Ast {
757 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
758 match self {
759 Ast::Num(n) => {
760 if n.is_nan() {
761 write!(f, "(0.0 / 0.0)") } else if n.is_infinite() {
763 if *n > 0.0 {
764 write!(f, "(1.0 / 0.0)") } else {
766 write!(f, "(-1.0 / 0.0)") }
768 } else {
769 write!(f, "{}", n)
770 }
771 }
772 Ast::Var(name) => write!(f, "{}", name),
773 Ast::BinOp(op, left, right) => {
774 write!(f, "({} {} {})", left, op, right)
775 }
776 Ast::UnaryOp(op, inner) => {
777 write!(f, "({}{})", op, inner)
778 }
779 #[cfg(feature = "func")]
780 Ast::Call(name, args) => {
781 write!(f, "{}(", name)?;
782 for (i, arg) in args.iter().enumerate() {
783 if i > 0 {
784 write!(f, ", ")?;
785 }
786 write!(f, "{}", arg)?;
787 }
788 write!(f, ")")
789 }
790 #[cfg(feature = "cond")]
791 Ast::Compare(op, left, right) => {
792 write!(f, "({} {} {})", left, op, right)
793 }
794 #[cfg(feature = "cond")]
795 Ast::And(left, right) => {
796 write!(f, "({} and {})", left, right)
797 }
798 #[cfg(feature = "cond")]
799 Ast::Or(left, right) => {
800 write!(f, "({} or {})", left, right)
801 }
802 #[cfg(feature = "cond")]
803 Ast::If(cond, then_expr, else_expr) => {
804 write!(f, "(if {} then {} else {})", cond, then_expr, else_expr)
805 }
806 Ast::Let { name, value, body } => {
807 write!(f, "(let {} = {}; {})", name, value, body)
808 }
809 }
810 }
811}
812
813impl std::fmt::Display for BinOp {
814 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
815 match self {
816 BinOp::Add => write!(f, "+"),
817 BinOp::Sub => write!(f, "-"),
818 BinOp::Mul => write!(f, "*"),
819 BinOp::Div => write!(f, "/"),
820 BinOp::Pow => write!(f, "^"),
821 BinOp::Rem => write!(f, "%"),
822 BinOp::BitAnd => write!(f, "&"),
823 BinOp::BitOr => write!(f, "|"),
824 BinOp::Shl => write!(f, "<<"),
825 BinOp::Shr => write!(f, ">>"),
826 }
827 }
828}
829
830impl std::fmt::Display for UnaryOp {
831 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
832 match self {
833 UnaryOp::Neg => write!(f, "-"),
834 #[cfg(feature = "cond")]
835 UnaryOp::Not => write!(f, "not "),
836 UnaryOp::BitNot => write!(f, "~"),
837 }
838 }
839}
840
841#[cfg(feature = "cond")]
842impl std::fmt::Display for CompareOp {
843 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
844 match self {
845 CompareOp::Lt => write!(f, "<"),
846 CompareOp::Le => write!(f, "<="),
847 CompareOp::Gt => write!(f, ">"),
848 CompareOp::Ge => write!(f, ">="),
849 CompareOp::Eq => write!(f, "=="),
850 CompareOp::Ne => write!(f, "!="),
851 }
852 }
853}
854
855#[cfg(feature = "introspect")]
860impl Ast {
861 pub fn free_vars(&self) -> HashSet<&str> {
877 let mut vars = HashSet::new();
878 self.collect_vars(&mut vars);
879 vars
880 }
881
882 fn collect_vars<'a>(&'a self, vars: &mut HashSet<&'a str>) {
883 match self {
884 Ast::Num(_) => {}
885 Ast::Var(name) => {
886 vars.insert(name.as_str());
887 }
888 Ast::BinOp(_, left, right) => {
889 left.collect_vars(vars);
890 right.collect_vars(vars);
891 }
892 Ast::UnaryOp(_, inner) => {
893 inner.collect_vars(vars);
894 }
895 #[cfg(feature = "func")]
896 Ast::Call(_, args) => {
897 for arg in args {
898 arg.collect_vars(vars);
899 }
900 }
901 #[cfg(feature = "cond")]
902 Ast::Compare(_, left, right) => {
903 left.collect_vars(vars);
904 right.collect_vars(vars);
905 }
906 #[cfg(feature = "cond")]
907 Ast::And(left, right) => {
908 left.collect_vars(vars);
909 right.collect_vars(vars);
910 }
911 #[cfg(feature = "cond")]
912 Ast::Or(left, right) => {
913 left.collect_vars(vars);
914 right.collect_vars(vars);
915 }
916 #[cfg(feature = "cond")]
917 Ast::If(cond, then_expr, else_expr) => {
918 cond.collect_vars(vars);
919 then_expr.collect_vars(vars);
920 else_expr.collect_vars(vars);
921 }
922 Ast::Let { name, value, body } => {
923 value.collect_vars(vars);
925 let mut body_vars = HashSet::new();
927 body.collect_vars(&mut body_vars);
928 body_vars.remove(name.as_str());
929 for v in body_vars {
930 vars.insert(v);
931 }
932 }
933 }
934 }
935}
936
937struct Parser<'a> {
942 lexer: Lexer<'a>,
943 current: Token,
944}
945
946impl<'a> Parser<'a> {
947 fn new(input: &'a str) -> Result<Self, ParseError> {
948 let mut lexer = Lexer::new(input);
949 let current = lexer.next_token()?;
950 Ok(Self { lexer, current })
951 }
952
953 fn advance(&mut self) -> Result<(), ParseError> {
954 self.current = self.lexer.next_token()?;
955 Ok(())
956 }
957
958 fn expect(&mut self, expected: Token) -> Result<(), ParseError> {
959 if self.current == expected {
960 self.advance()
961 } else {
962 Err(ParseError::UnexpectedToken(format!("{:?}", self.current)))
963 }
964 }
965
966 fn parse_expr(&mut self) -> Result<Ast, ParseError> {
982 self.parse_let()
983 }
984
985 fn parse_let(&mut self) -> Result<Ast, ParseError> {
986 if self.current == Token::Let {
987 self.advance()?;
988 let name = match &self.current {
990 Token::Ident(s) => s.clone(),
991 _ => return Err(ParseError::UnexpectedToken(format!("{:?}", self.current))),
992 };
993 self.advance()?;
994 self.expect(Token::Assign)?;
996 let value = self.parse_non_let()?;
998 self.expect(Token::Semicolon)?;
1000 let body = self.parse_let()?;
1002 Ok(Ast::Let {
1003 name,
1004 value: Box::new(value),
1005 body: Box::new(body),
1006 })
1007 } else {
1008 self.parse_non_let()
1009 }
1010 }
1011
1012 fn parse_non_let(&mut self) -> Result<Ast, ParseError> {
1013 #[cfg(feature = "cond")]
1014 {
1015 self.parse_if()
1016 }
1017 #[cfg(not(feature = "cond"))]
1018 {
1019 self.parse_bit_or()
1020 }
1021 }
1022
1023 #[cfg(feature = "cond")]
1024 fn parse_if(&mut self) -> Result<Ast, ParseError> {
1025 if self.current == Token::If {
1026 self.advance()?;
1027 let cond = self.parse_or()?;
1028 self.expect(Token::Then)?;
1029 let then_expr = self.parse_if()?; self.expect(Token::Else)?;
1031 let else_expr = self.parse_if()?; Ok(Ast::If(
1033 Box::new(cond),
1034 Box::new(then_expr),
1035 Box::new(else_expr),
1036 ))
1037 } else {
1038 self.parse_or()
1039 }
1040 }
1041
1042 #[cfg(feature = "cond")]
1043 fn parse_or(&mut self) -> Result<Ast, ParseError> {
1044 let mut left = self.parse_and()?;
1045
1046 while self.current == Token::Or {
1047 self.advance()?;
1048 let right = self.parse_and()?;
1049 left = Ast::Or(Box::new(left), Box::new(right));
1050 }
1051
1052 Ok(left)
1053 }
1054
1055 #[cfg(feature = "cond")]
1056 fn parse_and(&mut self) -> Result<Ast, ParseError> {
1057 let mut left = self.parse_bit_or()?;
1058
1059 while self.current == Token::And {
1060 self.advance()?;
1061 let right = self.parse_bit_or()?;
1062 left = Ast::And(Box::new(left), Box::new(right));
1063 }
1064
1065 Ok(left)
1066 }
1067
1068 fn parse_bit_or(&mut self) -> Result<Ast, ParseError> {
1069 let mut left = self.parse_bit_and()?;
1070
1071 while self.current == Token::Pipe {
1072 self.advance()?;
1073 let right = self.parse_bit_and()?;
1074 left = Ast::BinOp(BinOp::BitOr, Box::new(left), Box::new(right));
1075 }
1076
1077 Ok(left)
1078 }
1079
1080 fn parse_bit_and(&mut self) -> Result<Ast, ParseError> {
1081 #[cfg(feature = "cond")]
1082 let mut left = self.parse_compare()?;
1083 #[cfg(not(feature = "cond"))]
1084 let mut left = self.parse_shift()?;
1085
1086 while self.current == Token::Ampersand {
1087 self.advance()?;
1088 #[cfg(feature = "cond")]
1089 let right = self.parse_compare()?;
1090 #[cfg(not(feature = "cond"))]
1091 let right = self.parse_shift()?;
1092 left = Ast::BinOp(BinOp::BitAnd, Box::new(left), Box::new(right));
1093 }
1094
1095 Ok(left)
1096 }
1097
1098 #[cfg(feature = "cond")]
1099 fn parse_compare(&mut self) -> Result<Ast, ParseError> {
1100 let left = self.parse_shift()?;
1101
1102 let op = match &self.current {
1103 Token::Lt => Some(CompareOp::Lt),
1104 Token::Le => Some(CompareOp::Le),
1105 Token::Gt => Some(CompareOp::Gt),
1106 Token::Ge => Some(CompareOp::Ge),
1107 Token::Eq => Some(CompareOp::Eq),
1108 Token::Ne => Some(CompareOp::Ne),
1109 _ => None,
1110 };
1111
1112 if let Some(op) = op {
1113 self.advance()?;
1114 let right = self.parse_shift()?;
1115 Ok(Ast::Compare(op, Box::new(left), Box::new(right)))
1116 } else {
1117 Ok(left)
1118 }
1119 }
1120
1121 fn parse_shift(&mut self) -> Result<Ast, ParseError> {
1122 let mut left = self.parse_add_sub()?;
1123
1124 loop {
1125 match &self.current {
1126 Token::Shl => {
1127 self.advance()?;
1128 let right = self.parse_add_sub()?;
1129 left = Ast::BinOp(BinOp::Shl, Box::new(left), Box::new(right));
1130 }
1131 Token::Shr => {
1132 self.advance()?;
1133 let right = self.parse_add_sub()?;
1134 left = Ast::BinOp(BinOp::Shr, Box::new(left), Box::new(right));
1135 }
1136 _ => break,
1137 }
1138 }
1139
1140 Ok(left)
1141 }
1142
1143 fn parse_add_sub(&mut self) -> Result<Ast, ParseError> {
1144 let mut left = self.parse_mul_div()?;
1145
1146 loop {
1147 match &self.current {
1148 Token::Plus => {
1149 self.advance()?;
1150 let right = self.parse_mul_div()?;
1151 left = Ast::BinOp(BinOp::Add, Box::new(left), Box::new(right));
1152 }
1153 Token::Minus => {
1154 self.advance()?;
1155 let right = self.parse_mul_div()?;
1156 left = Ast::BinOp(BinOp::Sub, Box::new(left), Box::new(right));
1157 }
1158 _ => break,
1159 }
1160 }
1161
1162 Ok(left)
1163 }
1164
1165 fn parse_mul_div(&mut self) -> Result<Ast, ParseError> {
1166 let mut left = self.parse_power()?;
1167
1168 loop {
1169 match &self.current {
1170 Token::Star => {
1171 self.advance()?;
1172 let right = self.parse_power()?;
1173 left = Ast::BinOp(BinOp::Mul, Box::new(left), Box::new(right));
1174 }
1175 Token::Slash => {
1176 self.advance()?;
1177 let right = self.parse_power()?;
1178 left = Ast::BinOp(BinOp::Div, Box::new(left), Box::new(right));
1179 }
1180 Token::Percent => {
1181 self.advance()?;
1182 let right = self.parse_power()?;
1183 left = Ast::BinOp(BinOp::Rem, Box::new(left), Box::new(right));
1184 }
1185 _ => break,
1186 }
1187 }
1188
1189 Ok(left)
1190 }
1191
1192 fn parse_power(&mut self) -> Result<Ast, ParseError> {
1193 let base = self.parse_unary()?;
1194
1195 if self.current == Token::Caret {
1196 self.advance()?;
1197 let exp = self.parse_power()?; Ok(Ast::BinOp(BinOp::Pow, Box::new(base), Box::new(exp)))
1199 } else {
1200 Ok(base)
1201 }
1202 }
1203
1204 fn parse_unary(&mut self) -> Result<Ast, ParseError> {
1205 match &self.current {
1206 Token::Minus => {
1207 self.advance()?;
1208 let inner = self.parse_unary()?;
1209 Ok(Ast::UnaryOp(UnaryOp::Neg, Box::new(inner)))
1210 }
1211 Token::Tilde => {
1212 self.advance()?;
1213 let inner = self.parse_unary()?;
1214 Ok(Ast::UnaryOp(UnaryOp::BitNot, Box::new(inner)))
1215 }
1216 #[cfg(feature = "cond")]
1217 Token::Not => {
1218 self.advance()?;
1219 let inner = self.parse_unary()?;
1220 Ok(Ast::UnaryOp(UnaryOp::Not, Box::new(inner)))
1221 }
1222 _ => self.parse_primary(),
1223 }
1224 }
1225
1226 fn parse_primary(&mut self) -> Result<Ast, ParseError> {
1227 match &self.current {
1228 Token::Number(n) => {
1229 let n = *n;
1230 self.advance()?;
1231 Ok(Ast::Num(n))
1232 }
1233 Token::Ident(name) => {
1234 let name = name.clone();
1235 self.advance()?;
1236
1237 #[cfg(feature = "func")]
1239 if self.current == Token::LParen {
1240 self.advance()?;
1241 let mut args = Vec::new();
1242 if self.current != Token::RParen {
1243 args.push(self.parse_expr()?);
1244 while self.current == Token::Comma {
1245 self.advance()?;
1246 args.push(self.parse_expr()?);
1247 }
1248 }
1249 self.expect(Token::RParen)?;
1250 return Ok(Ast::Call(name, args));
1251 }
1252
1253 Ok(Ast::Var(name))
1255 }
1256 Token::LParen => {
1257 self.advance()?;
1258 let inner = self.parse_expr()?;
1259 self.expect(Token::RParen)?;
1260 Ok(inner)
1261 }
1262 Token::Eof => Err(ParseError::UnexpectedEnd),
1263 _ => Err(ParseError::UnexpectedToken(format!("{:?}", self.current))),
1264 }
1265 }
1266}
1267
1268#[derive(Debug, Clone)]
1302pub struct Expr {
1303 ast: Ast,
1304}
1305
1306impl Expr {
1307 pub fn parse(input: &str) -> Result<Self, ParseError> {
1332 let mut parser = Parser::new(input)?;
1333 let ast = parser.parse_expr()?;
1334 if parser.current != Token::Eof {
1335 return Err(ParseError::UnexpectedToken(format!("{:?}", parser.current)));
1336 }
1337 Ok(Self { ast })
1338 }
1339
1340 pub fn ast(&self) -> &Ast {
1345 &self.ast
1346 }
1347
1348 #[cfg(feature = "introspect")]
1365 pub fn free_vars(&self) -> HashSet<&str> {
1366 self.ast.free_vars()
1367 }
1368
1369 #[cfg(feature = "func")]
1378 pub fn eval(
1379 &self,
1380 vars: &HashMap<String, f32>,
1381 funcs: &FunctionRegistry,
1382 ) -> Result<f32, EvalError> {
1383 eval_ast(&self.ast, vars, funcs)
1384 }
1385
1386 #[cfg(not(feature = "func"))]
1394 pub fn eval(&self, vars: &HashMap<String, f32>) -> Result<f32, EvalError> {
1395 eval_ast(&self.ast, vars)
1396 }
1397}
1398
1399#[cfg(feature = "func")]
1400fn eval_ast(
1401 ast: &Ast,
1402 vars: &HashMap<String, f32>,
1403 funcs: &FunctionRegistry,
1404) -> Result<f32, EvalError> {
1405 match ast {
1406 Ast::Num(n) => Ok(*n as f32),
1407 Ast::Var(name) => vars
1408 .get(name)
1409 .copied()
1410 .ok_or_else(|| EvalError::UnknownVariable(name.clone())),
1411 Ast::BinOp(op, l, r) => {
1412 let l = eval_ast(l, vars, funcs)?;
1413 let r = eval_ast(r, vars, funcs)?;
1414 match op {
1415 BinOp::Add => Ok(l + r),
1416 BinOp::Sub => Ok(l - r),
1417 BinOp::Mul => Ok(l * r),
1418 BinOp::Div => Ok(l / r),
1419 BinOp::Pow => Ok(l.powf(r)),
1420 BinOp::Rem => Ok(l % r),
1421 BinOp::BitAnd => Err(EvalError::UnsupportedOperation("&".to_string())),
1422 BinOp::BitOr => Err(EvalError::UnsupportedOperation("|".to_string())),
1423 BinOp::Shl => Err(EvalError::UnsupportedOperation("<<".to_string())),
1424 BinOp::Shr => Err(EvalError::UnsupportedOperation(">>".to_string())),
1425 }
1426 }
1427 Ast::UnaryOp(op, inner) => {
1428 let v = eval_ast(inner, vars, funcs)?;
1429 match op {
1430 UnaryOp::Neg => Ok(-v),
1431 UnaryOp::BitNot => Err(EvalError::UnsupportedOperation("~".to_string())),
1432 #[cfg(feature = "cond")]
1433 UnaryOp::Not => {
1434 if v == 0.0 {
1435 Ok(1.0)
1436 } else {
1437 Ok(0.0)
1438 }
1439 }
1440 }
1441 }
1442 #[cfg(feature = "cond")]
1443 Ast::Compare(op, l, r) => {
1444 let l = eval_ast(l, vars, funcs)?;
1445 let r = eval_ast(r, vars, funcs)?;
1446 let result = match op {
1447 CompareOp::Lt => l < r,
1448 CompareOp::Le => l <= r,
1449 CompareOp::Gt => l > r,
1450 CompareOp::Ge => l >= r,
1451 CompareOp::Eq => l == r,
1452 CompareOp::Ne => l != r,
1453 };
1454 Ok(if result { 1.0 } else { 0.0 })
1455 }
1456 #[cfg(feature = "cond")]
1457 Ast::And(l, r) => {
1458 let l = eval_ast(l, vars, funcs)?;
1459 if l == 0.0 {
1460 Ok(0.0) } else {
1462 let r = eval_ast(r, vars, funcs)?;
1463 Ok(if r != 0.0 { 1.0 } else { 0.0 })
1464 }
1465 }
1466 #[cfg(feature = "cond")]
1467 Ast::Or(l, r) => {
1468 let l = eval_ast(l, vars, funcs)?;
1469 if l != 0.0 {
1470 Ok(1.0) } else {
1472 let r = eval_ast(r, vars, funcs)?;
1473 Ok(if r != 0.0 { 1.0 } else { 0.0 })
1474 }
1475 }
1476 #[cfg(feature = "cond")]
1477 Ast::If(cond, then_expr, else_expr) => {
1478 let cond = eval_ast(cond, vars, funcs)?;
1479 if cond != 0.0 {
1480 eval_ast(then_expr, vars, funcs)
1481 } else {
1482 eval_ast(else_expr, vars, funcs)
1483 }
1484 }
1485 Ast::Call(name, args) => {
1486 let func = funcs
1487 .get(name)
1488 .ok_or_else(|| EvalError::UnknownFunction(name.clone()))?;
1489
1490 if args.len() != func.arg_count() {
1491 return Err(EvalError::WrongArgCount {
1492 func: name.clone(),
1493 expected: func.arg_count(),
1494 got: args.len(),
1495 });
1496 }
1497
1498 let arg_values: Vec<f32> = args
1499 .iter()
1500 .map(|a| eval_ast(a, vars, funcs))
1501 .collect::<Result<_, _>>()?;
1502
1503 Ok(func.call(&arg_values))
1504 }
1505 Ast::Let { name, value, body } => {
1506 let val = eval_ast(value, vars, funcs)?;
1507 let mut new_vars = vars.clone();
1508 new_vars.insert(name.clone(), val);
1509 eval_ast(body, &new_vars, funcs)
1510 }
1511 }
1512}
1513
1514#[cfg(not(feature = "func"))]
1515fn eval_ast(ast: &Ast, vars: &HashMap<String, f32>) -> Result<f32, EvalError> {
1516 match ast {
1517 Ast::Num(n) => Ok(*n as f32),
1518 Ast::Var(name) => vars
1519 .get(name)
1520 .copied()
1521 .ok_or_else(|| EvalError::UnknownVariable(name.clone())),
1522 Ast::BinOp(op, l, r) => {
1523 let l = eval_ast(l, vars)?;
1524 let r = eval_ast(r, vars)?;
1525 match op {
1526 BinOp::Add => Ok(l + r),
1527 BinOp::Sub => Ok(l - r),
1528 BinOp::Mul => Ok(l * r),
1529 BinOp::Div => Ok(l / r),
1530 BinOp::Pow => Ok(l.powf(r)),
1531 BinOp::Rem => Ok(l % r),
1532 BinOp::BitAnd => Err(EvalError::UnsupportedOperation("&".to_string())),
1533 BinOp::BitOr => Err(EvalError::UnsupportedOperation("|".to_string())),
1534 BinOp::Shl => Err(EvalError::UnsupportedOperation("<<".to_string())),
1535 BinOp::Shr => Err(EvalError::UnsupportedOperation(">>".to_string())),
1536 }
1537 }
1538 Ast::UnaryOp(op, inner) => {
1539 let v = eval_ast(inner, vars)?;
1540 match op {
1541 UnaryOp::Neg => Ok(-v),
1542 UnaryOp::BitNot => Err(EvalError::UnsupportedOperation("~".to_string())),
1543 #[cfg(feature = "cond")]
1544 UnaryOp::Not => {
1545 if v == 0.0 {
1546 Ok(1.0)
1547 } else {
1548 Ok(0.0)
1549 }
1550 }
1551 }
1552 }
1553 #[cfg(feature = "cond")]
1554 Ast::Compare(op, l, r) => {
1555 let l = eval_ast(l, vars)?;
1556 let r = eval_ast(r, vars)?;
1557 let result = match op {
1558 CompareOp::Lt => l < r,
1559 CompareOp::Le => l <= r,
1560 CompareOp::Gt => l > r,
1561 CompareOp::Ge => l >= r,
1562 CompareOp::Eq => l == r,
1563 CompareOp::Ne => l != r,
1564 };
1565 Ok(if result { 1.0 } else { 0.0 })
1566 }
1567 #[cfg(feature = "cond")]
1568 Ast::And(l, r) => {
1569 let l = eval_ast(l, vars)?;
1570 if l == 0.0 {
1571 Ok(0.0) } else {
1573 let r = eval_ast(r, vars)?;
1574 Ok(if r != 0.0 { 1.0 } else { 0.0 })
1575 }
1576 }
1577 #[cfg(feature = "cond")]
1578 Ast::Or(l, r) => {
1579 let l = eval_ast(l, vars)?;
1580 if l != 0.0 {
1581 Ok(1.0) } else {
1583 let r = eval_ast(r, vars)?;
1584 Ok(if r != 0.0 { 1.0 } else { 0.0 })
1585 }
1586 }
1587 #[cfg(feature = "cond")]
1588 Ast::If(cond, then_expr, else_expr) => {
1589 let cond = eval_ast(cond, vars)?;
1590 if cond != 0.0 {
1591 eval_ast(then_expr, vars)
1592 } else {
1593 eval_ast(else_expr, vars)
1594 }
1595 }
1596 Ast::Let { name, value, body } => {
1597 let val = eval_ast(value, vars)?;
1598 let mut new_vars = vars.clone();
1599 new_vars.insert(name.clone(), val);
1600 eval_ast(body, &new_vars)
1601 }
1602 }
1603}
1604
1605#[cfg(test)]
1610mod tests {
1611 use super::*;
1612
1613 #[cfg(feature = "func")]
1614 fn eval(expr_str: &str, vars: &[(&str, f32)]) -> f32 {
1615 let registry = FunctionRegistry::new();
1616 let expr = Expr::parse(expr_str).unwrap();
1617 let var_map: HashMap<String, f32> = vars.iter().map(|(k, v)| (k.to_string(), *v)).collect();
1618 expr.eval(&var_map, ®istry).unwrap()
1619 }
1620
1621 #[cfg(not(feature = "func"))]
1622 fn eval(expr_str: &str, vars: &[(&str, f32)]) -> f32 {
1623 let expr = Expr::parse(expr_str).unwrap();
1624 let var_map: HashMap<String, f32> = vars.iter().map(|(k, v)| (k.to_string(), *v)).collect();
1625 expr.eval(&var_map).unwrap()
1626 }
1627
1628 #[test]
1629 fn test_parse_number() {
1630 assert_eq!(eval("42", &[]), 42.0);
1631 }
1632
1633 #[test]
1634 fn test_parse_float() {
1635 assert!((eval("1.234", &[]) - 1.234).abs() < 0.001);
1636 }
1637
1638 #[test]
1639 fn test_parse_variable() {
1640 assert_eq!(eval("x", &[("x", 5.0)]), 5.0);
1641 assert_eq!(eval("foo", &[("foo", 3.0)]), 3.0);
1642 }
1643
1644 #[test]
1645 fn test_parse_add() {
1646 assert_eq!(eval("1 + 2", &[]), 3.0);
1647 }
1648
1649 #[test]
1650 fn test_parse_mul() {
1651 assert_eq!(eval("3 * 4", &[]), 12.0);
1652 }
1653
1654 #[test]
1655 fn test_precedence() {
1656 assert_eq!(eval("2 + 3 * 4", &[]), 14.0);
1657 }
1658
1659 #[test]
1660 fn test_parentheses() {
1661 assert_eq!(eval("(2 + 3) * 4", &[]), 20.0);
1662 }
1663
1664 #[test]
1665 fn test_negation() {
1666 assert_eq!(eval("-5", &[]), -5.0);
1667 }
1668
1669 #[test]
1670 fn test_power() {
1671 assert_eq!(eval("2 ^ 3", &[]), 8.0);
1672 }
1673
1674 #[cfg(feature = "func")]
1675 #[test]
1676 fn test_unknown_variable() {
1677 let registry = FunctionRegistry::new();
1678 let expr = Expr::parse("unknown").unwrap();
1679 let vars = HashMap::new();
1680 let result = expr.eval(&vars, ®istry);
1681 assert!(matches!(result, Err(EvalError::UnknownVariable(_))));
1682 }
1683
1684 #[cfg(not(feature = "func"))]
1685 #[test]
1686 fn test_unknown_variable() {
1687 let expr = Expr::parse("unknown").unwrap();
1688 let vars = HashMap::new();
1689 let result = expr.eval(&vars);
1690 assert!(matches!(result, Err(EvalError::UnknownVariable(_))));
1691 }
1692
1693 #[cfg(feature = "func")]
1694 #[test]
1695 fn test_unknown_function() {
1696 let registry = FunctionRegistry::new();
1697 let expr = Expr::parse("unknown(1)").unwrap();
1698 let vars = HashMap::new();
1699 let result = expr.eval(&vars, ®istry);
1700 assert!(matches!(result, Err(EvalError::UnknownFunction(_))));
1701 }
1702
1703 #[cfg(feature = "func")]
1704 #[test]
1705 fn test_custom_function() {
1706 struct Double;
1707 impl ExprFn for Double {
1708 fn name(&self) -> &str {
1709 "double"
1710 }
1711 fn arg_count(&self) -> usize {
1712 1
1713 }
1714 fn call(&self, args: &[f32]) -> f32 {
1715 args[0] * 2.0
1716 }
1717 }
1718
1719 let mut registry = FunctionRegistry::new();
1720 registry.register(Double);
1721
1722 let expr = Expr::parse("double(5)").unwrap();
1723 let vars = HashMap::new();
1724 assert_eq!(expr.eval(&vars, ®istry).unwrap(), 10.0);
1725 }
1726
1727 #[cfg(feature = "func")]
1728 #[test]
1729 fn test_zero_arg_function() {
1730 struct Pi;
1731 impl ExprFn for Pi {
1732 fn name(&self) -> &str {
1733 "pi"
1734 }
1735 fn arg_count(&self) -> usize {
1736 0
1737 }
1738 fn call(&self, _args: &[f32]) -> f32 {
1739 std::f32::consts::PI
1740 }
1741 }
1742
1743 let mut registry = FunctionRegistry::new();
1744 registry.register(Pi);
1745
1746 let expr = Expr::parse("pi()").unwrap();
1747 let vars = HashMap::new();
1748 assert!((expr.eval(&vars, ®istry).unwrap() - std::f32::consts::PI).abs() < 0.001);
1749 }
1750
1751 #[cfg(feature = "func")]
1752 #[test]
1753 fn test_wrong_arg_count() {
1754 struct OneArg;
1755 impl ExprFn for OneArg {
1756 fn name(&self) -> &str {
1757 "one_arg"
1758 }
1759 fn arg_count(&self) -> usize {
1760 1
1761 }
1762 fn call(&self, args: &[f32]) -> f32 {
1763 args[0]
1764 }
1765 }
1766
1767 let mut registry = FunctionRegistry::new();
1768 registry.register(OneArg);
1769
1770 let expr = Expr::parse("one_arg(1, 2)").unwrap();
1771 let vars = HashMap::new();
1772 let result = expr.eval(&vars, ®istry);
1773 assert!(matches!(result, Err(EvalError::WrongArgCount { .. })));
1774 }
1775
1776 #[cfg(feature = "func")]
1777 #[test]
1778 fn test_complex_expression() {
1779 struct Add;
1780 impl ExprFn for Add {
1781 fn name(&self) -> &str {
1782 "add"
1783 }
1784 fn arg_count(&self) -> usize {
1785 2
1786 }
1787 fn call(&self, args: &[f32]) -> f32 {
1788 args[0] + args[1]
1789 }
1790 }
1791
1792 let mut registry = FunctionRegistry::new();
1793 registry.register(Add);
1794
1795 let expr = Expr::parse("add(x * 2, y + 1)").unwrap();
1796 let vars: HashMap<String, f32> = [("x".to_string(), 3.0), ("y".to_string(), 4.0)].into();
1797 assert_eq!(expr.eval(&vars, ®istry).unwrap(), 11.0); }
1799
1800 #[cfg(feature = "cond")]
1802 #[test]
1803 fn test_compare_lt() {
1804 assert_eq!(eval("1 < 2", &[]), 1.0);
1805 assert_eq!(eval("2 < 1", &[]), 0.0);
1806 assert_eq!(eval("1 < 1", &[]), 0.0);
1807 }
1808
1809 #[cfg(feature = "cond")]
1810 #[test]
1811 fn test_compare_le() {
1812 assert_eq!(eval("1 <= 2", &[]), 1.0);
1813 assert_eq!(eval("2 <= 1", &[]), 0.0);
1814 assert_eq!(eval("1 <= 1", &[]), 1.0);
1815 }
1816
1817 #[cfg(feature = "cond")]
1818 #[test]
1819 fn test_compare_gt() {
1820 assert_eq!(eval("2 > 1", &[]), 1.0);
1821 assert_eq!(eval("1 > 2", &[]), 0.0);
1822 }
1823
1824 #[cfg(feature = "cond")]
1825 #[test]
1826 fn test_compare_ge() {
1827 assert_eq!(eval("2 >= 1", &[]), 1.0);
1828 assert_eq!(eval("1 >= 1", &[]), 1.0);
1829 }
1830
1831 #[cfg(feature = "cond")]
1832 #[test]
1833 fn test_compare_eq() {
1834 assert_eq!(eval("1 == 1", &[]), 1.0);
1835 assert_eq!(eval("1 == 2", &[]), 0.0);
1836 }
1837
1838 #[cfg(feature = "cond")]
1839 #[test]
1840 fn test_compare_ne() {
1841 assert_eq!(eval("1 != 2", &[]), 1.0);
1842 assert_eq!(eval("1 != 1", &[]), 0.0);
1843 }
1844
1845 #[cfg(feature = "cond")]
1847 #[test]
1848 fn test_and() {
1849 assert_eq!(eval("1 and 1", &[]), 1.0);
1850 assert_eq!(eval("1 and 0", &[]), 0.0);
1851 assert_eq!(eval("0 and 1", &[]), 0.0);
1852 assert_eq!(eval("0 and 0", &[]), 0.0);
1853 }
1854
1855 #[cfg(feature = "cond")]
1856 #[test]
1857 fn test_or() {
1858 assert_eq!(eval("1 or 1", &[]), 1.0);
1859 assert_eq!(eval("1 or 0", &[]), 1.0);
1860 assert_eq!(eval("0 or 1", &[]), 1.0);
1861 assert_eq!(eval("0 or 0", &[]), 0.0);
1862 }
1863
1864 #[cfg(feature = "cond")]
1865 #[test]
1866 fn test_not() {
1867 assert_eq!(eval("not 0", &[]), 1.0);
1868 assert_eq!(eval("not 1", &[]), 0.0);
1869 assert_eq!(eval("not 5", &[]), 0.0); }
1871
1872 #[cfg(feature = "cond")]
1874 #[test]
1875 fn test_if_then_else() {
1876 assert_eq!(eval("if 1 then 10 else 20", &[]), 10.0);
1877 assert_eq!(eval("if 0 then 10 else 20", &[]), 20.0);
1878 }
1879
1880 #[cfg(feature = "cond")]
1881 #[test]
1882 fn test_if_with_comparison() {
1883 assert_eq!(eval("if x > 5 then 1 else 0", &[("x", 10.0)]), 1.0);
1884 assert_eq!(eval("if x > 5 then 1 else 0", &[("x", 3.0)]), 0.0);
1885 }
1886
1887 #[cfg(feature = "cond")]
1888 #[test]
1889 fn test_nested_if() {
1890 assert_eq!(
1892 eval(
1893 "if x > 0 then if x > 10 then 2 else 1 else 0",
1894 &[("x", 15.0)]
1895 ),
1896 2.0
1897 );
1898 assert_eq!(
1899 eval(
1900 "if x > 0 then if x > 10 then 2 else 1 else 0",
1901 &[("x", 5.0)]
1902 ),
1903 1.0
1904 );
1905 assert_eq!(
1906 eval(
1907 "if x > 0 then if x > 10 then 2 else 1 else 0",
1908 &[("x", -1.0)]
1909 ),
1910 0.0
1911 );
1912 }
1913
1914 #[cfg(feature = "cond")]
1915 #[test]
1916 fn test_compound_boolean() {
1917 assert_eq!(eval("x > 0 and x < 10", &[("x", 5.0)]), 1.0);
1918 assert_eq!(eval("x > 0 and x < 10", &[("x", 15.0)]), 0.0);
1919 assert_eq!(eval("x < 0 or x > 10", &[("x", 5.0)]), 0.0);
1920 assert_eq!(eval("x < 0 or x > 10", &[("x", 15.0)]), 1.0);
1921 }
1922
1923 #[cfg(feature = "cond")]
1924 #[test]
1925 fn test_precedence_compare_vs_arithmetic() {
1926 assert_eq!(eval("1 + 2 < 4", &[]), 1.0);
1928 assert_eq!(eval("1 + 2 < 3", &[]), 0.0);
1929 }
1930
1931 #[cfg(feature = "cond")]
1932 #[test]
1933 fn test_precedence_and_vs_or() {
1934 assert_eq!(eval("1 or 0 and 0", &[]), 1.0); assert_eq!(eval("0 or 1 and 1", &[]), 1.0); assert_eq!(eval("0 or 0 and 1", &[]), 0.0); }
1939
1940 #[cfg(feature = "introspect")]
1942 #[test]
1943 fn test_free_vars_simple() {
1944 let expr = Expr::parse("x + y").unwrap();
1945 let vars = expr.free_vars();
1946 assert_eq!(vars.len(), 2);
1947 assert!(vars.contains("x"));
1948 assert!(vars.contains("y"));
1949 }
1950
1951 #[cfg(feature = "introspect")]
1952 #[test]
1953 fn test_free_vars_no_vars() {
1954 let expr = Expr::parse("1 + 2 * 3").unwrap();
1955 let vars = expr.free_vars();
1956 assert!(vars.is_empty());
1957 }
1958
1959 #[cfg(feature = "introspect")]
1960 #[test]
1961 fn test_free_vars_duplicates() {
1962 let expr = Expr::parse("x + x * x").unwrap();
1963 let vars = expr.free_vars();
1964 assert_eq!(vars.len(), 1);
1965 assert!(vars.contains("x"));
1966 }
1967
1968 #[cfg(all(feature = "introspect", feature = "func"))]
1969 #[test]
1970 fn test_free_vars_in_call() {
1971 let expr = Expr::parse("sin(x) + cos(y)").unwrap();
1972 let vars = expr.free_vars();
1973 assert_eq!(vars.len(), 2);
1974 assert!(vars.contains("x"));
1975 assert!(vars.contains("y"));
1976 }
1977
1978 #[cfg(all(feature = "introspect", feature = "cond"))]
1979 #[test]
1980 fn test_free_vars_in_conditional() {
1981 let expr = Expr::parse("if a > b then x else y").unwrap();
1982 let vars = expr.free_vars();
1983 assert_eq!(vars.len(), 4);
1984 assert!(vars.contains("a"));
1985 assert!(vars.contains("b"));
1986 assert!(vars.contains("x"));
1987 assert!(vars.contains("y"));
1988 }
1989
1990 #[test]
1992 fn test_let_simple() {
1993 assert_eq!(eval("let a = 1; a", &[]), 1.0);
1994 }
1995
1996 #[test]
1997 fn test_let_with_computation() {
1998 assert_eq!(eval("let a = 1 + 2; a * 2", &[]), 6.0);
1999 }
2000
2001 #[test]
2002 fn test_let_chained() {
2003 assert_eq!(eval("let a = 1; let b = 2; a + b", &[]), 3.0);
2004 }
2005
2006 #[test]
2007 fn test_let_with_var() {
2008 assert_eq!(eval("let a = x * 2; a + 1", &[("x", 5.0)]), 11.0);
2009 }
2010
2011 #[test]
2012 fn test_let_shadow() {
2013 assert_eq!(eval("let x = 10; x + 1", &[("x", 5.0)]), 11.0);
2015 }
2016
2017 #[test]
2018 fn test_let_uses_outer_in_value() {
2019 assert_eq!(eval("let x = x + 1; x", &[("x", 5.0)]), 6.0);
2021 }
2022
2023 #[cfg(feature = "introspect")]
2024 #[test]
2025 fn test_free_vars_in_let() {
2026 let expr = Expr::parse("let a = x; a + y").unwrap();
2028 let vars = expr.free_vars();
2029 assert_eq!(vars.len(), 2);
2030 assert!(vars.contains("x"));
2031 assert!(vars.contains("y"));
2032 assert!(!vars.contains("a"));
2033 }
2034
2035 #[cfg(feature = "introspect")]
2036 #[test]
2037 fn test_free_vars_in_let_shadow() {
2038 let expr = Expr::parse("let x = x + 1; x").unwrap();
2040 let vars = expr.free_vars();
2041 assert_eq!(vars.len(), 1);
2042 assert!(vars.contains("x"));
2043 }
2044
2045 #[test]
2047 fn test_ast_display_simple() {
2048 let expr = Expr::parse("1 + 2").unwrap();
2049 let s = expr.ast().to_string();
2050 assert_eq!(s, "(1 + 2)");
2051 }
2052
2053 #[test]
2054 fn test_ast_display_nested() {
2055 let expr = Expr::parse("1 + 2 * 3").unwrap();
2056 let s = expr.ast().to_string();
2057 assert_eq!(s, "(1 + (2 * 3))");
2059 }
2060
2061 #[test]
2062 fn test_ast_roundtrip() {
2063 let cases = [
2064 "1 + 2",
2065 "x * y",
2066 "1 + 2 * 3",
2067 "(1 + 2) * 3",
2068 "-x",
2069 "x ^ 2",
2070 "2 ^ 3 ^ 4", ];
2072 for case in cases {
2073 let expr1 = Expr::parse(case).unwrap();
2074 let stringified = expr1.ast().to_string();
2075 let expr2 = Expr::parse(&stringified).unwrap();
2076 let stringified2 = expr2.ast().to_string();
2077 assert_eq!(stringified, stringified2, "Roundtrip failed for: {}", case);
2078 }
2079 }
2080
2081 #[cfg(feature = "func")]
2082 #[test]
2083 fn test_ast_roundtrip_func() {
2084 let cases = ["sin(x)", "foo(a, b, c)", "f()"];
2085 for case in cases {
2086 let expr1 = Expr::parse(case).unwrap();
2087 let stringified = expr1.ast().to_string();
2088 let expr2 = Expr::parse(&stringified).unwrap();
2089 let stringified2 = expr2.ast().to_string();
2090 assert_eq!(stringified, stringified2, "Roundtrip failed for: {}", case);
2091 }
2092 }
2093
2094 #[cfg(feature = "cond")]
2095 #[test]
2096 fn test_ast_roundtrip_cond() {
2097 let cases = [
2098 "x < y",
2099 "x and y",
2100 "x or y",
2101 "not x",
2102 "if x then y else z",
2103 "if a > b then x else y",
2104 ];
2105 for case in cases {
2106 let expr1 = Expr::parse(case).unwrap();
2107 let stringified = expr1.ast().to_string();
2108 let expr2 = Expr::parse(&stringified).unwrap();
2109 let stringified2 = expr2.ast().to_string();
2110 assert_eq!(stringified, stringified2, "Roundtrip failed for: {}", case);
2111 }
2112 }
2113
2114 #[test]
2115 fn test_ast_roundtrip_let() {
2116 let cases = [
2117 "let a = 1; a",
2118 "let a = 1; let b = 2; a + b",
2119 "let x = y * 2; x + 1",
2120 ];
2121 for case in cases {
2122 let expr1 = Expr::parse(case).unwrap();
2123 let stringified = expr1.ast().to_string();
2124 let expr2 = Expr::parse(&stringified).unwrap();
2125 let stringified2 = expr2.ast().to_string();
2126 assert_eq!(stringified, stringified2, "Roundtrip failed for: {}", case);
2127 }
2128 }
2129}
2130
2131#[cfg(test)]
2136mod proptest_tests {
2137 use super::*;
2138 use proptest::prelude::*;
2139
2140 fn expr_strategy() -> impl Strategy<Value = String> {
2142 let num = prop::num::f32::NORMAL.prop_map(|n| format!("{:.6}", n));
2144 let var = prop::sample::select(vec!["x", "y", "z", "a", "b"]).prop_map(String::from);
2145
2146 let binop = prop::sample::select(vec!["+", "-", "*", "/"]);
2148
2149 prop::strategy::Union::new(vec![
2151 num.clone().boxed(),
2152 var.clone().boxed(),
2153 (num.clone(), binop.clone(), num.clone())
2154 .prop_map(|(l, op, r)| format!("({} {} {})", l, op, r))
2155 .boxed(),
2156 (var.clone(), binop.clone(), num.clone())
2157 .prop_map(|(l, op, r)| format!("({} {} {})", l, op, r))
2158 .boxed(),
2159 (num.clone(), binop, var.clone())
2160 .prop_map(|(l, op, r)| format!("({} {} {})", l, op, r))
2161 .boxed(),
2162 ])
2163 }
2164
2165 proptest! {
2166 #[test]
2168 fn parse_never_panics(s in ".*") {
2169 let _ = Expr::parse(&s);
2171 }
2172
2173 #[test]
2175 fn valid_expr_parses(expr in expr_strategy()) {
2176 let result = Expr::parse(&expr);
2177 prop_assert!(result.is_ok(), "Failed to parse: {}", expr);
2178 }
2179
2180 #[test]
2182 fn number_roundtrip(n in prop::num::f64::NORMAL) {
2183 let expr_str = format!("{:.6}", n);
2184 if let Ok(expr) = Expr::parse(&expr_str) {
2185 if let Ast::Num(parsed) = expr.ast() {
2187 let diff = (parsed - n).abs();
2188 prop_assert!(diff < 0.001 || diff / n.abs() < 0.001,
2189 "Number mismatch: {} vs {}", n, parsed);
2190 }
2191 }
2192 }
2193
2194 #[test]
2196 #[cfg(not(feature = "func"))]
2197 fn eval_with_vars_no_panic(
2198 x in prop::num::f32::NORMAL,
2199 y in prop::num::f32::NORMAL,
2200 expr in expr_strategy()
2201 ) {
2202 if let Ok(parsed) = Expr::parse(&expr) {
2203 let vars: HashMap<String, f32> = [
2204 ("x".into(), x),
2205 ("y".into(), y),
2206 ("z".into(), 1.0),
2207 ("a".into(), 2.0),
2208 ("b".into(), 3.0),
2209 ].into();
2210 let _ = parsed.eval(&vars);
2212 }
2213 }
2214
2215 #[test]
2217 #[cfg(feature = "func")]
2218 fn eval_with_vars_no_panic_func(
2219 x in prop::num::f32::NORMAL,
2220 y in prop::num::f32::NORMAL,
2221 expr in expr_strategy()
2222 ) {
2223 if let Ok(parsed) = Expr::parse(&expr) {
2224 let vars: HashMap<String, f32> = [
2225 ("x".into(), x),
2226 ("y".into(), y),
2227 ("z".into(), 1.0),
2228 ("a".into(), 2.0),
2229 ("b".into(), 3.0),
2230 ].into();
2231 let registry = FunctionRegistry::new();
2232 let _ = parsed.eval(&vars, ®istry);
2233 }
2234 }
2235
2236 #[test]
2238 #[cfg(not(feature = "func"))]
2239 fn negation_inverse(n in prop::num::f32::NORMAL) {
2240 let expr = Expr::parse(&format!("--{:.6}", n)).unwrap();
2241 let vars = HashMap::new();
2242 let result = expr.eval(&vars).unwrap();
2243 prop_assert!((result - n).abs() < 0.01,
2244 "Double negation failed: --{} = {}", n, result);
2245 }
2246
2247 #[test]
2249 #[cfg(feature = "func")]
2250 fn negation_inverse_func(n in prop::num::f32::NORMAL) {
2251 let expr = Expr::parse(&format!("--{:.6}", n)).unwrap();
2252 let vars = HashMap::new();
2253 let registry = FunctionRegistry::new();
2254 let result = expr.eval(&vars, ®istry).unwrap();
2255 prop_assert!((result - n).abs() < 0.01,
2256 "Double negation failed: --{} = {}", n, result);
2257 }
2258 }
2259}