1use crate::{
2 ast::Expression,
3 environment::{Environment, FunctionResult},
4 error::{Error, Result},
5 operator::Operator,
6 value::Value,
7};
8
9pub fn check_variables_and_functions(
16 env: &impl Environment,
17 expression: &Expression,
18) -> Result<()> {
19 match expression {
20 Expression::Unary { right, operator: _ } => check_variables_and_functions(env, right),
21 Expression::Binary {
22 left,
23 right,
24 operator: _,
25 } => check_variables_and_functions(env, left)
26 .and_then(|()| check_variables_and_functions(env, right)),
27 Expression::Ternary {
28 left,
29 middle,
30 right,
31 operator: _,
32 } => check_variables_and_functions(env, left)
33 .and_then(|()| check_variables_and_functions(env, middle))
34 .and_then(|()| check_variables_and_functions(env, right)),
35 Expression::Array {
36 expressions: values,
37 } => check_expressions(env, values),
38 Expression::Variable { name } => {
39 if env.variable_exists(name) {
40 Ok(())
41 } else {
42 Err(Error::MissingVariable(name.clone()))
43 }
44 }
45 Expression::Call { name, params } => {
46 let param_count = params.len();
47
48 match env.function_exists(name, param_count) {
49 FunctionResult::Exists { pure: _ } => check_expressions(env, params),
50 FunctionResult::NotFound => Err(Error::MissingFunction(name.clone())),
51 FunctionResult::WrongArity { min, max } => Err(Error::ParamCountMismatch(
52 name.clone(),
53 param_count,
54 min,
55 max,
56 )),
57 }
58 }
59 Expression::Literal { value: _ } => Ok(()),
60 }
61}
62
63fn check_expressions(env: &impl Environment, expressions: &[Expression]) -> Result<()> {
64 expressions
65 .iter()
66 .try_for_each(|expression| check_variables_and_functions(env, expression))
67}
68
69pub fn check_boolean_result(ast: &Expression) -> Result<()> {
87 match ast {
88 Expression::Unary { right: _, operator } => match operator {
89 Operator::Not => Ok(()),
90 _ => Err(Error::InvalidUnaryOperator(*operator)),
91 },
92 Expression::Binary {
93 left: _,
94 right: _,
95 operator,
96 } => match operator {
97 Operator::Greater
98 | Operator::GreaterEqual
99 | Operator::Less
100 | Operator::LessEqual
101 | Operator::Equal
102 | Operator::NotEqual
103 | Operator::And
104 | Operator::Or
105 | Operator::Xor => Ok(()),
106 _ => Err(Error::InvalidBinaryOperator(*operator)),
107 },
108 Expression::Ternary {
109 left,
110 middle,
111 right,
112 operator,
113 } => match operator {
114 Operator::TernaryCondition => {
115 check_boolean_result(left)
118 .and_then(|()| check_boolean_result(middle))
119 .and_then(|()| check_boolean_result(right))
120 }
121 _ => Err(Error::InvalidTernaryOperator(*operator)),
122 },
123 Expression::Array { expressions: _ } => Err(Error::LiteralNotBoolean),
124 Expression::Literal { value } => match value {
125 Value::Boolean(_) => Ok(()),
126 _ => Err(Error::LiteralNotBoolean),
127 },
128 Expression::Variable { name: _ } | Expression::Call { name: _, params: _ } => {
129 Ok(()) }
131 }
132}
133
134#[cfg(test)]
135mod test {
136 use crate::{
137 ast::Expression,
138 environment::StaticEnvironment,
139 function::{Arity, Function},
140 operator::Operator,
141 stdlib::NativeResult,
142 validate::Error,
143 value::Value,
144 };
145
146 use super::check_variables_and_functions;
147
148 #[test]
149 fn valid() {
150 let ast = Expression::Binary {
151 left: Box::new(Expression::Literal {
152 value: Value::Number(10.0),
153 }),
154 right: Box::new(Expression::Literal {
155 value: Value::Number(10.0),
156 }),
157 operator: Operator::Plus,
158 };
159
160 let result = check_variables_and_functions(&StaticEnvironment::default(), &ast);
161
162 assert_eq!(Ok(()), result);
163 }
164
165 #[test]
166 fn valid_nested() {
167 let ast = Expression::Binary {
168 left: Box::new(Expression::Literal {
169 value: Value::Number(10.0),
170 }),
171 right: Box::new(Expression::Unary {
172 right: Box::new(Expression::Literal {
173 value: Value::Number(10.0),
174 }),
175 operator: Operator::Minus,
176 }),
177 operator: Operator::Plus,
178 };
179
180 let result = check_variables_and_functions(&StaticEnvironment::default(), &ast);
181
182 assert_eq!(Ok(()), result);
183 }
184
185 #[test]
186 fn err_missing_variable() {
187 let ast = Expression::Binary {
188 left: Box::new(Expression::Literal {
189 value: Value::Number(10.0),
190 }),
191 right: Box::new(Expression::Variable {
192 name: String::from("VAR_NAME"),
193 }),
194 operator: Operator::Plus,
195 };
196
197 let result = check_variables_and_functions(&StaticEnvironment::default(), &ast);
198
199 assert_eq!(
200 Err(Error::MissingVariable(String::from("VAR_NAME"))),
201 result
202 );
203 }
204
205 #[test]
206 fn err_function_missing() {
207 let ast = Expression::Binary {
208 left: Box::new(Expression::Literal {
209 value: Value::Number(10.0),
210 }),
211 right: Box::new(Expression::Call {
212 name: String::from("max"),
213 params: vec![],
214 }),
215 operator: Operator::Plus,
216 };
217
218 let result = check_variables_and_functions(&StaticEnvironment::default(), &ast);
219
220 assert_eq!(Err(Error::MissingFunction(String::from("max"))), result);
221 }
222
223 fn dummy_function(_params: &[Value]) -> NativeResult {
224 unreachable!()
225 }
226
227 #[test]
228 fn err_function_params_mismatch() {
229 let ast = Expression::Binary {
230 left: Box::new(Expression::Literal {
231 value: Value::Number(10.0),
232 }),
233 right: Box::new(Expression::Call {
234 name: String::from("max"),
235 params: vec![],
236 }),
237 operator: Operator::Plus,
238 };
239
240 let mut env = StaticEnvironment::default();
241 env.add_function(Function::new(
242 dummy_function,
243 Arity::Polyadic {
244 required: 2,
245 optional: 0,
246 },
247 "max(left: Number, right: Number): Number",
248 ));
249
250 let result = check_variables_and_functions(&env, &ast);
251
252 assert_eq!(
253 Err(Error::ParamCountMismatch(String::from("max"), 0, 2, 2)),
254 result
255 );
256 }
257
258 #[test]
259 fn err_function_nested_params() {
260 let ast = Expression::Call {
261 name: String::from("func"),
262 params: vec![Expression::Variable {
263 name: String::from("not_found"),
264 }],
265 };
266
267 let mut env = StaticEnvironment::default();
268 env.add_function(Function::new(
269 dummy_function,
270 Arity::Variadic,
271 "func(...): Number",
272 ));
273
274 let result = check_variables_and_functions(&env, &ast);
275
276 assert_eq!(
277 Err(Error::MissingVariable(String::from("not_found"))),
278 result
279 );
280 }
281}