tc/
lib.rs

1use std::{collections::HashMap, fmt::Display};
2
3mod ast;
4pub mod func;
5pub mod input;
6pub mod lex;
7pub mod parse;
8
9use input::{HasSpan, Span};
10
11#[derive(Debug)]
12pub enum Error {
13    Parse(parse::Error),
14    UnknownVar(Span, String),
15    UnknownFunc(Span, String),
16    FuncArgCount {
17        span: Span,
18        name: String,
19        expected: func::ArgCount,
20        actual: u32,
21    },
22    ZeroDiv(Span),
23}
24
25impl HasSpan for Error {
26    fn span(&self) -> Span {
27        match self {
28            Error::Parse(err) => err.span(),
29            Error::UnknownVar(span, _) => *span,
30            Error::UnknownFunc(span, _) => *span,
31            Error::FuncArgCount { span, .. } => *span,
32            Error::ZeroDiv(span) => *span,
33        }
34    }
35}
36
37impl Display for Error {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            Error::Parse(err) => err.fmt(f),
41            Error::UnknownVar(_, name) => {
42                write!(f, "Variable `{name}` is unknown")
43            }
44            Error::UnknownFunc(_, name) => {
45                write!(f, "Function `{name}` is unknown")
46            }
47            Error::FuncArgCount {
48                name,
49                expected,
50                actual,
51                ..
52            } => {
53                write!(
54                    f,
55                    "Function `{name}` expects {expected}, but received {actual}"
56                )
57            }
58            Error::ZeroDiv(_) => {
59                write!(f, "Division by zero")
60            }
61        }
62    }
63}
64
65impl From<parse::Error> for Error {
66    fn from(e: parse::Error) -> Self {
67        Error::Parse(e)
68    }
69}
70
71#[derive(Debug, Clone, PartialEq)]
72pub struct Eval {
73    pub sym: String,
74    pub val: f64,
75}
76
77#[derive(Debug, Clone)]
78pub struct TermCalc {
79    vars: HashMap<String, f64>,
80    funcs: HashMap<String, func::Func>,
81}
82
83impl Default for TermCalc {
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89impl TermCalc {
90    pub fn new() -> Self {
91        let mut vars = HashMap::new();
92        vars.insert("pi".to_string(), std::f64::consts::PI);
93        vars.insert("e".to_string(), std::f64::consts::E);
94        let funcs = func::all_funcs()
95            .into_iter()
96            .map(|f| (f.name.clone(), f))
97            .collect();
98        TermCalc { vars, funcs }
99    }
100
101    pub fn get_var(&self, name: &str) -> Option<f64> {
102        self.vars.get(name).copied()
103    }
104
105    pub fn eval_line<S: AsRef<str>>(&mut self, line: S) -> Result<Eval, Error> {
106        let line = line.as_ref();
107        let item = parse::parse_line(line.chars())?;
108        self.eval_item(item)
109    }
110
111    fn eval_item(&mut self, item: ast::Item) -> Result<Eval, Error> {
112        let (sym, expr) = match item.kind {
113            ast::ItemKind::Assign(sym, expr) => (sym, expr),
114            // if item is a single var, we return its name as evaluation symbol
115            ast::ItemKind::Expr(ast::Expr {
116                span,
117                kind: ast::ExprKind::Var(sym),
118            }) => (
119                sym.clone(),
120                ast::Expr {
121                    span,
122                    kind: ast::ExprKind::Var(sym),
123                },
124            ),
125            ast::ItemKind::Expr(expr) => ("ans".to_string(), expr),
126        };
127
128        let val = self.eval_expr(expr)?;
129        self.vars.insert(sym.clone(), val);
130        Ok(Eval { sym, val })
131    }
132
133    fn eval_expr(&self, expr: ast::Expr) -> Result<f64, Error> {
134        let ast::Expr { span, kind } = expr;
135        match kind {
136            ast::ExprKind::Num(n) => Ok(n),
137            ast::ExprKind::Var(s) => match self.vars.get(&s) {
138                Some(n) => Ok(*n),
139                None => Err(Error::UnknownVar(span, s)),
140            },
141            ast::ExprKind::BinOp(op, lhs, rhs) => {
142                let lhs = self.eval_expr(*lhs)?;
143                let rhs = self.eval_expr(*rhs)?;
144                match op {
145                    ast::BinOp::Add => Ok(lhs + rhs),
146                    ast::BinOp::Sub => Ok(lhs - rhs),
147                    ast::BinOp::Mul => Ok(lhs * rhs),
148                    ast::BinOp::Div if rhs == 0.0 => Err(Error::ZeroDiv(span)),
149                    ast::BinOp::Div => Ok(lhs / rhs),
150                    ast::BinOp::Mod if rhs == 0.0 => Err(Error::ZeroDiv(span)),
151                    ast::BinOp::Mod => Ok(lhs % rhs),
152                    ast::BinOp::Pow => Ok(lhs.powf(rhs)),
153                }
154            }
155            ast::ExprKind::UnOp(op, expr) => {
156                let val = self.eval_expr(*expr)?;
157                match op {
158                    ast::UnOp::Plus => Ok(val),
159                    ast::UnOp::Minus => Ok(-val),
160                }
161            }
162            ast::ExprKind::Call {
163                name_span,
164                name,
165                args,
166            } => self.eval_call(span, name_span, name, args),
167        }
168    }
169
170    fn eval_call(
171        &self,
172        span: Span,
173        name_span: Span,
174        name: String,
175        args: Vec<ast::Expr>,
176    ) -> Result<f64, Error> {
177        let func = match self.funcs.get(&name) {
178            Some(f) => f,
179            None => return Err(Error::UnknownFunc(name_span, name)),
180        };
181        let args = self.eval_args(span, func, args)?;
182        let f = func.eval;
183        Ok(f(args))
184    }
185
186    fn eval_args(
187        &self,
188        span: Span,
189        func: &func::Func,
190        args: Vec<ast::Expr>,
191    ) -> Result<func::Args, Error> {
192        if !func.arg_count.check(args.len()) {
193            return Err(Error::FuncArgCount {
194                span,
195                name: func.name.clone(),
196                expected: func.arg_count,
197                actual: args.len() as _,
198            });
199        }
200        let mut args = args.into_iter();
201        Ok(match func.arg_count {
202            func::ArgCount::One => {
203                let arg = self.eval_expr(args.next().unwrap())?;
204                func::Args::One(arg)
205            }
206            func::ArgCount::Two => {
207                let arg1 = self.eval_expr(args.next().unwrap())?;
208                let arg2 = self.eval_expr(args.next().unwrap())?;
209                func::Args::Two(arg1, arg2)
210            }
211            func::ArgCount::Atleast(..) => {
212                let args = args
213                    .map(|e| self.eval_expr(e))
214                    .collect::<Result<Vec<_>, _>>()?;
215                func::Args::Dyn(args)
216            }
217        })
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::{Eval, TermCalc};
224    use approx::{assert_relative_eq, AbsDiffEq, RelativeEq, UlpsEq};
225
226    impl AbsDiffEq for Eval {
227        type Epsilon = f64;
228        fn default_epsilon() -> Self::Epsilon {
229            f64::default_epsilon()
230        }
231        fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
232            self.sym == other.sym && self.val.abs_diff_eq(&other.val, epsilon)
233        }
234    }
235
236    impl RelativeEq for Eval {
237        fn default_max_relative() -> Self::Epsilon {
238            f64::default_max_relative()
239        }
240        fn relative_eq(
241            &self,
242            other: &Self,
243            epsilon: Self::Epsilon,
244            max_relative: Self::Epsilon,
245        ) -> bool {
246            self.sym == other.sym && self.val.relative_eq(&other.val, epsilon, max_relative)
247        }
248    }
249
250    impl UlpsEq for Eval {
251        fn default_max_ulps() -> u32 {
252            f64::default_max_ulps()
253        }
254        fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
255            self.sym == other.sym && self.val.ulps_eq(&other.val, epsilon, max_ulps)
256        }
257    }
258
259    #[test]
260    fn test_eval_line() {
261        let mut tc = TermCalc::new();
262        // integers have perfect precision, no need for relative_eq!
263        assert_eq!(
264            tc.eval_line("1").unwrap(),
265            Eval {
266                sym: "ans".to_string(),
267                val: 1.0,
268            }
269        );
270        assert_relative_eq!(
271            tc.eval_line("sin(pi/2)").unwrap(),
272            Eval {
273                sym: "ans".to_string(),
274                val: 1.0,
275            },
276            epsilon = f64::EPSILON,
277        );
278        assert_relative_eq!(
279            tc.eval_line("x = cos(pi)").unwrap(),
280            Eval {
281                sym: "x".to_string(),
282                val: -1.0,
283            },
284            epsilon = f64::EPSILON,
285        );
286        assert_relative_eq!(
287            tc.eval_line("y = x + ans").unwrap(),
288            Eval {
289                sym: "y".to_string(),
290                val: 0.0,
291            },
292            epsilon = f64::EPSILON,
293        );
294        assert_relative_eq!(
295            tc.eval_line("10 - 1 - 2 - 3").unwrap(),
296            Eval {
297                sym: "ans".to_string(),
298                val: 4.0,
299            },
300            epsilon = f64::EPSILON,
301        );
302    }
303}