1use somni_parser::{
2 ast::{
3 Body, Expression, Function, If, LeftHandExpression, LiteralValue, Loop,
4 RightHandExpression, Statement, TypeHint, VariableDefinition,
5 },
6 lexer,
7 parser::DefaultTypeSet,
8 Location,
9};
10
11use crate::{
12 value::LoadStore, EvalError, ExprContext, FunctionCallError, Type, TypeSet, TypedValue,
13};
14
15pub struct ExpressionVisitor<'a, C, T = DefaultTypeSet> {
17 pub context: &'a mut C,
19 pub source: &'a str,
21 pub _marker: std::marker::PhantomData<T>,
23}
24
25impl<'a, C, T> ExpressionVisitor<'a, C, T>
26where
27 C: ExprContext<T>,
28 T: TypeSet,
29{
30 fn visit_variable(&mut self, variable: &lexer::Token) -> Result<TypedValue<T>, EvalError> {
31 let name = variable.source(self.source);
32 self.context.try_load_variable(name).ok_or(EvalError {
33 message: format!("Variable {name} was not found").into_boxed_str(),
34 location: variable.location,
35 })
36 }
37
38 pub fn visit_expression(
40 &mut self,
41 expression: &Expression<T::Parser>,
42 ) -> Result<TypedValue<T>, EvalError> {
43 let result = match expression {
44 Expression::Expression { expression } => {
45 self.visit_right_hand_expression(expression)?
46 }
47 Expression::Assignment {
48 left_expr,
49 operator: _,
50 right_expr,
51 } => {
52 let rhs = self.visit_right_hand_expression(right_expr)?;
53 let assign_result = match left_expr {
54 LeftHandExpression::Deref { name, .. } => {
55 let address =
56 self.visit_right_hand_expression(&RightHandExpression::Variable {
57 variable: *name,
58 })?;
59 self.context.assign_address(address, &rhs)
60 }
61 LeftHandExpression::Name { variable } => {
62 let name = variable.source(self.source);
63 self.context.assign_variable(name, &rhs)
64 }
65 };
66
67 if let Err(error) = assign_result {
68 return Err(EvalError {
69 message: error,
70 location: expression.location(),
71 });
72 }
73
74 TypedValue::Void
75 }
76 };
77
78 Ok(result)
79 }
80
81 pub fn visit_right_hand_expression(
83 &mut self,
84 expression: &RightHandExpression<T::Parser>,
85 ) -> Result<TypedValue<T>, EvalError> {
86 let result = match expression {
87 RightHandExpression::Variable { variable } => self.visit_variable(variable)?,
88 RightHandExpression::Literal { value } => match &value.value {
89 LiteralValue::Integer(value) => TypedValue::<T>::MaybeSignedInt(*value),
90 LiteralValue::Float(value) => TypedValue::<T>::Float(*value),
91 LiteralValue::String(value) => value.store(self.context.type_context()),
92 LiteralValue::Boolean(value) => TypedValue::<T>::Bool(*value),
93 },
94 RightHandExpression::UnaryOperator { name, operand } => {
95 match name.source(self.source) {
96 "!" => {
97 let operand = self.visit_right_hand_expression(operand)?;
98
99 match TypedValue::<T>::not(self.context.type_context(), operand) {
100 Ok(r) => r,
101 Err(error) => {
102 return Err(EvalError {
103 message: format!("Failed to evaluate expression: {error}")
104 .into_boxed_str(),
105 location: expression.location(),
106 });
107 }
108 }
109 }
110
111 "-" => {
112 let value = self.visit_right_hand_expression(operand)?;
113 let ty = value.type_of();
114 TypedValue::<T>::negate(self.context.type_context(), value).map_err(
115 |e| EvalError {
116 message: format!("Cannot negate {ty}: {e}").into_boxed_str(),
117 location: operand.location(),
118 },
119 )?
120 }
121
122 "&" => {
123 let RightHandExpression::Variable { variable } = operand.as_ref() else {
124 return Err(EvalError {
125 message: String::from(
126 "Cannot take address of non-variable expression",
127 )
128 .into_boxed_str(),
129 location: operand.location(),
130 });
131 };
132
133 let name = variable.source(self.source);
134 self.context.address_of(name)
135 }
136 "*" => {
137 let address = self.visit_right_hand_expression(operand)?;
138 self.context.at_address(address).map_err(|e| EvalError {
139 message: format!("Failed to load variable from address: {e}")
140 .into_boxed_str(),
141 location: operand.location(),
142 })?
143 }
144 _ => {
145 return Err(EvalError {
146 message: format!(
147 "Unknown unary operator: {}",
148 name.source(self.source)
149 )
150 .into_boxed_str(),
151 location: expression.location(),
152 });
153 }
154 }
155 }
156 RightHandExpression::BinaryOperator { name, operands } => {
157 let lhs = self.visit_right_hand_expression(&operands[0])?;
158
159 let short_circuiting = ["&&", "||"];
160 let operator = name.source(self.source);
161
162 if short_circuiting.contains(&operator) {
164 return match operator {
165 "&&" if lhs == TypedValue::<T>::Bool(false) => Ok(TypedValue::Bool(false)),
166 "||" if lhs == TypedValue::<T>::Bool(true) => Ok(TypedValue::Bool(true)),
167 _ => self.visit_right_hand_expression(&operands[1]),
168 };
169 }
170
171 let rhs = self.visit_right_hand_expression(&operands[1])?;
173 let type_context = self.context.type_context();
174 let result = match operator {
175 "+" => TypedValue::<T>::add(type_context, lhs, rhs),
176 "-" => TypedValue::<T>::subtract(type_context, lhs, rhs),
177 "*" => TypedValue::<T>::multiply(type_context, lhs, rhs),
178 "/" => TypedValue::<T>::divide(type_context, lhs, rhs),
179 "%" => TypedValue::<T>::modulo(type_context, lhs, rhs),
180 "<" => TypedValue::<T>::less_than(type_context, lhs, rhs),
181 ">" => TypedValue::<T>::less_than(type_context, rhs, lhs),
182 "<=" => TypedValue::<T>::less_than_or_equal(type_context, lhs, rhs),
183 ">=" => TypedValue::<T>::less_than_or_equal(type_context, rhs, lhs),
184 "==" => TypedValue::<T>::equals(type_context, lhs, rhs),
185 "!=" => TypedValue::<T>::not_equals(type_context, lhs, rhs),
186 "|" => TypedValue::<T>::bitwise_or(type_context, lhs, rhs),
187 "^" => TypedValue::<T>::bitwise_xor(type_context, lhs, rhs),
188 "&" => TypedValue::<T>::bitwise_and(type_context, lhs, rhs),
189 "<<" => TypedValue::<T>::shift_left(type_context, lhs, rhs),
190 ">>" => TypedValue::<T>::shift_right(type_context, lhs, rhs),
191
192 other => {
193 return Err(EvalError {
194 message: format!("Unknown binary operator: {other}").into_boxed_str(),
195 location: expression.location(),
196 });
197 }
198 };
199
200 match result {
201 Ok(r) => r,
202 Err(error) => {
203 return Err(EvalError {
204 message: format!("Failed to evaluate expression: {error}")
205 .into_boxed_str(),
206 location: expression.location(),
207 });
208 }
209 }
210 }
211 RightHandExpression::FunctionCall { name, arguments } => {
212 let function_name = name.source(self.source);
213 let mut args = Vec::with_capacity(arguments.len());
214 for arg in arguments {
215 args.push(self.visit_right_hand_expression(arg)?);
216 }
217
218 match self.context.call_function(function_name, &args) {
219 Ok(result) => result,
220 Err(FunctionCallError::IncorrectArgumentCount { expected }) => {
221 return Err(EvalError {
222 message: format!(
223 "{function_name} takes {expected} arguments, {} given",
224 args.len()
225 )
226 .into_boxed_str(),
227 location: expression.location(),
228 });
229 }
230 Err(FunctionCallError::IncorrectArgumentType { idx, expected }) => {
231 return Err(EvalError {
232 message: format!(
233 "{function_name} expects argument {idx} to be {expected}, got {}",
234 args[idx].type_of()
235 )
236 .into_boxed_str(),
237 location: arguments[idx].location(),
238 });
239 }
240 Err(FunctionCallError::FunctionNotFound) => {
241 return Err(EvalError {
242 message: format!("Function {function_name} is not found")
243 .into_boxed_str(),
244 location: expression.location(),
245 });
246 }
247 Err(FunctionCallError::Other(error)) => {
248 return Err(EvalError {
249 message: format!("Failed to call {function_name}: {error}")
250 .into_boxed_str(),
251 location: expression.location(),
252 });
253 }
254 }
255 }
256 };
257
258 Ok(result)
259 }
260
261 fn typecheck_with_hint(
262 &self,
263 value: TypedValue<T>,
264 hint: Option<TypeHint>,
265 ) -> Result<TypedValue<T>, EvalError> {
266 let Some(hint) = hint else {
267 return Ok(value);
269 };
270
271 let ty =
272 Type::from_name(hint.type_name.source(self.source)).map_err(|message| EvalError {
273 message,
274 location: hint.type_name.location,
275 })?;
276
277 self.typecheck(value, ty, hint.type_name.location)
278 }
279
280 fn typecheck(
281 &self,
282 value: TypedValue<T>,
283 hint: Type,
284 location: Location,
285 ) -> Result<TypedValue<T>, EvalError> {
286 match (value, hint) {
287 (value, hint) if value.type_of() == hint => Ok(value),
288 (TypedValue::MaybeSignedInt(val), Type::Int) => Ok(TypedValue::Int(val)),
289 (TypedValue::MaybeSignedInt(val), Type::SignedInt) => Ok(TypedValue::<T>::SignedInt(
290 T::to_signed(val).map_err(|_| EvalError {
291 message: format!("Failed to cast {val:?} to signed int").into_boxed_str(),
292 location,
293 })?,
294 )),
295 (value, hint) => Err(EvalError {
296 message: format!("Expected {hint}, got {}", value.type_of()).into_boxed_str(),
297 location,
298 }),
299 }
300 }
301
302 pub fn visit_function(
304 &mut self,
305 function: &Function<T::Parser>,
306 args: &[TypedValue<T>],
307 ) -> Result<TypedValue<T>, EvalError> {
308 for (arg, arg_value) in function.arguments.iter().zip(args.iter()) {
309 let arg_name = arg.name.source(self.source);
310
311 let arg_value = self.typecheck_with_hint(arg_value.clone(), Some(arg.arg_type))?;
312
313 self.context.declare(arg_name, arg_value);
314 }
315
316 let retval = match self.visit_body(&function.body)? {
317 StatementResult::Return(typed_value) | StatementResult::ImplicitReturn(typed_value) => {
318 typed_value
319 }
320 StatementResult::EndOfBody => TypedValue::Void,
321 StatementResult::LoopBreak | StatementResult::LoopContinue => todo!(),
322 };
323
324 let retval =
325 self.typecheck_with_hint(retval, function.return_decl.as_ref().map(|d| d.return_type))?;
326
327 Ok(retval)
328 }
329
330 fn visit_body(&mut self, body: &Body<T::Parser>) -> Result<StatementResult<T>, EvalError> {
331 self.context.open_scope();
332
333 let mut body_result = StatementResult::EndOfBody;
334 for statement in body.statements.iter() {
335 if let Some(retval) = self.visit_statement(statement)? {
336 body_result = retval;
337 match body_result {
338 StatementResult::ImplicitReturn(_) => {}
339 _ => break,
340 }
341 } else {
342 body_result = StatementResult::EndOfBody;
344 }
345 }
346
347 self.context.close_scope();
348 Ok(body_result)
349 }
350
351 fn visit_statement(
352 &mut self,
353 statement: &Statement<T::Parser>,
354 ) -> Result<Option<StatementResult<T>>, EvalError> {
355 match statement {
356 Statement::Return(return_with_value) => {
357 return self
358 .visit_right_hand_expression(&return_with_value.expression)
359 .map(|rv| Some(StatementResult::Return(rv)));
360 }
361 Statement::ImplicitReturn(expression) => {
362 return self
363 .visit_right_hand_expression(expression)
364 .map(|rv| Some(StatementResult::ImplicitReturn(rv)));
365 }
366 Statement::EmptyReturn(_) => {
367 return Ok(Some(StatementResult::Return(TypedValue::Void)));
368 }
369 Statement::If(if_statement) => return self.visit_if(if_statement),
370 Statement::Loop(loop_statement) => return self.visit_loop(loop_statement),
371 Statement::Break(_) => return Ok(Some(StatementResult::LoopBreak)),
372 Statement::Continue(_) => return Ok(Some(StatementResult::LoopContinue)),
373 Statement::Scope(body) => {
374 return self.visit_body(body).map(|r| match r {
375 StatementResult::EndOfBody => None,
376 r => Some(r),
377 })
378 }
379 Statement::VariableDefinition(variable_definition) => {
380 self.visit_declaration(variable_definition)?;
381 }
382 Statement::Expression { expression, .. } => {
383 self.visit_expression(expression)?;
384 }
385 }
386
387 Ok(None)
388 }
389
390 fn visit_declaration(&mut self, decl: &VariableDefinition<T::Parser>) -> Result<(), EvalError> {
391 let name = decl.identifier.source(self.source);
392 let value = self.visit_right_hand_expression(&decl.initializer)?;
393
394 let value = self.typecheck_with_hint(value, decl.type_token)?;
395
396 self.context.declare(name, value);
397
398 Ok(())
399 }
400
401 fn visit_if(
402 &mut self,
403 if_statement: &If<T::Parser>,
404 ) -> Result<Option<StatementResult<T>>, EvalError> {
405 let condition = self.visit_right_hand_expression(&if_statement.condition)?;
406
407 let condition = self.typecheck(condition, Type::Bool, if_statement.condition.location())?;
408
409 let body = if condition == TypedValue::Bool(true) {
410 &if_statement.body
411 } else if let Some(ref else_branch) = if_statement.else_branch {
412 &else_branch.else_body
413 } else {
414 return Ok(None);
416 };
417
418 let retval = match self.visit_body(body)? {
419 StatementResult::EndOfBody => None,
420 other => Some(other),
421 };
422 Ok(retval)
423 }
424
425 fn visit_loop(
426 &mut self,
427 loop_statement: &Loop<T::Parser>,
428 ) -> Result<Option<StatementResult<T>>, EvalError> {
429 loop {
430 match self.visit_body(&loop_statement.body)? {
431 ret @ StatementResult::Return(_) => return Ok(Some(ret)),
432 StatementResult::LoopBreak => return Ok(None),
433 StatementResult::LoopContinue
434 | StatementResult::EndOfBody
435 | StatementResult::ImplicitReturn(_) => {}
436 }
437 }
438 }
439}
440
441enum StatementResult<T: TypeSet> {
442 Return(TypedValue<T>),
443 ImplicitReturn(TypedValue<T>),
444 LoopBreak,
445 LoopContinue,
446 EndOfBody,
447}