1use std::{collections::HashMap, fmt::Display};
2
3mod ast;
4pub mod func;
5pub mod input;
6pub mod lex;
7pub mod parse;
8
9use input::{HasSpan, Span};
10
11#[derive(Debug)]
12pub enum Error {
13 Parse(parse::Error),
14 UnknownVar(Span, String),
15 UnknownFunc(Span, String),
16 FuncArgCount {
17 span: Span,
18 name: String,
19 expected: func::ArgCount,
20 actual: u32,
21 },
22 ZeroDiv(Span),
23}
24
25impl HasSpan for Error {
26 fn span(&self) -> Span {
27 match self {
28 Error::Parse(err) => err.span(),
29 Error::UnknownVar(span, _) => *span,
30 Error::UnknownFunc(span, _) => *span,
31 Error::FuncArgCount { span, .. } => *span,
32 Error::ZeroDiv(span) => *span,
33 }
34 }
35}
36
37impl Display for Error {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 Error::Parse(err) => err.fmt(f),
41 Error::UnknownVar(_, name) => {
42 write!(f, "Variable `{name}` is unknown")
43 }
44 Error::UnknownFunc(_, name) => {
45 write!(f, "Function `{name}` is unknown")
46 }
47 Error::FuncArgCount {
48 name,
49 expected,
50 actual,
51 ..
52 } => {
53 write!(
54 f,
55 "Function `{name}` expects {expected}, but received {actual}"
56 )
57 }
58 Error::ZeroDiv(_) => {
59 write!(f, "Division by zero")
60 }
61 }
62 }
63}
64
65impl From<parse::Error> for Error {
66 fn from(e: parse::Error) -> Self {
67 Error::Parse(e)
68 }
69}
70
71#[derive(Debug, Clone, PartialEq)]
72pub struct Eval {
73 pub sym: String,
74 pub val: f64,
75}
76
77#[derive(Debug, Clone)]
78pub struct TermCalc {
79 vars: HashMap<String, f64>,
80 funcs: HashMap<String, func::Func>,
81}
82
83impl Default for TermCalc {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl TermCalc {
90 pub fn new() -> Self {
91 let mut vars = HashMap::new();
92 vars.insert("pi".to_string(), std::f64::consts::PI);
93 vars.insert("e".to_string(), std::f64::consts::E);
94 let funcs = func::all_funcs()
95 .into_iter()
96 .map(|f| (f.name.clone(), f))
97 .collect();
98 TermCalc { vars, funcs }
99 }
100
101 pub fn get_var(&self, name: &str) -> Option<f64> {
102 self.vars.get(name).copied()
103 }
104
105 pub fn eval_line<S: AsRef<str>>(&mut self, line: S) -> Result<Eval, Error> {
106 let line = line.as_ref();
107 let item = parse::parse_line(line.chars())?;
108 self.eval_item(item)
109 }
110
111 fn eval_item(&mut self, item: ast::Item) -> Result<Eval, Error> {
112 let (sym, expr) = match item.kind {
113 ast::ItemKind::Assign(sym, expr) => (sym, expr),
114 ast::ItemKind::Expr(ast::Expr {
116 span,
117 kind: ast::ExprKind::Var(sym),
118 }) => (
119 sym.clone(),
120 ast::Expr {
121 span,
122 kind: ast::ExprKind::Var(sym),
123 },
124 ),
125 ast::ItemKind::Expr(expr) => ("ans".to_string(), expr),
126 };
127
128 let val = self.eval_expr(expr)?;
129 self.vars.insert(sym.clone(), val);
130 Ok(Eval { sym, val })
131 }
132
133 fn eval_expr(&self, expr: ast::Expr) -> Result<f64, Error> {
134 let ast::Expr { span, kind } = expr;
135 match kind {
136 ast::ExprKind::Num(n) => Ok(n),
137 ast::ExprKind::Var(s) => match self.vars.get(&s) {
138 Some(n) => Ok(*n),
139 None => Err(Error::UnknownVar(span, s)),
140 },
141 ast::ExprKind::BinOp(op, lhs, rhs) => {
142 let lhs = self.eval_expr(*lhs)?;
143 let rhs = self.eval_expr(*rhs)?;
144 match op {
145 ast::BinOp::Add => Ok(lhs + rhs),
146 ast::BinOp::Sub => Ok(lhs - rhs),
147 ast::BinOp::Mul => Ok(lhs * rhs),
148 ast::BinOp::Div if rhs == 0.0 => Err(Error::ZeroDiv(span)),
149 ast::BinOp::Div => Ok(lhs / rhs),
150 ast::BinOp::Mod if rhs == 0.0 => Err(Error::ZeroDiv(span)),
151 ast::BinOp::Mod => Ok(lhs % rhs),
152 ast::BinOp::Pow => Ok(lhs.powf(rhs)),
153 }
154 }
155 ast::ExprKind::UnOp(op, expr) => {
156 let val = self.eval_expr(*expr)?;
157 match op {
158 ast::UnOp::Plus => Ok(val),
159 ast::UnOp::Minus => Ok(-val),
160 }
161 }
162 ast::ExprKind::Call {
163 name_span,
164 name,
165 args,
166 } => self.eval_call(span, name_span, name, args),
167 }
168 }
169
170 fn eval_call(
171 &self,
172 span: Span,
173 name_span: Span,
174 name: String,
175 args: Vec<ast::Expr>,
176 ) -> Result<f64, Error> {
177 let func = match self.funcs.get(&name) {
178 Some(f) => f,
179 None => return Err(Error::UnknownFunc(name_span, name)),
180 };
181 let args = self.eval_args(span, func, args)?;
182 let f = func.eval;
183 Ok(f(args))
184 }
185
186 fn eval_args(
187 &self,
188 span: Span,
189 func: &func::Func,
190 args: Vec<ast::Expr>,
191 ) -> Result<func::Args, Error> {
192 if !func.arg_count.check(args.len()) {
193 return Err(Error::FuncArgCount {
194 span,
195 name: func.name.clone(),
196 expected: func.arg_count,
197 actual: args.len() as _,
198 });
199 }
200 let mut args = args.into_iter();
201 Ok(match func.arg_count {
202 func::ArgCount::One => {
203 let arg = self.eval_expr(args.next().unwrap())?;
204 func::Args::One(arg)
205 }
206 func::ArgCount::Two => {
207 let arg1 = self.eval_expr(args.next().unwrap())?;
208 let arg2 = self.eval_expr(args.next().unwrap())?;
209 func::Args::Two(arg1, arg2)
210 }
211 func::ArgCount::Atleast(..) => {
212 let args = args
213 .map(|e| self.eval_expr(e))
214 .collect::<Result<Vec<_>, _>>()?;
215 func::Args::Dyn(args)
216 }
217 })
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::{Eval, TermCalc};
224 use approx::{assert_relative_eq, AbsDiffEq, RelativeEq, UlpsEq};
225
226 impl AbsDiffEq for Eval {
227 type Epsilon = f64;
228 fn default_epsilon() -> Self::Epsilon {
229 f64::default_epsilon()
230 }
231 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
232 self.sym == other.sym && self.val.abs_diff_eq(&other.val, epsilon)
233 }
234 }
235
236 impl RelativeEq for Eval {
237 fn default_max_relative() -> Self::Epsilon {
238 f64::default_max_relative()
239 }
240 fn relative_eq(
241 &self,
242 other: &Self,
243 epsilon: Self::Epsilon,
244 max_relative: Self::Epsilon,
245 ) -> bool {
246 self.sym == other.sym && self.val.relative_eq(&other.val, epsilon, max_relative)
247 }
248 }
249
250 impl UlpsEq for Eval {
251 fn default_max_ulps() -> u32 {
252 f64::default_max_ulps()
253 }
254 fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
255 self.sym == other.sym && self.val.ulps_eq(&other.val, epsilon, max_ulps)
256 }
257 }
258
259 #[test]
260 fn test_eval_line() {
261 let mut tc = TermCalc::new();
262 assert_eq!(
264 tc.eval_line("1").unwrap(),
265 Eval {
266 sym: "ans".to_string(),
267 val: 1.0,
268 }
269 );
270 assert_relative_eq!(
271 tc.eval_line("sin(pi/2)").unwrap(),
272 Eval {
273 sym: "ans".to_string(),
274 val: 1.0,
275 },
276 epsilon = f64::EPSILON,
277 );
278 assert_relative_eq!(
279 tc.eval_line("x = cos(pi)").unwrap(),
280 Eval {
281 sym: "x".to_string(),
282 val: -1.0,
283 },
284 epsilon = f64::EPSILON,
285 );
286 assert_relative_eq!(
287 tc.eval_line("y = x + ans").unwrap(),
288 Eval {
289 sym: "y".to_string(),
290 val: 0.0,
291 },
292 epsilon = f64::EPSILON,
293 );
294 assert_relative_eq!(
295 tc.eval_line("10 - 1 - 2 - 3").unwrap(),
296 Eval {
297 sym: "ans".to_string(),
298 val: 4.0,
299 },
300 epsilon = f64::EPSILON,
301 );
302 }
303}