1use std::collections::HashMap;
41use std::fmt;
42use std::iter::Peekable;
43use std::str::Chars;
44
45#[derive(Debug, Clone, PartialEq)]
47pub enum Expr {
48 Literal(f64),
50 Column(String),
52 BinaryOp {
54 op: BinaryOp,
55 left: Box<Expr>,
56 right: Box<Expr>,
57 },
58 UnaryOp { op: UnaryOp, expr: Box<Expr> },
60 FnCall { name: String, args: Vec<Expr> },
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum BinaryOp {
67 Add,
68 Sub,
69 Mul,
70 Div,
71 Mod,
72 Pow,
73}
74
75impl fmt::Display for BinaryOp {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 match self {
78 BinaryOp::Add => write!(f, "+"),
79 BinaryOp::Sub => write!(f, "-"),
80 BinaryOp::Mul => write!(f, "*"),
81 BinaryOp::Div => write!(f, "/"),
82 BinaryOp::Mod => write!(f, "%"),
83 BinaryOp::Pow => write!(f, "^"),
84 }
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub enum UnaryOp {
91 Neg,
92}
93
94#[derive(Debug, Clone, PartialEq)]
96enum Token {
97 Number(f64),
98 Ident(String),
99 Plus,
100 Minus,
101 Star,
102 Slash,
103 Percent,
104 Caret,
105 LParen,
106 RParen,
107 Comma,
108 Eof,
109}
110
111struct Lexer<'a> {
113 chars: Peekable<Chars<'a>>,
114}
115
116impl<'a> Lexer<'a> {
117 fn new(input: &'a str) -> Self {
118 Self {
119 chars: input.chars().peekable(),
120 }
121 }
122
123 fn next_token(&mut self) -> Result<Token, CalcError> {
124 self.skip_whitespace();
125
126 match self.chars.peek() {
127 None => Ok(Token::Eof),
128 Some(&c) => match c {
129 '+' => {
130 self.chars.next();
131 Ok(Token::Plus)
132 }
133 '-' => {
134 self.chars.next();
135 Ok(Token::Minus)
136 }
137 '*' => {
138 self.chars.next();
139 Ok(Token::Star)
140 }
141 '/' => {
142 self.chars.next();
143 Ok(Token::Slash)
144 }
145 '%' => {
146 self.chars.next();
147 Ok(Token::Percent)
148 }
149 '^' => {
150 self.chars.next();
151 Ok(Token::Caret)
152 }
153 '(' => {
154 self.chars.next();
155 Ok(Token::LParen)
156 }
157 ')' => {
158 self.chars.next();
159 Ok(Token::RParen)
160 }
161 ',' => {
162 self.chars.next();
163 Ok(Token::Comma)
164 }
165 '0'..='9' | '.' => self.number(),
166 'a'..='z' | 'A'..='Z' | '_' | '$' => self.ident(),
167 _ => Err(CalcError::UnexpectedChar(c)),
168 },
169 }
170 }
171
172 fn skip_whitespace(&mut self) {
173 while let Some(&c) = self.chars.peek() {
174 if c.is_whitespace() {
175 self.chars.next();
176 } else {
177 break;
178 }
179 }
180 }
181
182 fn number(&mut self) -> Result<Token, CalcError> {
183 let mut s = String::new();
184 let mut has_dot = false;
185
186 while let Some(&c) = self.chars.peek() {
187 if c.is_ascii_digit() {
188 s.push(c);
189 self.chars.next();
190 } else if c == '.' && !has_dot {
191 has_dot = true;
192 s.push(c);
193 self.chars.next();
194 } else if c == 'e' || c == 'E' {
195 s.push(c);
197 self.chars.next();
198 if let Some(&sign) = self.chars.peek()
199 && (sign == '+' || sign == '-')
200 {
201 s.push(sign);
202 self.chars.next();
203 }
204 } else {
205 break;
206 }
207 }
208
209 s.parse::<f64>()
210 .map(Token::Number)
211 .map_err(|_| CalcError::InvalidNumber(s))
212 }
213
214 fn ident(&mut self) -> Result<Token, CalcError> {
215 let mut s = String::new();
216
217 while let Some(&c) = self.chars.peek() {
218 if c.is_alphanumeric() || c == '_' || c == '$' {
219 s.push(c);
220 self.chars.next();
221 } else {
222 break;
223 }
224 }
225
226 Ok(Token::Ident(s))
227 }
228}
229
230pub struct Parser<'a> {
232 lexer: Lexer<'a>,
233 current: Token,
234}
235
236impl<'a> Parser<'a> {
237 pub fn new(input: &'a str) -> Result<Self, CalcError> {
239 let mut lexer = Lexer::new(input);
240 let current = lexer.next_token()?;
241 Ok(Self { lexer, current })
242 }
243
244 pub fn parse(&mut self) -> Result<Expr, CalcError> {
246 let expr = self.expression()?;
247 if self.current != Token::Eof {
248 return Err(CalcError::UnexpectedToken(format!("{:?}", self.current)));
249 }
250 Ok(expr)
251 }
252
253 fn advance(&mut self) -> Result<(), CalcError> {
254 self.current = self.lexer.next_token()?;
255 Ok(())
256 }
257
258 fn expression(&mut self) -> Result<Expr, CalcError> {
259 self.additive()
260 }
261
262 fn additive(&mut self) -> Result<Expr, CalcError> {
263 let mut left = self.multiplicative()?;
264
265 loop {
266 let op = match &self.current {
267 Token::Plus => BinaryOp::Add,
268 Token::Minus => BinaryOp::Sub,
269 _ => break,
270 };
271 self.advance()?;
272 let right = self.multiplicative()?;
273 left = Expr::BinaryOp {
274 op,
275 left: Box::new(left),
276 right: Box::new(right),
277 };
278 }
279
280 Ok(left)
281 }
282
283 fn multiplicative(&mut self) -> Result<Expr, CalcError> {
284 let mut left = self.power()?;
285
286 loop {
287 let op = match &self.current {
288 Token::Star => BinaryOp::Mul,
289 Token::Slash => BinaryOp::Div,
290 Token::Percent => BinaryOp::Mod,
291 _ => break,
292 };
293 self.advance()?;
294 let right = self.power()?;
295 left = Expr::BinaryOp {
296 op,
297 left: Box::new(left),
298 right: Box::new(right),
299 };
300 }
301
302 Ok(left)
303 }
304
305 fn power(&mut self) -> Result<Expr, CalcError> {
306 let left = self.unary()?;
307
308 if self.current == Token::Caret {
309 self.advance()?;
310 let right = self.power()?; return Ok(Expr::BinaryOp {
312 op: BinaryOp::Pow,
313 left: Box::new(left),
314 right: Box::new(right),
315 });
316 }
317
318 Ok(left)
319 }
320
321 fn unary(&mut self) -> Result<Expr, CalcError> {
322 if self.current == Token::Minus {
323 self.advance()?;
324 let expr = self.unary()?;
325 return Ok(Expr::UnaryOp {
326 op: UnaryOp::Neg,
327 expr: Box::new(expr),
328 });
329 }
330
331 self.primary()
332 }
333
334 fn primary(&mut self) -> Result<Expr, CalcError> {
335 match self.current.clone() {
336 Token::Number(n) => {
337 self.advance()?;
338 Ok(Expr::Literal(n))
339 }
340 Token::Ident(name) => {
341 self.advance()?;
342 if self.current == Token::LParen {
343 self.advance()?;
345 let args = self.arguments()?;
346 if self.current != Token::RParen {
347 return Err(CalcError::ExpectedToken(")".into()));
348 }
349 self.advance()?;
350 Ok(Expr::FnCall { name, args })
351 } else {
352 Ok(Expr::Column(name))
354 }
355 }
356 Token::LParen => {
357 self.advance()?;
358 let expr = self.expression()?;
359 if self.current != Token::RParen {
360 return Err(CalcError::ExpectedToken(")".into()));
361 }
362 self.advance()?;
363 Ok(expr)
364 }
365 _ => Err(CalcError::UnexpectedToken(format!("{:?}", self.current))),
366 }
367 }
368
369 fn arguments(&mut self) -> Result<Vec<Expr>, CalcError> {
370 let mut args = Vec::new();
371
372 if self.current == Token::RParen {
373 return Ok(args);
374 }
375
376 args.push(self.expression()?);
377
378 while self.current == Token::Comma {
379 self.advance()?;
380 args.push(self.expression()?);
381 }
382
383 Ok(args)
384 }
385}
386
387#[derive(Debug, Clone, PartialEq)]
389pub enum CalcError {
390 UnexpectedChar(char),
391 InvalidNumber(String),
392 UnexpectedToken(String),
393 ExpectedToken(String),
394 UndefinedColumn(String),
395 UndefinedFunction(String),
396 DivisionByZero,
397 InvalidArgCount {
398 name: String,
399 expected: usize,
400 got: usize,
401 },
402 MathError(String),
403 Timeout,
404}
405
406impl fmt::Display for CalcError {
407 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
408 match self {
409 CalcError::UnexpectedChar(c) => write!(f, "Unexpected character: {}", c),
410 CalcError::InvalidNumber(s) => write!(f, "Invalid number: {}", s),
411 CalcError::UnexpectedToken(s) => write!(f, "Unexpected token: {}", s),
412 CalcError::ExpectedToken(s) => write!(f, "Expected token: {}", s),
413 CalcError::UndefinedColumn(s) => write!(f, "Undefined column: {}", s),
414 CalcError::UndefinedFunction(s) => write!(f, "Undefined function: {}", s),
415 CalcError::DivisionByZero => write!(f, "Division by zero"),
416 CalcError::InvalidArgCount {
417 name,
418 expected,
419 got,
420 } => {
421 write!(
422 f,
423 "Function {} expects {} args, got {}",
424 name, expected, got
425 )
426 }
427 CalcError::MathError(s) => write!(f, "Math error: {}", s),
428 CalcError::Timeout => write!(f, "Evaluation timeout"),
429 }
430 }
431}
432
433impl std::error::Error for CalcError {}
434
435pub type RowContext = HashMap<String, f64>;
437
438pub struct Evaluator {
440 max_steps: usize,
442 steps: usize,
444}
445
446impl Evaluator {
447 pub fn new() -> Self {
449 Self {
450 max_steps: 10000,
451 steps: 0,
452 }
453 }
454
455 pub fn with_max_steps(max_steps: usize) -> Self {
457 Self {
458 max_steps,
459 steps: 0,
460 }
461 }
462
463 pub fn eval(&mut self, expr: &Expr, ctx: &RowContext) -> Result<f64, CalcError> {
465 self.steps += 1;
466 if self.steps > self.max_steps {
467 return Err(CalcError::Timeout);
468 }
469
470 match expr {
471 Expr::Literal(n) => Ok(*n),
472
473 Expr::Column(name) => ctx
474 .get(name)
475 .copied()
476 .ok_or_else(|| CalcError::UndefinedColumn(name.clone())),
477
478 Expr::BinaryOp { op, left, right } => {
479 let l = self.eval(left, ctx)?;
480 let r = self.eval(right, ctx)?;
481
482 match op {
483 BinaryOp::Add => Ok(l + r),
484 BinaryOp::Sub => Ok(l - r),
485 BinaryOp::Mul => Ok(l * r),
486 BinaryOp::Div => {
487 if r == 0.0 {
488 Err(CalcError::DivisionByZero)
489 } else {
490 Ok(l / r)
491 }
492 }
493 BinaryOp::Mod => {
494 if r == 0.0 {
495 Err(CalcError::DivisionByZero)
496 } else {
497 Ok(l % r)
498 }
499 }
500 BinaryOp::Pow => Ok(l.powf(r)),
501 }
502 }
503
504 Expr::UnaryOp { op, expr } => {
505 let v = self.eval(expr, ctx)?;
506 match op {
507 UnaryOp::Neg => Ok(-v),
508 }
509 }
510
511 Expr::FnCall { name, args } => self.call_function(name, args, ctx),
512 }
513 }
514
515 fn call_function(
517 &mut self,
518 name: &str,
519 args: &[Expr],
520 ctx: &RowContext,
521 ) -> Result<f64, CalcError> {
522 let evaluated: Result<Vec<f64>, CalcError> =
523 args.iter().map(|a| self.eval(a, ctx)).collect();
524 let args = evaluated?;
525
526 match name.to_lowercase().as_str() {
527 "abs" => {
529 check_args(name, &args, 1)?;
530 Ok(args[0].abs())
531 }
532 "sqrt" => {
533 check_args(name, &args, 1)?;
534 if args[0] < 0.0 {
535 Err(CalcError::MathError("sqrt of negative number".into()))
536 } else {
537 Ok(args[0].sqrt())
538 }
539 }
540 "floor" => {
541 check_args(name, &args, 1)?;
542 Ok(args[0].floor())
543 }
544 "ceil" => {
545 check_args(name, &args, 1)?;
546 Ok(args[0].ceil())
547 }
548 "round" => {
549 if args.len() == 1 {
550 Ok(args[0].round())
551 } else if args.len() == 2 {
552 let factor = 10f64.powi(args[1] as i32);
553 Ok((args[0] * factor).round() / factor)
554 } else {
555 Err(CalcError::InvalidArgCount {
556 name: name.into(),
557 expected: 1,
558 got: args.len(),
559 })
560 }
561 }
562 "sin" => {
563 check_args(name, &args, 1)?;
564 Ok(args[0].sin())
565 }
566 "cos" => {
567 check_args(name, &args, 1)?;
568 Ok(args[0].cos())
569 }
570 "tan" => {
571 check_args(name, &args, 1)?;
572 Ok(args[0].tan())
573 }
574 "exp" => {
575 check_args(name, &args, 1)?;
576 Ok(args[0].exp())
577 }
578 "ln" | "log" => {
579 check_args(name, &args, 1)?;
580 if args[0] <= 0.0 {
581 Err(CalcError::MathError("log of non-positive number".into()))
582 } else {
583 Ok(args[0].ln())
584 }
585 }
586 "log10" => {
587 check_args(name, &args, 1)?;
588 if args[0] <= 0.0 {
589 Err(CalcError::MathError("log of non-positive number".into()))
590 } else {
591 Ok(args[0].log10())
592 }
593 }
594 "log2" => {
595 check_args(name, &args, 1)?;
596 if args[0] <= 0.0 {
597 Err(CalcError::MathError("log of non-positive number".into()))
598 } else {
599 Ok(args[0].log2())
600 }
601 }
602
603 "pow" => {
605 check_args(name, &args, 2)?;
606 Ok(args[0].powf(args[1]))
607 }
608 "min" => {
609 check_args(name, &args, 2)?;
610 Ok(args[0].min(args[1]))
611 }
612 "max" => {
613 check_args(name, &args, 2)?;
614 Ok(args[0].max(args[1]))
615 }
616 "atan2" => {
617 check_args(name, &args, 2)?;
618 Ok(args[0].atan2(args[1]))
619 }
620
621 "sum" => Ok(args.iter().sum()),
623 "avg" => {
624 if args.is_empty() {
625 Err(CalcError::InvalidArgCount {
626 name: name.into(),
627 expected: 1,
628 got: 0,
629 })
630 } else {
631 Ok(args.iter().sum::<f64>() / args.len() as f64)
632 }
633 }
634
635 "if" => {
637 check_args(name, &args, 3)?;
638 if args[0] != 0.0 {
639 Ok(args[1])
640 } else {
641 Ok(args[2])
642 }
643 }
644
645 _ => Err(CalcError::UndefinedFunction(name.into())),
646 }
647 }
648}
649
650impl Default for Evaluator {
651 fn default() -> Self {
652 Self::new()
653 }
654}
655
656fn check_args(name: &str, args: &[f64], expected: usize) -> Result<(), CalcError> {
657 if args.len() != expected {
658 Err(CalcError::InvalidArgCount {
659 name: name.into(),
660 expected,
661 got: args.len(),
662 })
663 } else {
664 Ok(())
665 }
666}
667
668pub fn calculate(expr: &str, ctx: &RowContext) -> Result<f64, CalcError> {
670 let mut parser = Parser::new(expr)?;
671 let ast = parser.parse()?;
672 let mut evaluator = Evaluator::new();
673 evaluator.eval(&ast, ctx)
674}
675
676pub fn parse_expr(expr: &str) -> Result<Expr, CalcError> {
678 let mut parser = Parser::new(expr)?;
679 parser.parse()
680}
681
682#[cfg(test)]
683mod tests {
684 use super::*;
685
686 #[test]
687 fn test_basic_arithmetic() {
688 let ctx = RowContext::new();
689
690 assert_eq!(calculate("2 + 3", &ctx).unwrap(), 5.0);
691 assert_eq!(calculate("10 - 4", &ctx).unwrap(), 6.0);
692 assert_eq!(calculate("3 * 4", &ctx).unwrap(), 12.0);
693 assert_eq!(calculate("15 / 3", &ctx).unwrap(), 5.0);
694 assert_eq!(calculate("7 % 4", &ctx).unwrap(), 3.0);
695 assert_eq!(calculate("2 ^ 3", &ctx).unwrap(), 8.0);
696 }
697
698 #[test]
699 fn test_operator_precedence() {
700 let ctx = RowContext::new();
701
702 assert_eq!(calculate("2 + 3 * 4", &ctx).unwrap(), 14.0);
703 assert_eq!(calculate("(2 + 3) * 4", &ctx).unwrap(), 20.0);
704 assert_eq!(calculate("2 * 3 + 4", &ctx).unwrap(), 10.0);
705 assert_eq!(calculate("10 - 2 * 3", &ctx).unwrap(), 4.0);
706 }
707
708 #[test]
709 fn test_unary_minus() {
710 let ctx = RowContext::new();
711
712 assert_eq!(calculate("-5", &ctx).unwrap(), -5.0);
713 assert_eq!(calculate("--5", &ctx).unwrap(), 5.0);
714 assert_eq!(calculate("3 + -2", &ctx).unwrap(), 1.0);
715 assert_eq!(calculate("-3 * -2", &ctx).unwrap(), 6.0);
716 }
717
718 #[test]
719 fn test_column_references() {
720 let mut ctx = RowContext::new();
721 ctx.insert("price".into(), 99.99);
722 ctx.insert("quantity".into(), 5.0);
723 ctx.insert("tax_rate".into(), 0.15);
724
725 assert_eq!(calculate("price * quantity", &ctx).unwrap(), 499.95);
726 assert_eq!(
727 calculate("price * quantity * (1 + tax_rate)", &ctx).unwrap(),
728 574.9425
729 );
730 }
731
732 #[test]
733 fn test_functions() {
734 let ctx = RowContext::new();
735
736 assert_eq!(calculate("abs(-5)", &ctx).unwrap(), 5.0);
737 assert_eq!(calculate("sqrt(16)", &ctx).unwrap(), 4.0);
738 assert_eq!(calculate("floor(3.7)", &ctx).unwrap(), 3.0);
739 assert_eq!(calculate("ceil(3.2)", &ctx).unwrap(), 4.0);
740 assert_eq!(calculate("round(3.5)", &ctx).unwrap(), 4.0);
741 #[allow(clippy::approx_constant)]
742 {
743 assert_eq!(calculate("round(3.14159, 2)", &ctx).unwrap(), 3.14);
744 }
745 assert_eq!(calculate("min(3, 5)", &ctx).unwrap(), 3.0);
746 assert_eq!(calculate("max(3, 5)", &ctx).unwrap(), 5.0);
747 assert_eq!(calculate("pow(2, 10)", &ctx).unwrap(), 1024.0);
748 }
749
750 #[test]
751 fn test_trig_functions() {
752 let ctx = RowContext::new();
753 let _pi = std::f64::consts::PI;
754
755 assert!((calculate("sin(0)", &ctx).unwrap() - 0.0).abs() < 1e-10);
756 assert!((calculate("cos(0)", &ctx).unwrap() - 1.0).abs() < 1e-10);
757 }
758
759 #[test]
760 fn test_conditional() {
761 let mut ctx = RowContext::new();
762 ctx.insert("score".into(), 85.0);
763
764 assert_eq!(calculate("if(1, 10, 20)", &ctx).unwrap(), 10.0);
770 assert_eq!(calculate("if(0, 10, 20)", &ctx).unwrap(), 20.0);
771
772 assert_eq!(calculate("if(score, 1, 0)", &ctx).unwrap(), 1.0);
774 }
775
776 #[test]
777 fn test_variadic_functions() {
778 let ctx = RowContext::new();
779
780 assert_eq!(calculate("sum(1, 2, 3, 4)", &ctx).unwrap(), 10.0);
781 assert_eq!(calculate("avg(2, 4, 6)", &ctx).unwrap(), 4.0);
782 }
783
784 #[test]
785 fn test_scientific_notation() {
786 let ctx = RowContext::new();
787
788 assert_eq!(calculate("1e3", &ctx).unwrap(), 1000.0);
789 assert_eq!(calculate("1.5e-2", &ctx).unwrap(), 0.015);
790 }
791
792 #[test]
793 fn test_division_by_zero() {
794 let ctx = RowContext::new();
795
796 assert!(matches!(
797 calculate("1 / 0", &ctx),
798 Err(CalcError::DivisionByZero)
799 ));
800 assert!(matches!(
801 calculate("5 % 0", &ctx),
802 Err(CalcError::DivisionByZero)
803 ));
804 }
805
806 #[test]
807 fn test_undefined_column() {
808 let ctx = RowContext::new();
809
810 assert!(matches!(
811 calculate("undefined_col + 1", &ctx),
812 Err(CalcError::UndefinedColumn(_))
813 ));
814 }
815
816 #[test]
817 fn test_undefined_function() {
818 let ctx = RowContext::new();
819
820 assert!(matches!(
821 calculate("unknown_func(1)", &ctx),
822 Err(CalcError::UndefinedFunction(_))
823 ));
824 }
825
826 #[test]
827 fn test_complex_expression() {
828 let mut ctx = RowContext::new();
829 ctx.insert("revenue".into(), 1000.0);
830 ctx.insert("cost".into(), 600.0);
831 ctx.insert("tax".into(), 0.15);
832
833 let result = calculate("(revenue - cost) * (1 - tax)", &ctx).unwrap();
835 assert!((result - 340.0).abs() < 1e-10);
836 }
837}