1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
//! An simple, comprehensive scientific calculator library
//! 
//! An easy-to-use scientific calculator crate that evaluates mathematical 
//! expressions in strings. Includes variable assignment and recall, comprehensive
//! built-in functions and constants, and elegant error handling. The `calculate()`
//! function is passed a `Context` object that is used for keeping state between
//! calculations such as user-defined variables.
//! 
//! # Example
//! 
//! ```
//! # use sci_calc::{calculate, context::Context};
//! let mut ctx = Context::new();
//! assert_eq!(calculate("5 + 5", &mut ctx), Ok(10.0));
//! ```

use std::fmt;
use lalrpop_util::{lalrpop_mod, ParseError};
use libm::tgamma;

pub mod context;
use context::*;

mod ast;
use ast::*;

// defining lalrpop's parsing module
lalrpop_mod!(grammar);

/// Attempts to calculate a string containing a mathematical expression
/// 
/// On top of the input string, this function also takes a mutable reference to
/// a `Context` object, this object is used to keep track of state between calls,
/// tracking things like user-defined variables and previous answers. Returns a
/// result containing the solution to the expression if successful, or a `CalcError`
/// struct if not.
/// 
/// # Example
/// 
/// ```
/// # use sci_calc::{calculate, context::Context};
/// # let mut ctx = Context::new();
/// assert_eq!(calculate("5 + 5", &mut ctx), Ok(10.0));
/// ```
pub fn calculate(input_str: &str, ctx: &mut Context) -> Result<f64, CalcError> {
	
	let input_str = if let Some(stripped) = input_str.strip_suffix('\n') { stripped } else { input_str };

	// invoking grammar parser generated by lalrpop
	let parser = grammar::targetParser::new();
	let (tree, assignment) = match parser.parse(input_str) {
		Ok(res) => { res }
		Err(e) => {
			
			let msg = match &e {
				ParseError::InvalidToken { location } => {
					let pad = std::iter::repeat(" ").take(*location).collect::<String>();
					format!("Invalid token\n| {input_str}\n| {pad}└── here")
				},
				ParseError::UnrecognizedEof { location: _, expected: _ } => {
					String::from("Unexpected EOI")
				},
				ParseError::UnrecognizedToken { token, expected: _ } => {
					let pad = std::iter::repeat(" ").take(token.0).collect::<String>();
					format!("Unexpected token\n| {input_str}\n| {pad}└── here")
				},
				ParseError::ExtraToken { token} => {
					let pad = std::iter::repeat(" ").take(token.0).collect::<String>();
					format!("Extra token\n| {input_str}\n| {pad}└── here")
				},
				_ => String::from("Parser error"),
			};
			return Err(CalcError {
				error_type: CalcErrorType::ParserError,
				msg,
			});
		}
	};
	//print!("Tree: {tree}\r\n");

	let res = evaluate_ast(*tree, ctx);

	if res.is_ok() {
		let solution = res.clone().unwrap();

	
		if assignment.is_some() {
			// handling assignment
			let assign_var = assignment.unwrap();
			if let Err(e) = ctx.assign_var(&assign_var, solution) {
				return Err(e);
			}
		} else {
			// setting `ans` variable
			ctx.prev_ans = Some(solution);
		}
	}

	return res;
}

/// Recursive function used to evaluate the abstract syntax tree generated by
/// the lalrpop parser
fn evaluate_ast(root: Expr, ctx: &Context) -> Result<f64, CalcError> {
	match root {
		Expr::Num(n) => {
			Ok(n)
		}
		Expr::Op(left_e, op, right_e) => {
			// evaluation inner expressions
			let lhs = match evaluate_ast(*left_e, ctx) {
				Ok(n) => n,
				Err(e) => { return Err(e) },
			};
			let rhs = match evaluate_ast(*right_e, ctx) {
				Ok(n) => n,
				Err(e) => { return Err(e) },
			};
			// performing operation
			let res = match op {
				Operation::Add => { lhs + rhs }
				Operation::Sub => { lhs - rhs }
				Operation::Mul => { lhs * rhs }
				Operation::Div => { lhs / rhs }
				Operation::FloorDiv => { f64::floor(lhs / rhs) }
				Operation::Mod => { lhs % rhs }
				Operation::Exp => { lhs.powf(rhs) }
			};
			Ok(res)
		}
		Expr::Func(name, arg_list) => {
			let mut args: Vec<f64> = Vec::new();
			for arg in arg_list {
				let val = match evaluate_ast(*arg, ctx) {
					Ok(n) => n,
					Err(e) => { return Err(e) },
				};
				args.push(val);
			}
			if let Some(res) = ctx.try_function(&name, args) {
				return res;
			}
			return Err(CalcError {
				error_type: CalcErrorType::UndefinedIdentifier,
				msg: format!("Unknown function \"{name}()\""),
			})
		}
		Expr::Var(name) => {
			if let Some(res) = ctx.lookup_var(&name) {
				return res;
			}
			return Err(CalcError {
				error_type: CalcErrorType::UndefinedIdentifier,
				msg: format!("Unknown variable \"{name}\""),
			})
		}
		Expr::Fac(e) => {
			let num = match evaluate_ast(*e, ctx) {
				Ok(n) => n,
				Err(e) => { return Err(e) },
			};
			return Ok(tgamma(num + 1.0));
		}
	}
}

/// Custom error handling struct
#[derive(Debug, Clone, PartialEq)]
pub struct CalcError {
	/// Broad type of error
	pub error_type: CalcErrorType,
	/// Description of error
	pub msg: String,
}
impl fmt::Display for CalcError {
	fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
		write!(formatter, "{}: {}\n", self.error_type, self.msg)
	}
}

/// Error types
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CalcErrorType {
	/// Error generated during parsing of the input
	ParserError,
	/// Error generated when encountering an unknown variable or function
	UndefinedIdentifier,
	/// Error generated when assigning a value to a variable fails
	AssignmentError,
	/// Error generated from passing invalid arguments into a function
	ArgumentError,
	/// Error generated during calculation
	CalculationError,
}
impl fmt::Display for CalcErrorType {
	fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
		write!(formatter, "{}", match *self {
			Self::ParserError => { "Parser error" },
			Self::UndefinedIdentifier => { "Undefined identifier" },
			Self::AssignmentError => { "Assignment error" },
			Self::ArgumentError => { "Argument error" },
			Self::CalculationError => { "Calculation error" },
		})
	}
}