1use std::collections::HashMap;
44use std::fmt;
45use std::iter::Peekable;
46use std::str::Chars;
47
48#[derive(Debug, Clone, PartialEq)]
50pub enum Expr {
51 Literal(f64),
53 Column(String),
55 BinaryOp {
57 op: BinaryOp,
58 left: Box<Expr>,
59 right: Box<Expr>,
60 },
61 UnaryOp { op: UnaryOp, expr: Box<Expr> },
63 FnCall { name: String, args: Vec<Expr> },
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub enum BinaryOp {
70 Add,
71 Sub,
72 Mul,
73 Div,
74 Mod,
75 Pow,
76}
77
78impl fmt::Display for BinaryOp {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 match self {
81 BinaryOp::Add => write!(f, "+"),
82 BinaryOp::Sub => write!(f, "-"),
83 BinaryOp::Mul => write!(f, "*"),
84 BinaryOp::Div => write!(f, "/"),
85 BinaryOp::Mod => write!(f, "%"),
86 BinaryOp::Pow => write!(f, "^"),
87 }
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum UnaryOp {
94 Neg,
95}
96
97#[derive(Debug, Clone, PartialEq)]
99enum Token {
100 Number(f64),
101 Ident(String),
102 Plus,
103 Minus,
104 Star,
105 Slash,
106 Percent,
107 Caret,
108 LParen,
109 RParen,
110 Comma,
111 Eof,
112}
113
114struct Lexer<'a> {
116 chars: Peekable<Chars<'a>>,
117}
118
119impl<'a> Lexer<'a> {
120 fn new(input: &'a str) -> Self {
121 Self {
122 chars: input.chars().peekable(),
123 }
124 }
125
126 fn next_token(&mut self) -> Result<Token, CalcError> {
127 self.skip_whitespace();
128
129 match self.chars.peek() {
130 None => Ok(Token::Eof),
131 Some(&c) => match c {
132 '+' => {
133 self.chars.next();
134 Ok(Token::Plus)
135 }
136 '-' => {
137 self.chars.next();
138 Ok(Token::Minus)
139 }
140 '*' => {
141 self.chars.next();
142 Ok(Token::Star)
143 }
144 '/' => {
145 self.chars.next();
146 Ok(Token::Slash)
147 }
148 '%' => {
149 self.chars.next();
150 Ok(Token::Percent)
151 }
152 '^' => {
153 self.chars.next();
154 Ok(Token::Caret)
155 }
156 '(' => {
157 self.chars.next();
158 Ok(Token::LParen)
159 }
160 ')' => {
161 self.chars.next();
162 Ok(Token::RParen)
163 }
164 ',' => {
165 self.chars.next();
166 Ok(Token::Comma)
167 }
168 '0'..='9' | '.' => self.number(),
169 'a'..='z' | 'A'..='Z' | '_' | '$' => self.ident(),
170 _ => Err(CalcError::UnexpectedChar(c)),
171 },
172 }
173 }
174
175 fn skip_whitespace(&mut self) {
176 while let Some(&c) = self.chars.peek() {
177 if c.is_whitespace() {
178 self.chars.next();
179 } else {
180 break;
181 }
182 }
183 }
184
185 fn number(&mut self) -> Result<Token, CalcError> {
186 let mut s = String::new();
187 let mut has_dot = false;
188
189 while let Some(&c) = self.chars.peek() {
190 if c.is_ascii_digit() {
191 s.push(c);
192 self.chars.next();
193 } else if c == '.' && !has_dot {
194 has_dot = true;
195 s.push(c);
196 self.chars.next();
197 } else if c == 'e' || c == 'E' {
198 s.push(c);
200 self.chars.next();
201 if let Some(&sign) = self.chars.peek()
202 && (sign == '+' || sign == '-')
203 {
204 s.push(sign);
205 self.chars.next();
206 }
207 } else {
208 break;
209 }
210 }
211
212 s.parse::<f64>()
213 .map(Token::Number)
214 .map_err(|_| CalcError::InvalidNumber(s))
215 }
216
217 fn ident(&mut self) -> Result<Token, CalcError> {
218 let mut s = String::new();
219
220 while let Some(&c) = self.chars.peek() {
221 if c.is_alphanumeric() || c == '_' || c == '$' {
222 s.push(c);
223 self.chars.next();
224 } else {
225 break;
226 }
227 }
228
229 Ok(Token::Ident(s))
230 }
231}
232
233pub struct Parser<'a> {
235 lexer: Lexer<'a>,
236 current: Token,
237}
238
239impl<'a> Parser<'a> {
240 pub fn new(input: &'a str) -> Result<Self, CalcError> {
242 let mut lexer = Lexer::new(input);
243 let current = lexer.next_token()?;
244 Ok(Self { lexer, current })
245 }
246
247 pub fn parse(&mut self) -> Result<Expr, CalcError> {
249 let expr = self.expression()?;
250 if self.current != Token::Eof {
251 return Err(CalcError::UnexpectedToken(format!("{:?}", self.current)));
252 }
253 Ok(expr)
254 }
255
256 fn advance(&mut self) -> Result<(), CalcError> {
257 self.current = self.lexer.next_token()?;
258 Ok(())
259 }
260
261 fn expression(&mut self) -> Result<Expr, CalcError> {
262 self.additive()
263 }
264
265 fn additive(&mut self) -> Result<Expr, CalcError> {
266 let mut left = self.multiplicative()?;
267
268 loop {
269 let op = match &self.current {
270 Token::Plus => BinaryOp::Add,
271 Token::Minus => BinaryOp::Sub,
272 _ => break,
273 };
274 self.advance()?;
275 let right = self.multiplicative()?;
276 left = Expr::BinaryOp {
277 op,
278 left: Box::new(left),
279 right: Box::new(right),
280 };
281 }
282
283 Ok(left)
284 }
285
286 fn multiplicative(&mut self) -> Result<Expr, CalcError> {
287 let mut left = self.power()?;
288
289 loop {
290 let op = match &self.current {
291 Token::Star => BinaryOp::Mul,
292 Token::Slash => BinaryOp::Div,
293 Token::Percent => BinaryOp::Mod,
294 _ => break,
295 };
296 self.advance()?;
297 let right = self.power()?;
298 left = Expr::BinaryOp {
299 op,
300 left: Box::new(left),
301 right: Box::new(right),
302 };
303 }
304
305 Ok(left)
306 }
307
308 fn power(&mut self) -> Result<Expr, CalcError> {
309 let left = self.unary()?;
310
311 if self.current == Token::Caret {
312 self.advance()?;
313 let right = self.power()?; return Ok(Expr::BinaryOp {
315 op: BinaryOp::Pow,
316 left: Box::new(left),
317 right: Box::new(right),
318 });
319 }
320
321 Ok(left)
322 }
323
324 fn unary(&mut self) -> Result<Expr, CalcError> {
325 if self.current == Token::Minus {
326 self.advance()?;
327 let expr = self.unary()?;
328 return Ok(Expr::UnaryOp {
329 op: UnaryOp::Neg,
330 expr: Box::new(expr),
331 });
332 }
333
334 self.primary()
335 }
336
337 fn primary(&mut self) -> Result<Expr, CalcError> {
338 match self.current.clone() {
339 Token::Number(n) => {
340 self.advance()?;
341 Ok(Expr::Literal(n))
342 }
343 Token::Ident(name) => {
344 self.advance()?;
345 if self.current == Token::LParen {
346 self.advance()?;
348 let args = self.arguments()?;
349 if self.current != Token::RParen {
350 return Err(CalcError::ExpectedToken(")".into()));
351 }
352 self.advance()?;
353 Ok(Expr::FnCall { name, args })
354 } else {
355 Ok(Expr::Column(name))
357 }
358 }
359 Token::LParen => {
360 self.advance()?;
361 let expr = self.expression()?;
362 if self.current != Token::RParen {
363 return Err(CalcError::ExpectedToken(")".into()));
364 }
365 self.advance()?;
366 Ok(expr)
367 }
368 _ => Err(CalcError::UnexpectedToken(format!("{:?}", self.current))),
369 }
370 }
371
372 fn arguments(&mut self) -> Result<Vec<Expr>, CalcError> {
373 let mut args = Vec::new();
374
375 if self.current == Token::RParen {
376 return Ok(args);
377 }
378
379 args.push(self.expression()?);
380
381 while self.current == Token::Comma {
382 self.advance()?;
383 args.push(self.expression()?);
384 }
385
386 Ok(args)
387 }
388}
389
390#[derive(Debug, Clone, PartialEq)]
392pub enum CalcError {
393 UnexpectedChar(char),
394 InvalidNumber(String),
395 UnexpectedToken(String),
396 ExpectedToken(String),
397 UndefinedColumn(String),
398 UndefinedFunction(String),
399 DivisionByZero,
400 InvalidArgCount {
401 name: String,
402 expected: usize,
403 got: usize,
404 },
405 MathError(String),
406 Timeout,
407}
408
409impl fmt::Display for CalcError {
410 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
411 match self {
412 CalcError::UnexpectedChar(c) => write!(f, "Unexpected character: {}", c),
413 CalcError::InvalidNumber(s) => write!(f, "Invalid number: {}", s),
414 CalcError::UnexpectedToken(s) => write!(f, "Unexpected token: {}", s),
415 CalcError::ExpectedToken(s) => write!(f, "Expected token: {}", s),
416 CalcError::UndefinedColumn(s) => write!(f, "Undefined column: {}", s),
417 CalcError::UndefinedFunction(s) => write!(f, "Undefined function: {}", s),
418 CalcError::DivisionByZero => write!(f, "Division by zero"),
419 CalcError::InvalidArgCount {
420 name,
421 expected,
422 got,
423 } => {
424 write!(
425 f,
426 "Function {} expects {} args, got {}",
427 name, expected, got
428 )
429 }
430 CalcError::MathError(s) => write!(f, "Math error: {}", s),
431 CalcError::Timeout => write!(f, "Evaluation timeout"),
432 }
433 }
434}
435
436impl std::error::Error for CalcError {}
437
438pub type RowContext = HashMap<String, f64>;
440
441pub struct Evaluator {
443 max_steps: usize,
445 steps: usize,
447}
448
449impl Evaluator {
450 pub fn new() -> Self {
452 Self {
453 max_steps: 10000,
454 steps: 0,
455 }
456 }
457
458 pub fn with_max_steps(max_steps: usize) -> Self {
460 Self {
461 max_steps,
462 steps: 0,
463 }
464 }
465
466 pub fn eval(&mut self, expr: &Expr, ctx: &RowContext) -> Result<f64, CalcError> {
468 self.steps += 1;
469 if self.steps > self.max_steps {
470 return Err(CalcError::Timeout);
471 }
472
473 match expr {
474 Expr::Literal(n) => Ok(*n),
475
476 Expr::Column(name) => ctx
477 .get(name)
478 .copied()
479 .ok_or_else(|| CalcError::UndefinedColumn(name.clone())),
480
481 Expr::BinaryOp { op, left, right } => {
482 let l = self.eval(left, ctx)?;
483 let r = self.eval(right, ctx)?;
484
485 match op {
486 BinaryOp::Add => Ok(l + r),
487 BinaryOp::Sub => Ok(l - r),
488 BinaryOp::Mul => Ok(l * r),
489 BinaryOp::Div => {
490 if r == 0.0 {
491 Err(CalcError::DivisionByZero)
492 } else {
493 Ok(l / r)
494 }
495 }
496 BinaryOp::Mod => {
497 if r == 0.0 {
498 Err(CalcError::DivisionByZero)
499 } else {
500 Ok(l % r)
501 }
502 }
503 BinaryOp::Pow => Ok(l.powf(r)),
504 }
505 }
506
507 Expr::UnaryOp { op, expr } => {
508 let v = self.eval(expr, ctx)?;
509 match op {
510 UnaryOp::Neg => Ok(-v),
511 }
512 }
513
514 Expr::FnCall { name, args } => self.call_function(name, args, ctx),
515 }
516 }
517
518 fn call_function(
520 &mut self,
521 name: &str,
522 args: &[Expr],
523 ctx: &RowContext,
524 ) -> Result<f64, CalcError> {
525 let evaluated: Result<Vec<f64>, CalcError> =
526 args.iter().map(|a| self.eval(a, ctx)).collect();
527 let args = evaluated?;
528
529 match name.to_lowercase().as_str() {
530 "abs" => {
532 check_args(name, &args, 1)?;
533 Ok(args[0].abs())
534 }
535 "sqrt" => {
536 check_args(name, &args, 1)?;
537 if args[0] < 0.0 {
538 Err(CalcError::MathError("sqrt of negative number".into()))
539 } else {
540 Ok(args[0].sqrt())
541 }
542 }
543 "floor" => {
544 check_args(name, &args, 1)?;
545 Ok(args[0].floor())
546 }
547 "ceil" => {
548 check_args(name, &args, 1)?;
549 Ok(args[0].ceil())
550 }
551 "round" => {
552 if args.len() == 1 {
553 Ok(args[0].round())
554 } else if args.len() == 2 {
555 let factor = 10f64.powi(args[1] as i32);
556 Ok((args[0] * factor).round() / factor)
557 } else {
558 Err(CalcError::InvalidArgCount {
559 name: name.into(),
560 expected: 1,
561 got: args.len(),
562 })
563 }
564 }
565 "sin" => {
566 check_args(name, &args, 1)?;
567 Ok(args[0].sin())
568 }
569 "cos" => {
570 check_args(name, &args, 1)?;
571 Ok(args[0].cos())
572 }
573 "tan" => {
574 check_args(name, &args, 1)?;
575 Ok(args[0].tan())
576 }
577 "exp" => {
578 check_args(name, &args, 1)?;
579 Ok(args[0].exp())
580 }
581 "ln" | "log" => {
582 check_args(name, &args, 1)?;
583 if args[0] <= 0.0 {
584 Err(CalcError::MathError("log of non-positive number".into()))
585 } else {
586 Ok(args[0].ln())
587 }
588 }
589 "log10" => {
590 check_args(name, &args, 1)?;
591 if args[0] <= 0.0 {
592 Err(CalcError::MathError("log of non-positive number".into()))
593 } else {
594 Ok(args[0].log10())
595 }
596 }
597 "log2" => {
598 check_args(name, &args, 1)?;
599 if args[0] <= 0.0 {
600 Err(CalcError::MathError("log of non-positive number".into()))
601 } else {
602 Ok(args[0].log2())
603 }
604 }
605
606 "pow" => {
608 check_args(name, &args, 2)?;
609 Ok(args[0].powf(args[1]))
610 }
611 "min" => {
612 check_args(name, &args, 2)?;
613 Ok(args[0].min(args[1]))
614 }
615 "max" => {
616 check_args(name, &args, 2)?;
617 Ok(args[0].max(args[1]))
618 }
619 "atan2" => {
620 check_args(name, &args, 2)?;
621 Ok(args[0].atan2(args[1]))
622 }
623
624 "sum" => Ok(args.iter().sum()),
626 "avg" => {
627 if args.is_empty() {
628 Err(CalcError::InvalidArgCount {
629 name: name.into(),
630 expected: 1,
631 got: 0,
632 })
633 } else {
634 Ok(args.iter().sum::<f64>() / args.len() as f64)
635 }
636 }
637
638 "if" => {
640 check_args(name, &args, 3)?;
641 if args[0] != 0.0 {
642 Ok(args[1])
643 } else {
644 Ok(args[2])
645 }
646 }
647
648 _ => Err(CalcError::UndefinedFunction(name.into())),
649 }
650 }
651}
652
653impl Default for Evaluator {
654 fn default() -> Self {
655 Self::new()
656 }
657}
658
659fn check_args(name: &str, args: &[f64], expected: usize) -> Result<(), CalcError> {
660 if args.len() != expected {
661 Err(CalcError::InvalidArgCount {
662 name: name.into(),
663 expected,
664 got: args.len(),
665 })
666 } else {
667 Ok(())
668 }
669}
670
671pub fn calculate(expr: &str, ctx: &RowContext) -> Result<f64, CalcError> {
673 let mut parser = Parser::new(expr)?;
674 let ast = parser.parse()?;
675 let mut evaluator = Evaluator::new();
676 evaluator.eval(&ast, ctx)
677}
678
679pub fn parse_expr(expr: &str) -> Result<Expr, CalcError> {
681 let mut parser = Parser::new(expr)?;
682 parser.parse()
683}
684
685#[cfg(test)]
686mod tests {
687 use super::*;
688
689 #[test]
690 fn test_basic_arithmetic() {
691 let ctx = RowContext::new();
692
693 assert_eq!(calculate("2 + 3", &ctx).unwrap(), 5.0);
694 assert_eq!(calculate("10 - 4", &ctx).unwrap(), 6.0);
695 assert_eq!(calculate("3 * 4", &ctx).unwrap(), 12.0);
696 assert_eq!(calculate("15 / 3", &ctx).unwrap(), 5.0);
697 assert_eq!(calculate("7 % 4", &ctx).unwrap(), 3.0);
698 assert_eq!(calculate("2 ^ 3", &ctx).unwrap(), 8.0);
699 }
700
701 #[test]
702 fn test_operator_precedence() {
703 let ctx = RowContext::new();
704
705 assert_eq!(calculate("2 + 3 * 4", &ctx).unwrap(), 14.0);
706 assert_eq!(calculate("(2 + 3) * 4", &ctx).unwrap(), 20.0);
707 assert_eq!(calculate("2 * 3 + 4", &ctx).unwrap(), 10.0);
708 assert_eq!(calculate("10 - 2 * 3", &ctx).unwrap(), 4.0);
709 }
710
711 #[test]
712 fn test_unary_minus() {
713 let ctx = RowContext::new();
714
715 assert_eq!(calculate("-5", &ctx).unwrap(), -5.0);
716 assert_eq!(calculate("--5", &ctx).unwrap(), 5.0);
717 assert_eq!(calculate("3 + -2", &ctx).unwrap(), 1.0);
718 assert_eq!(calculate("-3 * -2", &ctx).unwrap(), 6.0);
719 }
720
721 #[test]
722 fn test_column_references() {
723 let mut ctx = RowContext::new();
724 ctx.insert("price".into(), 99.99);
725 ctx.insert("quantity".into(), 5.0);
726 ctx.insert("tax_rate".into(), 0.15);
727
728 assert_eq!(calculate("price * quantity", &ctx).unwrap(), 499.95);
729 assert_eq!(
730 calculate("price * quantity * (1 + tax_rate)", &ctx).unwrap(),
731 574.9425
732 );
733 }
734
735 #[test]
736 fn test_functions() {
737 let ctx = RowContext::new();
738
739 assert_eq!(calculate("abs(-5)", &ctx).unwrap(), 5.0);
740 assert_eq!(calculate("sqrt(16)", &ctx).unwrap(), 4.0);
741 assert_eq!(calculate("floor(3.7)", &ctx).unwrap(), 3.0);
742 assert_eq!(calculate("ceil(3.2)", &ctx).unwrap(), 4.0);
743 assert_eq!(calculate("round(3.5)", &ctx).unwrap(), 4.0);
744 #[allow(clippy::approx_constant)]
745 {
746 assert_eq!(calculate("round(3.14159, 2)", &ctx).unwrap(), 3.14);
747 }
748 assert_eq!(calculate("min(3, 5)", &ctx).unwrap(), 3.0);
749 assert_eq!(calculate("max(3, 5)", &ctx).unwrap(), 5.0);
750 assert_eq!(calculate("pow(2, 10)", &ctx).unwrap(), 1024.0);
751 }
752
753 #[test]
754 fn test_trig_functions() {
755 let ctx = RowContext::new();
756 let _pi = std::f64::consts::PI;
757
758 assert!((calculate("sin(0)", &ctx).unwrap() - 0.0).abs() < 1e-10);
759 assert!((calculate("cos(0)", &ctx).unwrap() - 1.0).abs() < 1e-10);
760 }
761
762 #[test]
763 fn test_conditional() {
764 let mut ctx = RowContext::new();
765 ctx.insert("score".into(), 85.0);
766
767 assert_eq!(calculate("if(1, 10, 20)", &ctx).unwrap(), 10.0);
773 assert_eq!(calculate("if(0, 10, 20)", &ctx).unwrap(), 20.0);
774
775 assert_eq!(calculate("if(score, 1, 0)", &ctx).unwrap(), 1.0);
777 }
778
779 #[test]
780 fn test_variadic_functions() {
781 let ctx = RowContext::new();
782
783 assert_eq!(calculate("sum(1, 2, 3, 4)", &ctx).unwrap(), 10.0);
784 assert_eq!(calculate("avg(2, 4, 6)", &ctx).unwrap(), 4.0);
785 }
786
787 #[test]
788 fn test_scientific_notation() {
789 let ctx = RowContext::new();
790
791 assert_eq!(calculate("1e3", &ctx).unwrap(), 1000.0);
792 assert_eq!(calculate("1.5e-2", &ctx).unwrap(), 0.015);
793 }
794
795 #[test]
796 fn test_division_by_zero() {
797 let ctx = RowContext::new();
798
799 assert!(matches!(
800 calculate("1 / 0", &ctx),
801 Err(CalcError::DivisionByZero)
802 ));
803 assert!(matches!(
804 calculate("5 % 0", &ctx),
805 Err(CalcError::DivisionByZero)
806 ));
807 }
808
809 #[test]
810 fn test_undefined_column() {
811 let ctx = RowContext::new();
812
813 assert!(matches!(
814 calculate("undefined_col + 1", &ctx),
815 Err(CalcError::UndefinedColumn(_))
816 ));
817 }
818
819 #[test]
820 fn test_undefined_function() {
821 let ctx = RowContext::new();
822
823 assert!(matches!(
824 calculate("unknown_func(1)", &ctx),
825 Err(CalcError::UndefinedFunction(_))
826 ));
827 }
828
829 #[test]
830 fn test_complex_expression() {
831 let mut ctx = RowContext::new();
832 ctx.insert("revenue".into(), 1000.0);
833 ctx.insert("cost".into(), 600.0);
834 ctx.insert("tax".into(), 0.15);
835
836 let result = calculate("(revenue - cost) * (1 - tax)", &ctx).unwrap();
838 assert!((result - 340.0).abs() < 1e-10);
839 }
840}