Skip to main content

sochdb_query/
calc.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Calculator Expression Evaluator (Task 13)
19//!
20//! Safe mathematical expression evaluator for agentic use cases:
21//! - Sandboxed evaluation (no code injection)
22//! - Column references for computed fields
23//! - Built-in math functions (abs, sqrt, pow, etc.)
24//!
25//! ## Grammar (Recursive Descent)
26//!
27//! ```text
28//! expr     → term (('+' | '-') term)*
29//! term     → factor (('*' | '/' | '%') factor)*
30//! factor   → unary
31//! unary    → '-'? primary
32//! primary  → NUMBER | COLUMN | '(' expr ')' | function
33//! function → IDENT '(' (expr (',' expr)*)? ')'
34//! ```
35//!
36//! ## Security Model
37//!
38//! - No variable assignment (immutable)
39//! - No loops (single-pass evaluation)
40//! - No function definitions (allowlist only)
41//! - Timeout: 1ms max for safety
42
43use std::collections::HashMap;
44use std::fmt;
45use std::iter::Peekable;
46use std::str::Chars;
47
48/// Expression AST node
49#[derive(Debug, Clone, PartialEq)]
50pub enum Expr {
51    /// Literal number
52    Literal(f64),
53    /// Column reference
54    Column(String),
55    /// Binary operation
56    BinaryOp {
57        op: BinaryOp,
58        left: Box<Expr>,
59        right: Box<Expr>,
60    },
61    /// Unary operation
62    UnaryOp { op: UnaryOp, expr: Box<Expr> },
63    /// Function call
64    FnCall { name: String, args: Vec<Expr> },
65}
66
67/// Binary operators
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub enum BinaryOp {
70    Add,
71    Sub,
72    Mul,
73    Div,
74    Mod,
75    Pow,
76}
77
78impl fmt::Display for BinaryOp {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        match self {
81            BinaryOp::Add => write!(f, "+"),
82            BinaryOp::Sub => write!(f, "-"),
83            BinaryOp::Mul => write!(f, "*"),
84            BinaryOp::Div => write!(f, "/"),
85            BinaryOp::Mod => write!(f, "%"),
86            BinaryOp::Pow => write!(f, "^"),
87        }
88    }
89}
90
91/// Unary operators
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum UnaryOp {
94    Neg,
95}
96
97/// Token types
98#[derive(Debug, Clone, PartialEq)]
99enum Token {
100    Number(f64),
101    Ident(String),
102    Plus,
103    Minus,
104    Star,
105    Slash,
106    Percent,
107    Caret,
108    LParen,
109    RParen,
110    Comma,
111    Eof,
112}
113
114/// Tokenizer for expressions
115struct Lexer<'a> {
116    chars: Peekable<Chars<'a>>,
117}
118
119impl<'a> Lexer<'a> {
120    fn new(input: &'a str) -> Self {
121        Self {
122            chars: input.chars().peekable(),
123        }
124    }
125
126    fn next_token(&mut self) -> Result<Token, CalcError> {
127        self.skip_whitespace();
128
129        match self.chars.peek() {
130            None => Ok(Token::Eof),
131            Some(&c) => match c {
132                '+' => {
133                    self.chars.next();
134                    Ok(Token::Plus)
135                }
136                '-' => {
137                    self.chars.next();
138                    Ok(Token::Minus)
139                }
140                '*' => {
141                    self.chars.next();
142                    Ok(Token::Star)
143                }
144                '/' => {
145                    self.chars.next();
146                    Ok(Token::Slash)
147                }
148                '%' => {
149                    self.chars.next();
150                    Ok(Token::Percent)
151                }
152                '^' => {
153                    self.chars.next();
154                    Ok(Token::Caret)
155                }
156                '(' => {
157                    self.chars.next();
158                    Ok(Token::LParen)
159                }
160                ')' => {
161                    self.chars.next();
162                    Ok(Token::RParen)
163                }
164                ',' => {
165                    self.chars.next();
166                    Ok(Token::Comma)
167                }
168                '0'..='9' | '.' => self.number(),
169                'a'..='z' | 'A'..='Z' | '_' | '$' => self.ident(),
170                _ => Err(CalcError::UnexpectedChar(c)),
171            },
172        }
173    }
174
175    fn skip_whitespace(&mut self) {
176        while let Some(&c) = self.chars.peek() {
177            if c.is_whitespace() {
178                self.chars.next();
179            } else {
180                break;
181            }
182        }
183    }
184
185    fn number(&mut self) -> Result<Token, CalcError> {
186        let mut s = String::new();
187        let mut has_dot = false;
188
189        while let Some(&c) = self.chars.peek() {
190            if c.is_ascii_digit() {
191                s.push(c);
192                self.chars.next();
193            } else if c == '.' && !has_dot {
194                has_dot = true;
195                s.push(c);
196                self.chars.next();
197            } else if c == 'e' || c == 'E' {
198                // Scientific notation
199                s.push(c);
200                self.chars.next();
201                if let Some(&sign) = self.chars.peek()
202                    && (sign == '+' || sign == '-')
203                {
204                    s.push(sign);
205                    self.chars.next();
206                }
207            } else {
208                break;
209            }
210        }
211
212        s.parse::<f64>()
213            .map(Token::Number)
214            .map_err(|_| CalcError::InvalidNumber(s))
215    }
216
217    fn ident(&mut self) -> Result<Token, CalcError> {
218        let mut s = String::new();
219
220        while let Some(&c) = self.chars.peek() {
221            if c.is_alphanumeric() || c == '_' || c == '$' {
222                s.push(c);
223                self.chars.next();
224            } else {
225                break;
226            }
227        }
228
229        Ok(Token::Ident(s))
230    }
231}
232
233/// Expression parser
234pub struct Parser<'a> {
235    lexer: Lexer<'a>,
236    current: Token,
237}
238
239impl<'a> Parser<'a> {
240    /// Create a new parser
241    pub fn new(input: &'a str) -> Result<Self, CalcError> {
242        let mut lexer = Lexer::new(input);
243        let current = lexer.next_token()?;
244        Ok(Self { lexer, current })
245    }
246
247    /// Parse the expression
248    pub fn parse(&mut self) -> Result<Expr, CalcError> {
249        let expr = self.expression()?;
250        if self.current != Token::Eof {
251            return Err(CalcError::UnexpectedToken(format!("{:?}", self.current)));
252        }
253        Ok(expr)
254    }
255
256    fn advance(&mut self) -> Result<(), CalcError> {
257        self.current = self.lexer.next_token()?;
258        Ok(())
259    }
260
261    fn expression(&mut self) -> Result<Expr, CalcError> {
262        self.additive()
263    }
264
265    fn additive(&mut self) -> Result<Expr, CalcError> {
266        let mut left = self.multiplicative()?;
267
268        loop {
269            let op = match &self.current {
270                Token::Plus => BinaryOp::Add,
271                Token::Minus => BinaryOp::Sub,
272                _ => break,
273            };
274            self.advance()?;
275            let right = self.multiplicative()?;
276            left = Expr::BinaryOp {
277                op,
278                left: Box::new(left),
279                right: Box::new(right),
280            };
281        }
282
283        Ok(left)
284    }
285
286    fn multiplicative(&mut self) -> Result<Expr, CalcError> {
287        let mut left = self.power()?;
288
289        loop {
290            let op = match &self.current {
291                Token::Star => BinaryOp::Mul,
292                Token::Slash => BinaryOp::Div,
293                Token::Percent => BinaryOp::Mod,
294                _ => break,
295            };
296            self.advance()?;
297            let right = self.power()?;
298            left = Expr::BinaryOp {
299                op,
300                left: Box::new(left),
301                right: Box::new(right),
302            };
303        }
304
305        Ok(left)
306    }
307
308    fn power(&mut self) -> Result<Expr, CalcError> {
309        let left = self.unary()?;
310
311        if self.current == Token::Caret {
312            self.advance()?;
313            let right = self.power()?; // Right associative
314            return Ok(Expr::BinaryOp {
315                op: BinaryOp::Pow,
316                left: Box::new(left),
317                right: Box::new(right),
318            });
319        }
320
321        Ok(left)
322    }
323
324    fn unary(&mut self) -> Result<Expr, CalcError> {
325        if self.current == Token::Minus {
326            self.advance()?;
327            let expr = self.unary()?;
328            return Ok(Expr::UnaryOp {
329                op: UnaryOp::Neg,
330                expr: Box::new(expr),
331            });
332        }
333
334        self.primary()
335    }
336
337    fn primary(&mut self) -> Result<Expr, CalcError> {
338        match self.current.clone() {
339            Token::Number(n) => {
340                self.advance()?;
341                Ok(Expr::Literal(n))
342            }
343            Token::Ident(name) => {
344                self.advance()?;
345                if self.current == Token::LParen {
346                    // Function call
347                    self.advance()?;
348                    let args = self.arguments()?;
349                    if self.current != Token::RParen {
350                        return Err(CalcError::ExpectedToken(")".into()));
351                    }
352                    self.advance()?;
353                    Ok(Expr::FnCall { name, args })
354                } else {
355                    // Column reference
356                    Ok(Expr::Column(name))
357                }
358            }
359            Token::LParen => {
360                self.advance()?;
361                let expr = self.expression()?;
362                if self.current != Token::RParen {
363                    return Err(CalcError::ExpectedToken(")".into()));
364                }
365                self.advance()?;
366                Ok(expr)
367            }
368            _ => Err(CalcError::UnexpectedToken(format!("{:?}", self.current))),
369        }
370    }
371
372    fn arguments(&mut self) -> Result<Vec<Expr>, CalcError> {
373        let mut args = Vec::new();
374
375        if self.current == Token::RParen {
376            return Ok(args);
377        }
378
379        args.push(self.expression()?);
380
381        while self.current == Token::Comma {
382            self.advance()?;
383            args.push(self.expression()?);
384        }
385
386        Ok(args)
387    }
388}
389
390/// Calculator error types
391#[derive(Debug, Clone, PartialEq)]
392pub enum CalcError {
393    UnexpectedChar(char),
394    InvalidNumber(String),
395    UnexpectedToken(String),
396    ExpectedToken(String),
397    UndefinedColumn(String),
398    UndefinedFunction(String),
399    DivisionByZero,
400    InvalidArgCount {
401        name: String,
402        expected: usize,
403        got: usize,
404    },
405    MathError(String),
406    Timeout,
407}
408
409impl fmt::Display for CalcError {
410    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
411        match self {
412            CalcError::UnexpectedChar(c) => write!(f, "Unexpected character: {}", c),
413            CalcError::InvalidNumber(s) => write!(f, "Invalid number: {}", s),
414            CalcError::UnexpectedToken(s) => write!(f, "Unexpected token: {}", s),
415            CalcError::ExpectedToken(s) => write!(f, "Expected token: {}", s),
416            CalcError::UndefinedColumn(s) => write!(f, "Undefined column: {}", s),
417            CalcError::UndefinedFunction(s) => write!(f, "Undefined function: {}", s),
418            CalcError::DivisionByZero => write!(f, "Division by zero"),
419            CalcError::InvalidArgCount {
420                name,
421                expected,
422                got,
423            } => {
424                write!(
425                    f,
426                    "Function {} expects {} args, got {}",
427                    name, expected, got
428                )
429            }
430            CalcError::MathError(s) => write!(f, "Math error: {}", s),
431            CalcError::Timeout => write!(f, "Evaluation timeout"),
432        }
433    }
434}
435
436impl std::error::Error for CalcError {}
437
438/// Row context for evaluation
439pub type RowContext = HashMap<String, f64>;
440
441/// Expression evaluator
442pub struct Evaluator {
443    /// Maximum evaluation steps (prevent infinite loops)
444    max_steps: usize,
445    /// Current step count
446    steps: usize,
447}
448
449impl Evaluator {
450    /// Create a new evaluator
451    pub fn new() -> Self {
452        Self {
453            max_steps: 10000,
454            steps: 0,
455        }
456    }
457
458    /// Create with custom step limit
459    pub fn with_max_steps(max_steps: usize) -> Self {
460        Self {
461            max_steps,
462            steps: 0,
463        }
464    }
465
466    /// Evaluate expression with row context
467    pub fn eval(&mut self, expr: &Expr, ctx: &RowContext) -> Result<f64, CalcError> {
468        self.steps += 1;
469        if self.steps > self.max_steps {
470            return Err(CalcError::Timeout);
471        }
472
473        match expr {
474            Expr::Literal(n) => Ok(*n),
475
476            Expr::Column(name) => ctx
477                .get(name)
478                .copied()
479                .ok_or_else(|| CalcError::UndefinedColumn(name.clone())),
480
481            Expr::BinaryOp { op, left, right } => {
482                let l = self.eval(left, ctx)?;
483                let r = self.eval(right, ctx)?;
484
485                match op {
486                    BinaryOp::Add => Ok(l + r),
487                    BinaryOp::Sub => Ok(l - r),
488                    BinaryOp::Mul => Ok(l * r),
489                    BinaryOp::Div => {
490                        if r == 0.0 {
491                            Err(CalcError::DivisionByZero)
492                        } else {
493                            Ok(l / r)
494                        }
495                    }
496                    BinaryOp::Mod => {
497                        if r == 0.0 {
498                            Err(CalcError::DivisionByZero)
499                        } else {
500                            Ok(l % r)
501                        }
502                    }
503                    BinaryOp::Pow => Ok(l.powf(r)),
504                }
505            }
506
507            Expr::UnaryOp { op, expr } => {
508                let v = self.eval(expr, ctx)?;
509                match op {
510                    UnaryOp::Neg => Ok(-v),
511                }
512            }
513
514            Expr::FnCall { name, args } => self.call_function(name, args, ctx),
515        }
516    }
517
518    /// Call a built-in function
519    fn call_function(
520        &mut self,
521        name: &str,
522        args: &[Expr],
523        ctx: &RowContext,
524    ) -> Result<f64, CalcError> {
525        let evaluated: Result<Vec<f64>, CalcError> =
526            args.iter().map(|a| self.eval(a, ctx)).collect();
527        let args = evaluated?;
528
529        match name.to_lowercase().as_str() {
530            // Single argument functions
531            "abs" => {
532                check_args(name, &args, 1)?;
533                Ok(args[0].abs())
534            }
535            "sqrt" => {
536                check_args(name, &args, 1)?;
537                if args[0] < 0.0 {
538                    Err(CalcError::MathError("sqrt of negative number".into()))
539                } else {
540                    Ok(args[0].sqrt())
541                }
542            }
543            "floor" => {
544                check_args(name, &args, 1)?;
545                Ok(args[0].floor())
546            }
547            "ceil" => {
548                check_args(name, &args, 1)?;
549                Ok(args[0].ceil())
550            }
551            "round" => {
552                if args.len() == 1 {
553                    Ok(args[0].round())
554                } else if args.len() == 2 {
555                    let factor = 10f64.powi(args[1] as i32);
556                    Ok((args[0] * factor).round() / factor)
557                } else {
558                    Err(CalcError::InvalidArgCount {
559                        name: name.into(),
560                        expected: 1,
561                        got: args.len(),
562                    })
563                }
564            }
565            "sin" => {
566                check_args(name, &args, 1)?;
567                Ok(args[0].sin())
568            }
569            "cos" => {
570                check_args(name, &args, 1)?;
571                Ok(args[0].cos())
572            }
573            "tan" => {
574                check_args(name, &args, 1)?;
575                Ok(args[0].tan())
576            }
577            "exp" => {
578                check_args(name, &args, 1)?;
579                Ok(args[0].exp())
580            }
581            "ln" | "log" => {
582                check_args(name, &args, 1)?;
583                if args[0] <= 0.0 {
584                    Err(CalcError::MathError("log of non-positive number".into()))
585                } else {
586                    Ok(args[0].ln())
587                }
588            }
589            "log10" => {
590                check_args(name, &args, 1)?;
591                if args[0] <= 0.0 {
592                    Err(CalcError::MathError("log of non-positive number".into()))
593                } else {
594                    Ok(args[0].log10())
595                }
596            }
597            "log2" => {
598                check_args(name, &args, 1)?;
599                if args[0] <= 0.0 {
600                    Err(CalcError::MathError("log of non-positive number".into()))
601                } else {
602                    Ok(args[0].log2())
603                }
604            }
605
606            // Two argument functions
607            "pow" => {
608                check_args(name, &args, 2)?;
609                Ok(args[0].powf(args[1]))
610            }
611            "min" => {
612                check_args(name, &args, 2)?;
613                Ok(args[0].min(args[1]))
614            }
615            "max" => {
616                check_args(name, &args, 2)?;
617                Ok(args[0].max(args[1]))
618            }
619            "atan2" => {
620                check_args(name, &args, 2)?;
621                Ok(args[0].atan2(args[1]))
622            }
623
624            // Variadic functions
625            "sum" => Ok(args.iter().sum()),
626            "avg" => {
627                if args.is_empty() {
628                    Err(CalcError::InvalidArgCount {
629                        name: name.into(),
630                        expected: 1,
631                        got: 0,
632                    })
633                } else {
634                    Ok(args.iter().sum::<f64>() / args.len() as f64)
635                }
636            }
637
638            // Conditional
639            "if" => {
640                check_args(name, &args, 3)?;
641                if args[0] != 0.0 {
642                    Ok(args[1])
643                } else {
644                    Ok(args[2])
645                }
646            }
647
648            _ => Err(CalcError::UndefinedFunction(name.into())),
649        }
650    }
651}
652
653impl Default for Evaluator {
654    fn default() -> Self {
655        Self::new()
656    }
657}
658
659fn check_args(name: &str, args: &[f64], expected: usize) -> Result<(), CalcError> {
660    if args.len() != expected {
661        Err(CalcError::InvalidArgCount {
662            name: name.into(),
663            expected,
664            got: args.len(),
665        })
666    } else {
667        Ok(())
668    }
669}
670
671/// Parse and evaluate an expression in one step
672pub fn calculate(expr: &str, ctx: &RowContext) -> Result<f64, CalcError> {
673    let mut parser = Parser::new(expr)?;
674    let ast = parser.parse()?;
675    let mut evaluator = Evaluator::new();
676    evaluator.eval(&ast, ctx)
677}
678
679/// Parse an expression without evaluating
680pub fn parse_expr(expr: &str) -> Result<Expr, CalcError> {
681    let mut parser = Parser::new(expr)?;
682    parser.parse()
683}
684
685#[cfg(test)]
686mod tests {
687    use super::*;
688
689    #[test]
690    fn test_basic_arithmetic() {
691        let ctx = RowContext::new();
692
693        assert_eq!(calculate("2 + 3", &ctx).unwrap(), 5.0);
694        assert_eq!(calculate("10 - 4", &ctx).unwrap(), 6.0);
695        assert_eq!(calculate("3 * 4", &ctx).unwrap(), 12.0);
696        assert_eq!(calculate("15 / 3", &ctx).unwrap(), 5.0);
697        assert_eq!(calculate("7 % 4", &ctx).unwrap(), 3.0);
698        assert_eq!(calculate("2 ^ 3", &ctx).unwrap(), 8.0);
699    }
700
701    #[test]
702    fn test_operator_precedence() {
703        let ctx = RowContext::new();
704
705        assert_eq!(calculate("2 + 3 * 4", &ctx).unwrap(), 14.0);
706        assert_eq!(calculate("(2 + 3) * 4", &ctx).unwrap(), 20.0);
707        assert_eq!(calculate("2 * 3 + 4", &ctx).unwrap(), 10.0);
708        assert_eq!(calculate("10 - 2 * 3", &ctx).unwrap(), 4.0);
709    }
710
711    #[test]
712    fn test_unary_minus() {
713        let ctx = RowContext::new();
714
715        assert_eq!(calculate("-5", &ctx).unwrap(), -5.0);
716        assert_eq!(calculate("--5", &ctx).unwrap(), 5.0);
717        assert_eq!(calculate("3 + -2", &ctx).unwrap(), 1.0);
718        assert_eq!(calculate("-3 * -2", &ctx).unwrap(), 6.0);
719    }
720
721    #[test]
722    fn test_column_references() {
723        let mut ctx = RowContext::new();
724        ctx.insert("price".into(), 99.99);
725        ctx.insert("quantity".into(), 5.0);
726        ctx.insert("tax_rate".into(), 0.15);
727
728        assert_eq!(calculate("price * quantity", &ctx).unwrap(), 499.95);
729        assert_eq!(
730            calculate("price * quantity * (1 + tax_rate)", &ctx).unwrap(),
731            574.9425
732        );
733    }
734
735    #[test]
736    fn test_functions() {
737        let ctx = RowContext::new();
738
739        assert_eq!(calculate("abs(-5)", &ctx).unwrap(), 5.0);
740        assert_eq!(calculate("sqrt(16)", &ctx).unwrap(), 4.0);
741        assert_eq!(calculate("floor(3.7)", &ctx).unwrap(), 3.0);
742        assert_eq!(calculate("ceil(3.2)", &ctx).unwrap(), 4.0);
743        assert_eq!(calculate("round(3.5)", &ctx).unwrap(), 4.0);
744        #[allow(clippy::approx_constant)]
745        {
746            assert_eq!(calculate("round(3.14159, 2)", &ctx).unwrap(), 3.14);
747        }
748        assert_eq!(calculate("min(3, 5)", &ctx).unwrap(), 3.0);
749        assert_eq!(calculate("max(3, 5)", &ctx).unwrap(), 5.0);
750        assert_eq!(calculate("pow(2, 10)", &ctx).unwrap(), 1024.0);
751    }
752
753    #[test]
754    fn test_trig_functions() {
755        let ctx = RowContext::new();
756        let _pi = std::f64::consts::PI;
757
758        assert!((calculate("sin(0)", &ctx).unwrap() - 0.0).abs() < 1e-10);
759        assert!((calculate("cos(0)", &ctx).unwrap() - 1.0).abs() < 1e-10);
760    }
761
762    #[test]
763    fn test_conditional() {
764        let mut ctx = RowContext::new();
765        ctx.insert("score".into(), 85.0);
766
767        // NOTE: Comparison operators not yet implemented in lexer
768        // Using computed boolean (non-zero = true, zero = false)
769        // if(score > 70, 1, 0) would work once we add comparison operators
770
771        // For now, test with explicit boolean values
772        assert_eq!(calculate("if(1, 10, 20)", &ctx).unwrap(), 10.0);
773        assert_eq!(calculate("if(0, 10, 20)", &ctx).unwrap(), 20.0);
774
775        // Can use score directly as condition (85 != 0 means true)
776        assert_eq!(calculate("if(score, 1, 0)", &ctx).unwrap(), 1.0);
777    }
778
779    #[test]
780    fn test_variadic_functions() {
781        let ctx = RowContext::new();
782
783        assert_eq!(calculate("sum(1, 2, 3, 4)", &ctx).unwrap(), 10.0);
784        assert_eq!(calculate("avg(2, 4, 6)", &ctx).unwrap(), 4.0);
785    }
786
787    #[test]
788    fn test_scientific_notation() {
789        let ctx = RowContext::new();
790
791        assert_eq!(calculate("1e3", &ctx).unwrap(), 1000.0);
792        assert_eq!(calculate("1.5e-2", &ctx).unwrap(), 0.015);
793    }
794
795    #[test]
796    fn test_division_by_zero() {
797        let ctx = RowContext::new();
798
799        assert!(matches!(
800            calculate("1 / 0", &ctx),
801            Err(CalcError::DivisionByZero)
802        ));
803        assert!(matches!(
804            calculate("5 % 0", &ctx),
805            Err(CalcError::DivisionByZero)
806        ));
807    }
808
809    #[test]
810    fn test_undefined_column() {
811        let ctx = RowContext::new();
812
813        assert!(matches!(
814            calculate("undefined_col + 1", &ctx),
815            Err(CalcError::UndefinedColumn(_))
816        ));
817    }
818
819    #[test]
820    fn test_undefined_function() {
821        let ctx = RowContext::new();
822
823        assert!(matches!(
824            calculate("unknown_func(1)", &ctx),
825            Err(CalcError::UndefinedFunction(_))
826        ));
827    }
828
829    #[test]
830    fn test_complex_expression() {
831        let mut ctx = RowContext::new();
832        ctx.insert("revenue".into(), 1000.0);
833        ctx.insert("cost".into(), 600.0);
834        ctx.insert("tax".into(), 0.15);
835
836        // Calculate after-tax profit
837        let result = calculate("(revenue - cost) * (1 - tax)", &ctx).unwrap();
838        assert!((result - 340.0).abs() < 1e-10);
839    }
840}