1use super::value::{Field, Fn, FnBody, FnInputs, Scalar, Value};
2use crate::{
3 etc::{
4 known,
5 syn::{SynPath, SynPathKind},
6 },
7 semantic::{
8 basic_traits::{RawScope, Scope, Scoping},
9 entry::GlobalCx,
10 infer,
11 tree::PathId,
12 },
13 Intern, Map, Result, TriResult,
14};
15use any_intern::Interned;
16use logic_eval_util::{str::StrPath, symbol::SymbolTable};
17use std::{collections::hash_map::Entry, mem};
18
19struct ValueWithCtrl<'gcx> {
20 value: Value<'gcx>,
21 is_return: bool,
22}
23
24impl<'gcx> From<Value<'gcx>> for ValueWithCtrl<'gcx> {
25 fn from(value: Value<'gcx>) -> Self {
26 ValueWithCtrl {
27 value,
28 is_return: false,
29 }
30 }
31}
32
33#[allow(unused_variables)]
36pub(crate) trait Host<'gcx>: Scoping {
37 fn find_type(&mut self, expr: &syn::Expr) -> TriResult<infer::Type<'gcx>, ()>;
38 fn find_fn(&mut self, name: StrPath, types: &[infer::Type<'gcx>]) -> Fn;
39 fn syn_path_to_value(&mut self, syn_path: SynPath) -> TriResult<Value<'gcx>, ()>;
40}
41
42struct HostWrapper<'a, H> {
43 inner: &'a mut H,
44 scope_stack: Vec<RawScope>,
45}
46
47impl<'a, 'gcx, H: Host<'gcx>> HostWrapper<'a, H> {
48 fn new(host: &'a mut H) -> Self {
49 Self {
50 inner: host,
51 scope_stack: Vec::new(),
52 }
53 }
54
55 fn eval_known_fn(&mut self, abs_path: &str, values: &[Value<'gcx>]) -> Option<Value<'gcx>> {
56 use known::apply;
57 use once_cell::sync::OnceCell;
58
59 type F = for<'a> fn(&[Value<'a>]) -> Result<Value<'a>>;
60
61 static FMAP: OnceCell<Map<&'static str, F>> = OnceCell::new();
62
63 let fmap = FMAP.get_or_init(|| {
64 let mut map: Map<&'static str, F> = Map::default();
65
66 map.insert(apply::NAME_ADD, |values: &[Value<'_>]| {
67 debug_assert_eq!(values.len(), 2);
68 values[0].try_add(&values[1])
69 });
70 map.insert(apply::NAME_SUB, |values: &[Value<'_>]| {
71 debug_assert_eq!(values.len(), 2);
72 values[0].try_sub(&values[1])
73 });
74 map.insert(apply::NAME_MUL, |values: &[Value<'_>]| {
75 debug_assert_eq!(values.len(), 2);
76 values[0].try_mul(&values[1])
77 });
78 map.insert(apply::NAME_DIV, |values: &[Value<'_>]| {
79 debug_assert_eq!(values.len(), 2);
80 values[0].try_div(&values[1])
81 });
82 map.insert(apply::NAME_REM, |values: &[Value<'_>]| {
83 debug_assert_eq!(values.len(), 2);
84 values[0].try_rem(&values[1])
85 });
86 map.insert(apply::NAME_BIT_XOR, |values: &[Value<'_>]| {
87 debug_assert_eq!(values.len(), 2);
88 values[0].try_bit_xor(&values[1])
89 });
90 map.insert(apply::NAME_BIT_AND, |values: &[Value<'_>]| {
91 debug_assert_eq!(values.len(), 2);
92 values[0].try_bit_and(&values[1])
93 });
94 map.insert(apply::NAME_BIT_OR, |values: &[Value<'_>]| {
95 debug_assert_eq!(values.len(), 2);
96 values[0].try_bit_or(&values[1])
97 });
98 map.insert(apply::NAME_SHL, |values: &[Value<'_>]| {
99 debug_assert_eq!(values.len(), 2);
100 values[0].try_shl(&values[1])
101 });
102 map.insert(apply::NAME_SHR, |values: &[Value<'_>]| {
103 debug_assert_eq!(values.len(), 2);
104 values[0].try_shr(&values[1])
105 });
106 map.insert(apply::NAME_ADD_ASSIGN, |values: &[Value<'_>]| {
107 debug_assert_eq!(values.len(), 2);
108 values[0].try_add(&values[1]) });
110 map.insert(apply::NAME_SUB_ASSIGN, |values: &[Value<'_>]| {
111 debug_assert_eq!(values.len(), 2);
112 values[0].try_sub(&values[1]) });
114 map.insert(apply::NAME_MUL_ASSIGN, |values: &[Value<'_>]| {
115 debug_assert_eq!(values.len(), 2);
116 values[0].try_mul(&values[1]) });
118 map.insert(apply::NAME_DIV_ASSIGN, |values: &[Value<'_>]| {
119 debug_assert_eq!(values.len(), 2);
120 values[0].try_div(&values[1]) });
122 map.insert(apply::NAME_REM_ASSIGN, |values: &[Value<'_>]| {
123 debug_assert_eq!(values.len(), 2);
124 values[0].try_rem(&values[1]) });
126 map.insert(apply::NAME_BIT_XOR_ASSIGN, |values: &[Value<'_>]| {
127 debug_assert_eq!(values.len(), 2);
128 values[0].try_bit_xor(&values[1]) });
130 map.insert(apply::NAME_BIT_AND_ASSIGN, |values: &[Value<'_>]| {
131 debug_assert_eq!(values.len(), 2);
132 values[0].try_bit_and(&values[1]) });
134 map.insert(apply::NAME_BIT_OR_ASSIGN, |values: &[Value<'_>]| {
135 debug_assert_eq!(values.len(), 2);
136 values[0].try_bit_or(&values[1]) });
138 map.insert(apply::NAME_SHL_ASSIGN, |values: &[Value<'_>]| {
139 debug_assert_eq!(values.len(), 2);
140 values[0].try_shl(&values[1]) });
142 map.insert(apply::NAME_SHR_ASSIGN, |values: &[Value<'_>]| {
143 debug_assert_eq!(values.len(), 2);
144 values[0].try_shr(&values[1]) });
146 map.insert(apply::NAME_NOT, |values: &[Value<'_>]| {
147 debug_assert_eq!(values.len(), 1);
148 values[0].try_not()
149 });
150 map.insert(apply::NAME_NEG, |values: &[Value<'_>]| {
151 debug_assert_eq!(values.len(), 1);
152 values[0].try_neg()
153 });
154
155 map
158 });
159
160 let f = fmap.get(&abs_path).cloned()?;
161 f(values).ok()
162 }
163
164 fn on_enter_scope(&mut self, scope: Scope) {
165 self.inner.on_enter_scope(scope);
166 self.scope_stack.push(scope.into_raw());
167 }
168
169 fn on_exit_scope(&mut self) {
170 let raw_scope = self.scope_stack.pop().unwrap();
171 let exit_scope = Scope::from_raw(raw_scope);
172 self.inner.on_exit_scope(exit_scope);
173
174 if let Some(raw_scope) = self.scope_stack.last() {
175 let reenter_scope = Scope::from_raw(*raw_scope);
176 self.inner.on_enter_scope(reenter_scope);
177 }
178 }
179}
180
181impl<'gcx, H: Host<'gcx>> Host<'gcx> for HostWrapper<'_, H> {
182 fn find_type(&mut self, expr: &syn::Expr) -> TriResult<infer::Type<'gcx>, ()> {
183 self.inner.find_type(expr)
184 }
185
186 fn find_fn(&mut self, name: StrPath, types: &[infer::Type<'gcx>]) -> Fn {
187 self.inner.find_fn(name, types)
188 }
189
190 fn syn_path_to_value(&mut self, syn_path: SynPath) -> TriResult<Value<'gcx>, ()> {
191 self.inner.syn_path_to_value(syn_path)
192 }
193}
194
195impl<'gcx, H: Host<'gcx>> Scoping for HostWrapper<'_, H> {
196 fn on_enter_scope(&mut self, scope: Scope) {
197 <Self>::on_enter_scope(self, scope)
198 }
199
200 fn on_exit_scope(&mut self, _: Scope) {
201 <Self>::on_exit_scope(self)
202 }
203}
204
205#[derive(Debug)]
208pub(crate) struct Evaluator<'gcx> {
209 gcx: &'gcx GlobalCx<'gcx>,
210 symbols: SymbolTable<Interned<'gcx, str>, Value<'gcx>>,
211}
212
213impl<'gcx> Evaluator<'gcx> {
214 pub(crate) fn new(gcx: &'gcx GlobalCx<'gcx>) -> Self {
215 Self {
216 gcx,
217 symbols: SymbolTable::default(),
218 }
219 }
220
221 pub(crate) fn eval_expr<H: Host<'gcx>>(
222 &mut self,
223 host: &mut H,
224 expr: &syn::Expr,
225 ) -> TriResult<Value<'gcx>, ()> {
226 self.symbols.clear();
227
228 let mut cx = EvalCx {
229 gcx: self.gcx,
230 symbols: &mut self.symbols,
231 host: HostWrapper::new(host),
232 };
233
234 cx.eval_expr(expr).map(|ex| ex.value)
235 }
236}
237
238struct EvalCx<'a, 'gcx, H> {
239 gcx: &'gcx GlobalCx<'gcx>,
240 symbols: &'a mut SymbolTable<Interned<'gcx, str>, Value<'gcx>>,
241 host: HostWrapper<'a, H>,
242}
243
244impl<'a, 'gcx, H: Host<'gcx>> EvalCx<'a, 'gcx, H> {
245 fn eval_expr(&mut self, expr: &syn::Expr) -> TriResult<ValueWithCtrl<'gcx>, ()> {
246 match expr {
247 syn::Expr::Array(v) => self.eval_expr_array(v).map(ValueWithCtrl::from),
248 syn::Expr::Assign(v) => self.eval_expr_assign(v).map(ValueWithCtrl::from),
249 syn::Expr::Async(_) => panic!("`async` is not supported"),
250 syn::Expr::Await(_) => panic!("`await` is not supported"),
251 syn::Expr::Binary(v) => self.eval_expr_binary(v).map(ValueWithCtrl::from),
252 syn::Expr::Block(v) => self.eval_block(&v.block).map(ValueWithCtrl::from),
253 syn::Expr::Break(v) => todo!("{v:#?}"),
254 syn::Expr::Call(v) => self.eval_expr_call(v).map(ValueWithCtrl::from),
255 syn::Expr::Cast(v) => self.eval_expr(&v.expr),
256 syn::Expr::Closure(_) => panic!("`closure` is not supported"),
257 syn::Expr::Const(v) => self.eval_block(&v.block).map(ValueWithCtrl::from),
258 syn::Expr::Let(v) => todo!("{v:#?}"),
259 syn::Expr::Lit(v) => self.eval_lit(&v.lit, expr).map(ValueWithCtrl::from),
260 syn::Expr::Loop(v) => todo!("{v:#?}"),
261 syn::Expr::Macro(_) => Ok(Value::Unit.into()),
262 syn::Expr::Match(v) => todo!("{v:#?}"),
263 syn::Expr::MethodCall(v) => todo!("{v:#?}"),
264 syn::Expr::Paren(v) => self.eval_expr_paren(v),
265 syn::Expr::Path(v) => self.eval_expr_path(v).map(ValueWithCtrl::from),
266 syn::Expr::Range(v) => todo!("{v:#?}"),
267 syn::Expr::RawAddr(v) => todo!("{v:#?}"),
268 syn::Expr::Reference(v) => todo!("{v:#?}"),
269 syn::Expr::Repeat(v) => todo!("{v:#?}"),
270 syn::Expr::Return(v) => todo!("{v:#?}"),
271 syn::Expr::Struct(v) => self.eval_expr_struct(v).map(ValueWithCtrl::from),
272 syn::Expr::Try(v) => todo!("{v:#?}"),
273 syn::Expr::TryBlock(v) => todo!("{v:#?}"),
274 syn::Expr::Tuple(v) => todo!("{v:#?}"),
275 syn::Expr::Unary(v) => self.eval_expr_unary(v).map(ValueWithCtrl::from),
276 syn::Expr::Unsafe(v) => todo!("{v:#?}"),
277 syn::Expr::Verbatim(v) => todo!("{v:#?}"),
278 syn::Expr::While(v) => todo!("{v:#?}"),
279 syn::Expr::Yield(v) => todo!("{v:#?}"),
280 _ => todo!(),
281 }
282 }
283
284 fn eval_expr_array(&mut self, expr_arr: &syn::ExprArray) -> TriResult<Value<'gcx>, ()> {
285 let fields = expr_arr
286 .elems
287 .iter()
288 .enumerate()
289 .map(|(i, elem)| {
290 self.eval_expr(elem).map(|ex| Field {
291 name: self.gcx.intern_str(&i.to_string()),
292 value: ex.value,
293 })
294 })
295 .collect::<TriResult<Vec<Field<'gcx>>, ()>>()?;
296 Ok(Value::Composed(fields))
297 }
298
299 fn eval_expr_assign(&mut self, expr_assign: &syn::ExprAssign) -> TriResult<Value<'gcx>, ()> {
300 let rv = self.eval_expr(&expr_assign.right)?.value;
301 self.update_symbol_by_expr(&expr_assign.left, rv);
302 Ok(Value::Unit)
303 }
304
305 fn eval_expr_binary(&mut self, expr_bin: &syn::ExprBinary) -> TriResult<Value<'gcx>, ()> {
306 use known::apply::*;
307
308 return match expr_bin.op {
309 syn::BinOp::Add(_) => bin(self, expr_bin, NAME_ADD),
310 syn::BinOp::Sub(_) => bin(self, expr_bin, NAME_SUB),
311 syn::BinOp::Mul(_) => bin(self, expr_bin, NAME_MUL),
312 syn::BinOp::Div(_) => bin(self, expr_bin, NAME_DIV),
313 syn::BinOp::Rem(_) => bin(self, expr_bin, NAME_REM),
314 syn::BinOp::BitXor(_) => bin(self, expr_bin, NAME_BIT_XOR),
315 syn::BinOp::BitAnd(_) => bin(self, expr_bin, NAME_BIT_AND),
316 syn::BinOp::BitOr(_) => bin(self, expr_bin, NAME_BIT_OR),
317 syn::BinOp::Shl(_) => bin(self, expr_bin, NAME_SHL),
318 syn::BinOp::Shr(_) => bin(self, expr_bin, NAME_SHR),
319 syn::BinOp::AddAssign(_) => bin_assign(self, expr_bin, NAME_ADD_ASSIGN),
320 syn::BinOp::SubAssign(_) => bin_assign(self, expr_bin, NAME_SUB_ASSIGN),
321 syn::BinOp::MulAssign(_) => bin_assign(self, expr_bin, NAME_MUL_ASSIGN),
322 syn::BinOp::DivAssign(_) => bin_assign(self, expr_bin, NAME_DIV_ASSIGN),
323 syn::BinOp::RemAssign(_) => bin_assign(self, expr_bin, NAME_REM_ASSIGN),
324 syn::BinOp::BitXorAssign(_) => bin_assign(self, expr_bin, NAME_BIT_XOR_ASSIGN),
325 syn::BinOp::BitAndAssign(_) => bin_assign(self, expr_bin, NAME_BIT_AND_ASSIGN),
326 syn::BinOp::BitOrAssign(_) => bin_assign(self, expr_bin, NAME_BIT_OR_ASSIGN),
327 syn::BinOp::ShlAssign(_) => bin_assign(self, expr_bin, NAME_SHL_ASSIGN),
328 syn::BinOp::ShrAssign(_) => bin_assign(self, expr_bin, NAME_SHR_ASSIGN),
329 _ => unreachable!(),
330 };
331
332 fn bin<'gcx, H: Host<'gcx>>(
335 this: &mut EvalCx<'_, 'gcx, H>,
336 expr_bin: &syn::ExprBinary,
337 name: &str,
338 ) -> TriResult<Value<'gcx>, ()> {
339 let lv = this.eval_expr(&expr_bin.left)?.value;
340 let rv = this.eval_expr(&expr_bin.right)?.value;
341 let values = [lv, rv];
342 if let Some(res) = this.host.eval_known_fn(name, &values) {
343 return Ok(res);
344 }
345
346 let lty = this.host.find_type(&expr_bin.left)?;
347 let rty = this.host.find_type(&expr_bin.right)?;
348 let f = this.host.find_fn(StrPath::absolute(name), &[lty, rty]);
349 this.apply_to_fn(f, &values)
350 }
351
352 fn bin_assign<'gcx, H: Host<'gcx>>(
353 this: &mut EvalCx<'_, 'gcx, H>,
354 expr_bin: &syn::ExprBinary,
355 name: &str,
356 ) -> TriResult<Value<'gcx>, ()> {
357 let value = bin(this, expr_bin, name)?;
358 this.update_symbol_by_expr(&expr_bin.left, value);
359 Ok(Value::Unit)
360 }
361 }
362
363 fn eval_expr_call(&mut self, expr_call: &syn::ExprCall) -> TriResult<Value<'gcx>, ()> {
364 let args = expr_call
365 .args
366 .iter()
367 .map(|arg| self.eval_expr(arg).map(|ex| ex.value))
368 .collect::<TriResult<Vec<_>, ()>>()?;
369
370 match self.eval_expr(&expr_call.func)?.value {
371 Value::Fn(f) => self.apply_to_fn(f, &args),
373 Value::Composed(fields) => {
375 let field_names = fields.into_iter().map(|field| field.name);
376 let value = self.apply_to_constructor(field_names, &args);
377 Ok(value)
378 }
379 _ => unreachable!(),
380 }
381 }
382
383 fn eval_expr_paren(
384 &mut self,
385 expr_paren: &syn::ExprParen,
386 ) -> TriResult<ValueWithCtrl<'gcx>, ()> {
387 self.eval_expr(&expr_paren.expr)
388 }
389
390 fn eval_expr_path(&mut self, expr_path: &syn::ExprPath) -> TriResult<Value<'gcx>, ()> {
391 if expr_path.qself.is_none() {
392 if let Some(ident) = expr_path.path.get_ident() {
393 if let Some(value) = self.symbols.get(&*ident.to_string()) {
394 return Ok(value.clone());
395 }
396 }
397 }
398
399 let syn_path = SynPath {
400 kind: SynPathKind::Expr,
401 qself: expr_path.qself.as_ref(),
402 path: &expr_path.path,
403 };
404 self.host.syn_path_to_value(syn_path)
405 }
406
407 fn eval_expr_struct(&mut self, expr_struct: &syn::ExprStruct) -> TriResult<Value<'gcx>, ()> {
408 let fields = expr_struct
409 .fields
410 .iter()
411 .map(|field| self.eval_field_value(field))
412 .collect::<TriResult<Vec<Field>, ()>>()?;
413 Ok(Value::Composed(fields))
414 }
415
416 fn eval_expr_unary(&mut self, expr_unary: &syn::ExprUnary) -> TriResult<Value<'gcx>, ()> {
417 use known::apply::*;
418
419 let name = match expr_unary.op {
420 syn::UnOp::Deref(_) => todo!(),
421 syn::UnOp::Not(_) => NAME_NOT,
422 syn::UnOp::Neg(_) => NAME_NEG,
423 _ => unreachable!(),
424 };
425
426 let v = self.eval_expr(&expr_unary.expr)?.value;
427 let values = [v];
428 if let Some(res) = self.host.eval_known_fn(name, &values) {
429 return Ok(res);
430 }
431
432 let ty = self.host.find_type(&expr_unary.expr)?;
433 let f = self.host.find_fn(StrPath::absolute(name), &[ty]);
434 self.apply_to_fn(f, &values)
435 }
436
437 fn eval_block(&mut self, block: &syn::Block) -> TriResult<Value<'gcx>, ()> {
438 self.host.on_enter_scope(Scope::Block(block));
439 self.symbols.push_transparent_block();
440
441 let mut last_value = Value::Unit;
442 for stmt in &block.stmts {
443 let ValueWithCtrl {
444 value, is_return, ..
445 } = self.eval_stmt(stmt)?;
446 last_value = value;
447 if is_return {
448 break;
449 }
450 }
451
452 self.symbols.pop_block();
453 self.host.on_exit_scope();
454 Ok(last_value)
455 }
456
457 fn eval_stmt(&mut self, stmt: &syn::Stmt) -> TriResult<ValueWithCtrl<'gcx>, ()> {
458 let value = match stmt {
459 syn::Stmt::Local(v) => {
460 self.eval_local(v)?;
461 Value::Unit.into()
462 }
463 syn::Stmt::Item(_) => Value::Unit.into(),
464 syn::Stmt::Expr(v, _) => self.eval_expr(v)?,
465 syn::Stmt::Macro(_) => Value::Unit.into(),
466 };
467 Ok(value)
468 }
469
470 fn eval_local(&mut self, local: &syn::Local) -> TriResult<(), ()> {
471 let rhs = local
473 .init
474 .as_ref()
475 .map(|init| self.eval_expr(&init.expr).map(|ex| ex.value))
476 .unwrap_or(Ok(Value::Unit))?;
477 self.push_symbol_by_pat(&local.pat, rhs);
478 Ok(())
479 }
480
481 fn eval_lit(&mut self, lit: &syn::Lit, expr: &syn::Expr) -> TriResult<Value<'gcx>, ()> {
482 use infer::{Type, TypeScalar::*};
483
484 let ty = self.host.find_type(expr)?;
485
486 let value = match lit {
487 syn::Lit::Int(v) => match ty {
488 Type::Scalar(Int { .. }) => {
489 let v = v.base10_parse().unwrap();
490 Value::Scalar(Scalar::Int(v))
491 }
492 Type::Scalar(I8) => {
493 let v = v.base10_parse().unwrap();
494 Value::Scalar(Scalar::I8(v))
495 }
496 Type::Scalar(I16) => {
497 let v = v.base10_parse().unwrap();
498 Value::Scalar(Scalar::I16(v))
499 }
500 Type::Scalar(I32) => {
501 let v = v.base10_parse().unwrap();
502 Value::Scalar(Scalar::I32(v))
503 }
504 Type::Scalar(I64) => {
505 let v = v.base10_parse().unwrap();
506 Value::Scalar(Scalar::I64(v))
507 }
508 Type::Scalar(I128) => {
509 let v = v.base10_parse().unwrap();
510 Value::Scalar(Scalar::I128(v))
511 }
512 Type::Scalar(Isize) => {
513 let v = v.base10_parse().unwrap();
514 Value::Scalar(Scalar::Isize(v))
515 }
516 Type::Scalar(U8) => {
517 let v = v.base10_parse().unwrap();
518 Value::Scalar(Scalar::U8(v))
519 }
520 Type::Scalar(U16) => {
521 let v = v.base10_parse().unwrap();
522 Value::Scalar(Scalar::U16(v))
523 }
524 Type::Scalar(U32) => {
525 let v = v.base10_parse().unwrap();
526 Value::Scalar(Scalar::U32(v))
527 }
528 Type::Scalar(U64) => {
529 let v = v.base10_parse().unwrap();
530 Value::Scalar(Scalar::U64(v))
531 }
532 Type::Scalar(U128) => {
533 let v = v.base10_parse().unwrap();
534 Value::Scalar(Scalar::U128(v))
535 }
536 Type::Scalar(Usize) => {
537 let v = v.base10_parse().unwrap();
538 Value::Scalar(Scalar::Usize(v))
539 }
540 _ => panic!("An integer does not match with the given type: {ty:?}"),
541 },
542 syn::Lit::Float(v) => match ty {
543 Type::Scalar(Float { .. }) => {
544 let v = v.base10_parse().unwrap();
545 Value::Scalar(Scalar::Float(v))
546 }
547 Type::Scalar(F32) => {
548 let v = v.base10_parse().unwrap();
549 Value::Scalar(Scalar::F32(v))
550 }
551 Type::Scalar(F64) => {
552 let v = v.base10_parse().unwrap();
553 Value::Scalar(Scalar::F64(v))
554 }
555 _ => panic!("A floating point does not match with the given type: {ty:?}"),
556 },
557 syn::Lit::Bool(v) => match ty {
558 Type::Scalar(Bool) => {
559 let v = v.value();
560 Value::Scalar(Scalar::Bool(v))
561 }
562 _ => panic!("A boolean does not match with the given type: {ty:?}"),
563 },
564 _ => panic!("not supported yet"),
565 };
566 Ok(value)
567 }
568
569 fn eval_field_value(&mut self, field_value: &syn::FieldValue) -> TriResult<Field<'gcx>, ()> {
570 let name = match &field_value.member {
571 syn::Member::Named(ident) => ident.to_string(),
572 syn::Member::Unnamed(i) => i.index.to_string(),
573 };
574 let value = self.eval_expr(&field_value.expr)?.value;
575 Ok(Field {
576 name: self.gcx.intern_str(&name),
577 value,
578 })
579 }
580
581 fn push_symbol_by_pat(&mut self, pat: &syn::Pat, value: Value<'gcx>) {
582 match pat {
583 syn::Pat::Ident(v) => {
584 let name = self.gcx.intern_str(&v.ident.to_string());
585 self.symbols.push(name, value);
586 }
587 syn::Pat::Type(v) => self.push_symbol_by_pat(&v.pat, value),
588 o => todo!("{o:#?}"),
589 }
590 }
591
592 fn update_symbol_by_expr(&mut self, lhs: &syn::Expr, rhs: Value<'gcx>) {
593 match lhs {
594 syn::Expr::Path(v) => self.update_symbol_by_expr_path(v, rhs),
595 o => todo!("{o:?}"),
596 }
597 }
598
599 fn update_symbol_by_expr_path(&mut self, lhs: &syn::ExprPath, rhs: Value<'gcx>) {
600 assert!(lhs.qself.is_none());
601
602 let lhs = lhs.path.get_ident().unwrap();
603 let name = lhs.to_string();
604 let value = self.symbols.get_mut(&*name).unwrap();
605 *value = rhs;
606 }
607
608 fn apply_to_fn(&mut self, f: Fn, args: &[Value<'gcx>]) -> TriResult<Value<'gcx>, ()> {
610 self.symbols.push_opaque_block();
611
612 match f.inputs {
613 FnInputs::Params(inputs) => {
614 debug_assert_eq!(inputs.len(), args.len());
615 for (arg, value) in inputs.iter().cloned().zip(args) {
616 let arg = unsafe { arg.as_ref().unwrap() };
617 match arg {
618 syn::FnArg::Receiver(_) => todo!(),
619 syn::FnArg::Typed(v) => self.push_symbol_by_pat(&v.pat, value.clone()),
620 }
621 }
622 }
623 }
624
625 let value = match f.body {
626 FnBody::Block(block) => {
627 let block = unsafe { block.as_ref().unwrap() };
628 self.eval_block(block)
629 }
630 };
631
632 self.symbols.pop_block();
633 value
634 }
635
636 fn apply_to_constructor<I>(&mut self, mut field_names: I, args: &[Value<'gcx>]) -> Value<'gcx>
637 where
638 I: Iterator<Item = Interned<'gcx, str>>,
639 {
640 let mut fields = Vec::new();
641 let mut args = args.iter();
642
643 while let (Some(field_name), Some(arg)) = (field_names.next(), args.next()) {
644 fields.push(Field {
645 name: field_name,
646 value: arg.clone(),
647 });
648 }
649
650 assert!(field_names.next().is_none());
651 assert!(args.next().is_none());
652
653 Value::Composed(fields)
654 }
655}
656
657#[derive(Debug, Default)]
658pub struct Evaluated<'gcx> {
659 mapped_values: Vec<Value<'gcx>>,
661
662 ptr_map: Map<*const syn::Expr, usize>,
664
665 pid_map: Map<PathId, usize>,
667}
668
669impl<'gcx> Evaluated<'gcx> {
670 pub(crate) fn new() -> Self {
671 Self {
672 mapped_values: Vec::new(),
673 ptr_map: Map::default(),
674 pid_map: Map::default(),
675 }
676 }
677
678 pub fn get_mapped_value_by_expr_ptr(&self, ptr: *const syn::Expr) -> Option<&Value<'gcx>> {
679 self.ptr_map
680 .get(&ptr)
681 .map(|index| &self.mapped_values[*index])
682 }
683
684 pub fn get_mapped_value_by_path_id(&self, pid: PathId) -> Option<&Value<'gcx>> {
685 self.pid_map
686 .get(&pid)
687 .map(|index| &self.mapped_values[*index])
688 }
689
690 pub(crate) fn get_value_by_expr(&self, expr: &syn::Expr) -> Option<&Value<'gcx>> {
691 self.get_mapped_value_by_expr_ptr(expr)
692 }
693
694 pub(crate) fn insert_mapped_value(
698 &mut self,
699 ptr: *const syn::Expr,
700 value: Value<'gcx>,
701 ) -> Option<Value<'gcx>> {
702 match self.ptr_map.entry(ptr) {
703 Entry::Occupied(entry) => {
704 let index = *entry.get();
705 let old_value = mem::replace(&mut self.mapped_values[index], value);
706 Some(old_value)
707 }
708 Entry::Vacant(entry) => {
709 self.mapped_values.push(value);
710 entry.insert(self.mapped_values.len() - 1);
711 None
712 }
713 }
714 }
715
716 pub(crate) fn insert_mapped_value2(
720 &mut self,
721 ptr: *const syn::Expr,
722 pid: PathId,
723 value: Value<'gcx>,
724 ) -> Option<Value<'gcx>> {
725 match (self.ptr_map.entry(ptr), self.pid_map.entry(pid)) {
726 (Entry::Occupied(ptr_entry), Entry::Occupied(pid_entry)) => {
727 debug_assert_eq!(ptr_entry.get(), pid_entry.get());
728 let index = *ptr_entry.get();
729 let old_value = mem::replace(&mut self.mapped_values[index], value);
730 Some(old_value)
731 }
732 (Entry::Occupied(ptr_entry), Entry::Vacant(pid_entry)) => {
733 let index = *ptr_entry.get();
734 pid_entry.insert(index);
735 let old_value = mem::replace(&mut self.mapped_values[index], value);
736 Some(old_value)
737 }
738 (Entry::Vacant(ptr_entry), Entry::Occupied(pid_entry)) => {
739 let index = *pid_entry.get();
740 ptr_entry.insert(index);
741 let old_value = mem::replace(&mut self.mapped_values[index], value);
742 Some(old_value)
743 }
744 (Entry::Vacant(ptr_entry), Entry::Vacant(pid_entry)) => {
745 self.mapped_values.push(value);
746 ptr_entry.insert(self.mapped_values.len() - 1);
747 pid_entry.insert(self.mapped_values.len() - 1);
748 None
749 }
750 }
751 }
752}
753
754#[cfg(test)]
755mod tests {
756 use super::{Evaluator, Host};
757 use crate::{
758 etc::syn::SynPath,
759 semantic::{
760 basic_traits::EvaluateArrayLength,
761 entry::GlobalCx,
762 eval::{
763 test_help::TestEvalHost,
764 value::{Fn, Scalar, Value},
765 },
766 infer::{
767 self,
768 test_help::{test_inferer, TestInferLogicHost},
769 Inferer,
770 },
771 logic::{self, test_help::test_logic, Logic},
772 },
773 Intern, Result, TriResult, TriResultHelper,
774 };
775 use logic_eval_util::str::StrPath;
776 use syn_locator::{Find, LocateEntry};
777
778 fn parse(code: &str) -> syn::Expr {
779 syn_locator::enable_thread_local(true);
780 syn_locator::clear();
781
782 let expr: syn::Expr = syn::parse_str(code).unwrap();
783 let pinned = std::pin::Pin::new(&expr);
784 pinned.locate_as_entry("mod.rs", code).unwrap();
785 expr
786 }
787
788 #[test]
789 fn test_eval_operators() {
790 fn eval<'gcx, H: infer::Host<'gcx> + logic::Host<'gcx>>(
791 inferer: &mut Inferer<'gcx>,
792 evaluator: &mut Evaluator<'gcx>,
793 logic: &mut Logic<'gcx>,
794 infer_logic_host: &mut H,
795 expr: &syn::Expr,
796 ) -> Result<Value<'gcx>> {
797 inferer
798 .infer_expr(logic, infer_logic_host, expr, None)
799 .elevate_err()?;
800 let mut eval_host = TestEvalHost::new(inferer);
801 evaluator.eval_expr(&mut eval_host, expr).elevate_err()
802 }
803
804 let gcx = GlobalCx::default();
805 let mut inferer = test_inferer(&gcx);
806 let mut evaluator = Evaluator::new(&gcx);
807 let mut logic = test_logic(&gcx);
808 let mut infer_logic_host = TestInferLogicHost::new(&gcx);
809
810 let expr = parse("{ 1 + 2 }");
812 let value = eval(
813 &mut inferer,
814 &mut evaluator,
815 &mut logic,
816 &mut infer_logic_host,
817 &expr,
818 )
819 .unwrap();
820 assert_eq!(value, Value::Scalar(Scalar::Int(1 + 2)));
821
822 let expr = parse("{ 3 - 2 }");
824 let value = eval(
825 &mut inferer,
826 &mut evaluator,
827 &mut logic,
828 &mut infer_logic_host,
829 &expr,
830 )
831 .unwrap();
832 assert_eq!(value, Value::Scalar(Scalar::Int(3 - 2)));
833
834 let expr = parse("{ 2 * 3 }");
836 let value = eval(
837 &mut inferer,
838 &mut evaluator,
839 &mut logic,
840 &mut infer_logic_host,
841 &expr,
842 )
843 .unwrap();
844 assert_eq!(value, Value::Scalar(Scalar::Int(2 * 3)));
845
846 let expr = parse("{ 6 / 3 }");
848 let value = eval(
849 &mut inferer,
850 &mut evaluator,
851 &mut logic,
852 &mut infer_logic_host,
853 &expr,
854 )
855 .unwrap();
856 assert_eq!(value, Value::Scalar(Scalar::Int(6 / 3)));
857
858 let expr = parse("{ 3 % 2 }");
860 let value = eval(
861 &mut inferer,
862 &mut evaluator,
863 &mut logic,
864 &mut infer_logic_host,
865 &expr,
866 )
867 .unwrap();
868 assert_eq!(value, Value::Scalar(Scalar::Int(3 % 2)));
869
870 let expr = parse("{ 1 ^ 2 }");
872 let value = eval(
873 &mut inferer,
874 &mut evaluator,
875 &mut logic,
876 &mut infer_logic_host,
877 &expr,
878 )
879 .unwrap();
880 assert_eq!(value, Value::Scalar(Scalar::Int(1 ^ 2)));
881
882 let expr = parse("{ 1 & 2 }");
884 let value = eval(
885 &mut inferer,
886 &mut evaluator,
887 &mut logic,
888 &mut infer_logic_host,
889 &expr,
890 )
891 .unwrap();
892 assert_eq!(value, Value::Scalar(Scalar::Int(1 & 2)));
893
894 let expr = parse("{ 1 | 2 }");
896 let value = eval(
897 &mut inferer,
898 &mut evaluator,
899 &mut logic,
900 &mut infer_logic_host,
901 &expr,
902 )
903 .unwrap();
904 assert_eq!(value, Value::Scalar(Scalar::Int(1 | 2)));
905
906 let expr = parse("{ 1 << 2 }");
908 let value = eval(
909 &mut inferer,
910 &mut evaluator,
911 &mut logic,
912 &mut infer_logic_host,
913 &expr,
914 )
915 .unwrap();
916 assert_eq!(value, Value::Scalar(Scalar::Int(1 << 2)));
917
918 let expr = parse("{ 4 >> 2 }");
920 let value = eval(
921 &mut inferer,
922 &mut evaluator,
923 &mut logic,
924 &mut infer_logic_host,
925 &expr,
926 )
927 .unwrap();
928 assert_eq!(value, Value::Scalar(Scalar::Int(4 >> 2)));
929
930 let expr = parse("{ let mut a = 1; a += 2; a }");
932 let value = eval(
933 &mut inferer,
934 &mut evaluator,
935 &mut logic,
936 &mut infer_logic_host,
937 &expr,
938 )
939 .unwrap();
940 assert_eq!(value, Value::Scalar(Scalar::Int(1 + 2)));
941
942 let expr = parse("{ let mut a = 3; a -= 2; a }");
944 let value = eval(
945 &mut inferer,
946 &mut evaluator,
947 &mut logic,
948 &mut infer_logic_host,
949 &expr,
950 )
951 .unwrap();
952 assert_eq!(value, Value::Scalar(Scalar::Int(3 - 2)));
953
954 let expr = parse("{ let mut a = 2; a *= 3; a }");
956 let value = eval(
957 &mut inferer,
958 &mut evaluator,
959 &mut logic,
960 &mut infer_logic_host,
961 &expr,
962 )
963 .unwrap();
964 assert_eq!(value, Value::Scalar(Scalar::Int(2 * 3)));
965
966 let expr = parse("{ let mut a = 6; a /= 3; a }");
968 let value = eval(
969 &mut inferer,
970 &mut evaluator,
971 &mut logic,
972 &mut infer_logic_host,
973 &expr,
974 )
975 .unwrap();
976 assert_eq!(value, Value::Scalar(Scalar::Int(6 / 3)));
977
978 let expr = parse("{ let mut a = 3; a %= 2; a }");
980 let value = eval(
981 &mut inferer,
982 &mut evaluator,
983 &mut logic,
984 &mut infer_logic_host,
985 &expr,
986 )
987 .unwrap();
988 assert_eq!(value, Value::Scalar(Scalar::Int(3 % 2)));
989
990 let expr = parse("{ let mut a = 1; a ^= 2; a }");
992 let value = eval(
993 &mut inferer,
994 &mut evaluator,
995 &mut logic,
996 &mut infer_logic_host,
997 &expr,
998 )
999 .unwrap();
1000 assert_eq!(value, Value::Scalar(Scalar::Int(1 ^ 2)));
1001
1002 let expr = parse("{ let mut a = 1; a &= 2; a }");
1004 let value = eval(
1005 &mut inferer,
1006 &mut evaluator,
1007 &mut logic,
1008 &mut infer_logic_host,
1009 &expr,
1010 )
1011 .unwrap();
1012 assert_eq!(value, Value::Scalar(Scalar::Int(1 & 2)));
1013
1014 let expr = parse("{ let mut a = 1; a |= 2; a }");
1016 let value = eval(
1017 &mut inferer,
1018 &mut evaluator,
1019 &mut logic,
1020 &mut infer_logic_host,
1021 &expr,
1022 )
1023 .unwrap();
1024 assert_eq!(value, Value::Scalar(Scalar::Int(1 | 2)));
1025
1026 let expr = parse("{ let mut a = 1; a <<= 2; a }");
1028 let value = eval(
1029 &mut inferer,
1030 &mut evaluator,
1031 &mut logic,
1032 &mut infer_logic_host,
1033 &expr,
1034 )
1035 .unwrap();
1036 assert_eq!(value, Value::Scalar(Scalar::Int(1 << 2)));
1037
1038 let expr = parse("{ let mut a = 4; a >>= 2; a }");
1040 let value = eval(
1041 &mut inferer,
1042 &mut evaluator,
1043 &mut logic,
1044 &mut infer_logic_host,
1045 &expr,
1046 )
1047 .unwrap();
1048 assert_eq!(value, Value::Scalar(Scalar::Int(4 >> 2)));
1049
1050 let expr = parse("{ !false }");
1052 let value = eval(
1053 &mut inferer,
1054 &mut evaluator,
1055 &mut logic,
1056 &mut infer_logic_host,
1057 &expr,
1058 )
1059 .unwrap();
1060 assert_eq!(value, Value::Scalar(Scalar::Bool(true)));
1061
1062 let expr = parse("{ let mut a = 1; a = -a; a }");
1064 let value = eval(
1065 &mut inferer,
1066 &mut evaluator,
1067 &mut logic,
1068 &mut infer_logic_host,
1069 &expr,
1070 )
1071 .unwrap();
1072 assert_eq!(value, Value::Scalar(Scalar::Int(-1)));
1073
1074 let expr = parse("{ 1 + 2 * 3 + 4 * 5 }");
1077 let value = eval(
1078 &mut inferer,
1079 &mut evaluator,
1080 &mut logic,
1081 &mut infer_logic_host,
1082 &expr,
1083 )
1084 .unwrap();
1085 assert_eq!(value, Value::Scalar(Scalar::Int(1 + 2 * 3 + 4 * 5)));
1086 let expr = parse("{ (1 + 2) * 3 + 4 * 5 }");
1087 let value = eval(
1088 &mut inferer,
1089 &mut evaluator,
1090 &mut logic,
1091 &mut infer_logic_host,
1092 &expr,
1093 )
1094 .unwrap();
1095 assert_eq!(value, Value::Scalar(Scalar::Int((1 + 2) * 3 + 4 * 5)));
1096 }
1097
1098 #[test]
1099 fn test_eval_function_call() {
1100 let code = r#"{
1101 fn f(x: i32) -> i32 { x * 2 }
1102 f(3)
1103 }"#;
1104
1105 struct TestEvalHost<'a, 'gcx> {
1106 inferer: &'a mut Inferer<'gcx>,
1107 expr: &'a syn::Expr,
1108 }
1109
1110 impl<'gcx> Host<'gcx> for TestEvalHost<'_, 'gcx> {
1111 fn find_type(&mut self, expr: &syn::Expr) -> TriResult<infer::Type<'gcx>, ()> {
1112 let ty = self.inferer.get_type(expr).unwrap().clone();
1113 Ok(ty)
1114 }
1115
1116 fn find_fn(&mut self, _: StrPath, _: &[infer::Type<'gcx>]) -> Fn {
1117 panic!()
1118 }
1119
1120 fn syn_path_to_value(&mut self, path: SynPath) -> TriResult<Value<'gcx>, ()> {
1121 let ident = path.path.get_ident().unwrap().to_string();
1122 if ident == "f" {
1123 let code = "fn f(x: i32) -> i32 { x * 2 }";
1124 let f: &syn::ItemFn = self.expr.find(code).unwrap();
1125 let f = Fn::from_signature_and_block(&f.sig, &f.block);
1126 Ok(Value::Fn(f))
1127 } else {
1128 unreachable!()
1129 }
1130 }
1131 }
1132
1133 crate::impl_empty_scoping!(TestEvalHost<'_, '_>);
1134
1135 struct TestInferHost<'gcx> {
1136 gcx: &'gcx GlobalCx<'gcx>,
1137 }
1138
1139 impl<'gcx> infer::Host<'gcx> for TestInferHost<'gcx> {
1140 fn syn_path_to_type(
1141 &mut self,
1142 _: SynPath,
1143 types: &mut infer::UniqueTypes,
1144 ) -> TriResult<infer::Type<'gcx>, ()> {
1145 use infer::{Param, Type, TypeScalar};
1146
1147 let tid_i32 = types.insert_type(Type::Scalar(TypeScalar::I32));
1148
1149 let res = infer::Type::Named(infer::TypeNamed {
1150 name: self.gcx.intern_str("f"),
1151 params: [
1152 Param::Other {
1153 name: self.gcx.intern_str("0"),
1154 tid: tid_i32,
1155 },
1156 Param::Other {
1157 name: self.gcx.intern_str("1"),
1158 tid: tid_i32,
1159 },
1160 ]
1161 .into(),
1162 });
1163 Ok(res)
1164 }
1165 }
1166
1167 impl<'gcx> EvaluateArrayLength<'gcx> for TestInferHost<'gcx> {
1168 fn eval_array_len(&mut self, _: &syn::Expr) -> TriResult<crate::ArrayLen, ()> {
1169 unreachable!()
1170 }
1171 }
1172
1173 crate::impl_empty_scoping!(TestInferHost<'_>);
1174 crate::impl_empty_method_host!(TestInferHost<'_>);
1175
1176 let gcx = GlobalCx::default();
1177 let mut inferer = test_inferer(&gcx);
1178 let mut evaluator = Evaluator::new(&gcx);
1179 let mut logic = test_logic(&gcx);
1180 let mut infer_logic_host = TestInferLogicHost::new(&gcx);
1181 infer_logic_host.override_infer_host(TestInferHost { gcx: &gcx });
1182
1183 let expr = parse(code);
1184 inferer
1185 .infer_expr(&mut logic, &mut infer_logic_host, &expr, None)
1186 .unwrap();
1187 let mut eval_host = TestEvalHost {
1188 inferer: &mut inferer,
1189 expr: &expr,
1190 };
1191 let value = evaluator.eval_expr(&mut eval_host, &expr).unwrap();
1192
1193 assert_eq!(value, Value::Scalar(Scalar::I32(3 * 2)));
1194 }
1195}