Skip to main content

syn_sem/semantic/eval/
proc.rs

1use super::value::{Field, Fn, FnBody, FnInputs, Scalar, Value};
2use crate::{
3    etc::{
4        known,
5        syn::{SynPath, SynPathKind},
6    },
7    semantic::{
8        basic_traits::{RawScope, Scope, Scoping},
9        entry::GlobalCx,
10        infer,
11        tree::PathId,
12    },
13    Intern, Map, Result, TriResult,
14};
15use any_intern::Interned;
16use logic_eval_util::{str::StrPath, symbol::SymbolTable};
17use std::{collections::hash_map::Entry, mem};
18
19struct ValueWithCtrl<'gcx> {
20    value: Value<'gcx>,
21    is_return: bool,
22}
23
24impl<'gcx> From<Value<'gcx>> for ValueWithCtrl<'gcx> {
25    fn from(value: Value<'gcx>) -> Self {
26        ValueWithCtrl {
27            value,
28            is_return: false,
29        }
30    }
31}
32
33// === Host ===
34
35#[allow(unused_variables)]
36pub(crate) trait Host<'gcx>: Scoping {
37    fn find_type(&mut self, expr: &syn::Expr) -> TriResult<infer::Type<'gcx>, ()>;
38    fn find_fn(&mut self, name: StrPath, types: &[infer::Type<'gcx>]) -> Fn;
39    fn syn_path_to_value(&mut self, syn_path: SynPath) -> TriResult<Value<'gcx>, ()>;
40}
41
42struct HostWrapper<'a, H> {
43    inner: &'a mut H,
44    scope_stack: Vec<RawScope>,
45}
46
47impl<'a, 'gcx, H: Host<'gcx>> HostWrapper<'a, H> {
48    fn new(host: &'a mut H) -> Self {
49        Self {
50            inner: host,
51            scope_stack: Vec::new(),
52        }
53    }
54
55    fn eval_known_fn(&mut self, abs_path: &str, values: &[Value<'gcx>]) -> Option<Value<'gcx>> {
56        use known::apply;
57        use once_cell::sync::OnceCell;
58
59        type F = for<'a> fn(&[Value<'a>]) -> Result<Value<'a>>;
60
61        static FMAP: OnceCell<Map<&'static str, F>> = OnceCell::new();
62
63        let fmap = FMAP.get_or_init(|| {
64            let mut map: Map<&'static str, F> = Map::default();
65
66            map.insert(apply::NAME_ADD, |values: &[Value<'_>]| {
67                debug_assert_eq!(values.len(), 2);
68                values[0].try_add(&values[1])
69            });
70            map.insert(apply::NAME_SUB, |values: &[Value<'_>]| {
71                debug_assert_eq!(values.len(), 2);
72                values[0].try_sub(&values[1])
73            });
74            map.insert(apply::NAME_MUL, |values: &[Value<'_>]| {
75                debug_assert_eq!(values.len(), 2);
76                values[0].try_mul(&values[1])
77            });
78            map.insert(apply::NAME_DIV, |values: &[Value<'_>]| {
79                debug_assert_eq!(values.len(), 2);
80                values[0].try_div(&values[1])
81            });
82            map.insert(apply::NAME_REM, |values: &[Value<'_>]| {
83                debug_assert_eq!(values.len(), 2);
84                values[0].try_rem(&values[1])
85            });
86            map.insert(apply::NAME_BIT_XOR, |values: &[Value<'_>]| {
87                debug_assert_eq!(values.len(), 2);
88                values[0].try_bit_xor(&values[1])
89            });
90            map.insert(apply::NAME_BIT_AND, |values: &[Value<'_>]| {
91                debug_assert_eq!(values.len(), 2);
92                values[0].try_bit_and(&values[1])
93            });
94            map.insert(apply::NAME_BIT_OR, |values: &[Value<'_>]| {
95                debug_assert_eq!(values.len(), 2);
96                values[0].try_bit_or(&values[1])
97            });
98            map.insert(apply::NAME_SHL, |values: &[Value<'_>]| {
99                debug_assert_eq!(values.len(), 2);
100                values[0].try_shl(&values[1])
101            });
102            map.insert(apply::NAME_SHR, |values: &[Value<'_>]| {
103                debug_assert_eq!(values.len(), 2);
104                values[0].try_shr(&values[1])
105            });
106            map.insert(apply::NAME_ADD_ASSIGN, |values: &[Value<'_>]| {
107                debug_assert_eq!(values.len(), 2);
108                values[0].try_add(&values[1]) // Add instead of AddAssign
109            });
110            map.insert(apply::NAME_SUB_ASSIGN, |values: &[Value<'_>]| {
111                debug_assert_eq!(values.len(), 2);
112                values[0].try_sub(&values[1]) // Sub instead of SubAssign
113            });
114            map.insert(apply::NAME_MUL_ASSIGN, |values: &[Value<'_>]| {
115                debug_assert_eq!(values.len(), 2);
116                values[0].try_mul(&values[1]) // Mul instead of MulAssign
117            });
118            map.insert(apply::NAME_DIV_ASSIGN, |values: &[Value<'_>]| {
119                debug_assert_eq!(values.len(), 2);
120                values[0].try_div(&values[1]) // Div instead of DivAssign
121            });
122            map.insert(apply::NAME_REM_ASSIGN, |values: &[Value<'_>]| {
123                debug_assert_eq!(values.len(), 2);
124                values[0].try_rem(&values[1]) // Rem instead of RemAssign
125            });
126            map.insert(apply::NAME_BIT_XOR_ASSIGN, |values: &[Value<'_>]| {
127                debug_assert_eq!(values.len(), 2);
128                values[0].try_bit_xor(&values[1]) // BitXor instead of BitXorAssign
129            });
130            map.insert(apply::NAME_BIT_AND_ASSIGN, |values: &[Value<'_>]| {
131                debug_assert_eq!(values.len(), 2);
132                values[0].try_bit_and(&values[1]) // BitAnd instead of BitAndAssign
133            });
134            map.insert(apply::NAME_BIT_OR_ASSIGN, |values: &[Value<'_>]| {
135                debug_assert_eq!(values.len(), 2);
136                values[0].try_bit_or(&values[1]) // BitOr instead of BitOrAssign
137            });
138            map.insert(apply::NAME_SHL_ASSIGN, |values: &[Value<'_>]| {
139                debug_assert_eq!(values.len(), 2);
140                values[0].try_shl(&values[1]) // Shl instead of ShlAssign
141            });
142            map.insert(apply::NAME_SHR_ASSIGN, |values: &[Value<'_>]| {
143                debug_assert_eq!(values.len(), 2);
144                values[0].try_shr(&values[1]) // Shr instead of ShrAssign
145            });
146            map.insert(apply::NAME_NOT, |values: &[Value<'_>]| {
147                debug_assert_eq!(values.len(), 1);
148                values[0].try_not()
149            });
150            map.insert(apply::NAME_NEG, |values: &[Value<'_>]| {
151                debug_assert_eq!(values.len(), 1);
152                values[0].try_neg()
153            });
154
155            // TODO: Deref
156
157            map
158        });
159
160        let f = fmap.get(&abs_path).cloned()?;
161        f(values).ok()
162    }
163
164    fn on_enter_scope(&mut self, scope: Scope) {
165        self.inner.on_enter_scope(scope);
166        self.scope_stack.push(scope.into_raw());
167    }
168
169    fn on_exit_scope(&mut self) {
170        let raw_scope = self.scope_stack.pop().unwrap();
171        let exit_scope = Scope::from_raw(raw_scope);
172        self.inner.on_exit_scope(exit_scope);
173
174        if let Some(raw_scope) = self.scope_stack.last() {
175            let reenter_scope = Scope::from_raw(*raw_scope);
176            self.inner.on_enter_scope(reenter_scope);
177        }
178    }
179}
180
181impl<'gcx, H: Host<'gcx>> Host<'gcx> for HostWrapper<'_, H> {
182    fn find_type(&mut self, expr: &syn::Expr) -> TriResult<infer::Type<'gcx>, ()> {
183        self.inner.find_type(expr)
184    }
185
186    fn find_fn(&mut self, name: StrPath, types: &[infer::Type<'gcx>]) -> Fn {
187        self.inner.find_fn(name, types)
188    }
189
190    fn syn_path_to_value(&mut self, syn_path: SynPath) -> TriResult<Value<'gcx>, ()> {
191        self.inner.syn_path_to_value(syn_path)
192    }
193}
194
195impl<'gcx, H: Host<'gcx>> Scoping for HostWrapper<'_, H> {
196    fn on_enter_scope(&mut self, scope: Scope) {
197        <Self>::on_enter_scope(self, scope)
198    }
199
200    fn on_exit_scope(&mut self, _: Scope) {
201        <Self>::on_exit_scope(self)
202    }
203}
204
205// === Evaluator ===
206
207#[derive(Debug)]
208pub(crate) struct Evaluator<'gcx> {
209    gcx: &'gcx GlobalCx<'gcx>,
210    symbols: SymbolTable<Interned<'gcx, str>, Value<'gcx>>,
211}
212
213impl<'gcx> Evaluator<'gcx> {
214    pub(crate) fn new(gcx: &'gcx GlobalCx<'gcx>) -> Self {
215        Self {
216            gcx,
217            symbols: SymbolTable::default(),
218        }
219    }
220
221    pub(crate) fn eval_expr<H: Host<'gcx>>(
222        &mut self,
223        host: &mut H,
224        expr: &syn::Expr,
225    ) -> TriResult<Value<'gcx>, ()> {
226        self.symbols.clear();
227
228        let mut cx = EvalCx {
229            gcx: self.gcx,
230            symbols: &mut self.symbols,
231            host: HostWrapper::new(host),
232        };
233
234        cx.eval_expr(expr).map(|ex| ex.value)
235    }
236}
237
238struct EvalCx<'a, 'gcx, H> {
239    gcx: &'gcx GlobalCx<'gcx>,
240    symbols: &'a mut SymbolTable<Interned<'gcx, str>, Value<'gcx>>,
241    host: HostWrapper<'a, H>,
242}
243
244impl<'a, 'gcx, H: Host<'gcx>> EvalCx<'a, 'gcx, H> {
245    fn eval_expr(&mut self, expr: &syn::Expr) -> TriResult<ValueWithCtrl<'gcx>, ()> {
246        match expr {
247            syn::Expr::Array(v) => self.eval_expr_array(v).map(ValueWithCtrl::from),
248            syn::Expr::Assign(v) => self.eval_expr_assign(v).map(ValueWithCtrl::from),
249            syn::Expr::Async(_) => panic!("`async` is not supported"),
250            syn::Expr::Await(_) => panic!("`await` is not supported"),
251            syn::Expr::Binary(v) => self.eval_expr_binary(v).map(ValueWithCtrl::from),
252            syn::Expr::Block(v) => self.eval_block(&v.block).map(ValueWithCtrl::from),
253            syn::Expr::Break(v) => todo!("{v:#?}"),
254            syn::Expr::Call(v) => self.eval_expr_call(v).map(ValueWithCtrl::from),
255            syn::Expr::Cast(v) => self.eval_expr(&v.expr),
256            syn::Expr::Closure(_) => panic!("`closure` is not supported"),
257            syn::Expr::Const(v) => self.eval_block(&v.block).map(ValueWithCtrl::from),
258            syn::Expr::Let(v) => todo!("{v:#?}"),
259            syn::Expr::Lit(v) => self.eval_lit(&v.lit, expr).map(ValueWithCtrl::from),
260            syn::Expr::Loop(v) => todo!("{v:#?}"),
261            syn::Expr::Macro(_) => Ok(Value::Unit.into()),
262            syn::Expr::Match(v) => todo!("{v:#?}"),
263            syn::Expr::MethodCall(v) => todo!("{v:#?}"),
264            syn::Expr::Paren(v) => self.eval_expr_paren(v),
265            syn::Expr::Path(v) => self.eval_expr_path(v).map(ValueWithCtrl::from),
266            syn::Expr::Range(v) => todo!("{v:#?}"),
267            syn::Expr::RawAddr(v) => todo!("{v:#?}"),
268            syn::Expr::Reference(v) => todo!("{v:#?}"),
269            syn::Expr::Repeat(v) => todo!("{v:#?}"),
270            syn::Expr::Return(v) => todo!("{v:#?}"),
271            syn::Expr::Struct(v) => self.eval_expr_struct(v).map(ValueWithCtrl::from),
272            syn::Expr::Try(v) => todo!("{v:#?}"),
273            syn::Expr::TryBlock(v) => todo!("{v:#?}"),
274            syn::Expr::Tuple(v) => todo!("{v:#?}"),
275            syn::Expr::Unary(v) => self.eval_expr_unary(v).map(ValueWithCtrl::from),
276            syn::Expr::Unsafe(v) => todo!("{v:#?}"),
277            syn::Expr::Verbatim(v) => todo!("{v:#?}"),
278            syn::Expr::While(v) => todo!("{v:#?}"),
279            syn::Expr::Yield(v) => todo!("{v:#?}"),
280            _ => todo!(),
281        }
282    }
283
284    fn eval_expr_array(&mut self, expr_arr: &syn::ExprArray) -> TriResult<Value<'gcx>, ()> {
285        let fields = expr_arr
286            .elems
287            .iter()
288            .enumerate()
289            .map(|(i, elem)| {
290                self.eval_expr(elem).map(|ex| Field {
291                    name: self.gcx.intern_str(&i.to_string()),
292                    value: ex.value,
293                })
294            })
295            .collect::<TriResult<Vec<Field<'gcx>>, ()>>()?;
296        Ok(Value::Composed(fields))
297    }
298
299    fn eval_expr_assign(&mut self, expr_assign: &syn::ExprAssign) -> TriResult<Value<'gcx>, ()> {
300        let rv = self.eval_expr(&expr_assign.right)?.value;
301        self.update_symbol_by_expr(&expr_assign.left, rv);
302        Ok(Value::Unit)
303    }
304
305    fn eval_expr_binary(&mut self, expr_bin: &syn::ExprBinary) -> TriResult<Value<'gcx>, ()> {
306        use known::apply::*;
307
308        return match expr_bin.op {
309            syn::BinOp::Add(_) => bin(self, expr_bin, NAME_ADD),
310            syn::BinOp::Sub(_) => bin(self, expr_bin, NAME_SUB),
311            syn::BinOp::Mul(_) => bin(self, expr_bin, NAME_MUL),
312            syn::BinOp::Div(_) => bin(self, expr_bin, NAME_DIV),
313            syn::BinOp::Rem(_) => bin(self, expr_bin, NAME_REM),
314            syn::BinOp::BitXor(_) => bin(self, expr_bin, NAME_BIT_XOR),
315            syn::BinOp::BitAnd(_) => bin(self, expr_bin, NAME_BIT_AND),
316            syn::BinOp::BitOr(_) => bin(self, expr_bin, NAME_BIT_OR),
317            syn::BinOp::Shl(_) => bin(self, expr_bin, NAME_SHL),
318            syn::BinOp::Shr(_) => bin(self, expr_bin, NAME_SHR),
319            syn::BinOp::AddAssign(_) => bin_assign(self, expr_bin, NAME_ADD_ASSIGN),
320            syn::BinOp::SubAssign(_) => bin_assign(self, expr_bin, NAME_SUB_ASSIGN),
321            syn::BinOp::MulAssign(_) => bin_assign(self, expr_bin, NAME_MUL_ASSIGN),
322            syn::BinOp::DivAssign(_) => bin_assign(self, expr_bin, NAME_DIV_ASSIGN),
323            syn::BinOp::RemAssign(_) => bin_assign(self, expr_bin, NAME_REM_ASSIGN),
324            syn::BinOp::BitXorAssign(_) => bin_assign(self, expr_bin, NAME_BIT_XOR_ASSIGN),
325            syn::BinOp::BitAndAssign(_) => bin_assign(self, expr_bin, NAME_BIT_AND_ASSIGN),
326            syn::BinOp::BitOrAssign(_) => bin_assign(self, expr_bin, NAME_BIT_OR_ASSIGN),
327            syn::BinOp::ShlAssign(_) => bin_assign(self, expr_bin, NAME_SHL_ASSIGN),
328            syn::BinOp::ShrAssign(_) => bin_assign(self, expr_bin, NAME_SHR_ASSIGN),
329            _ => unreachable!(),
330        };
331
332        // === Internal helper functions ===
333
334        fn bin<'gcx, H: Host<'gcx>>(
335            this: &mut EvalCx<'_, 'gcx, H>,
336            expr_bin: &syn::ExprBinary,
337            name: &str,
338        ) -> TriResult<Value<'gcx>, ()> {
339            let lv = this.eval_expr(&expr_bin.left)?.value;
340            let rv = this.eval_expr(&expr_bin.right)?.value;
341            let values = [lv, rv];
342            if let Some(res) = this.host.eval_known_fn(name, &values) {
343                return Ok(res);
344            }
345
346            let lty = this.host.find_type(&expr_bin.left)?;
347            let rty = this.host.find_type(&expr_bin.right)?;
348            let f = this.host.find_fn(StrPath::absolute(name), &[lty, rty]);
349            this.apply_to_fn(f, &values)
350        }
351
352        fn bin_assign<'gcx, H: Host<'gcx>>(
353            this: &mut EvalCx<'_, 'gcx, H>,
354            expr_bin: &syn::ExprBinary,
355            name: &str,
356        ) -> TriResult<Value<'gcx>, ()> {
357            let value = bin(this, expr_bin, name)?;
358            this.update_symbol_by_expr(&expr_bin.left, value);
359            Ok(Value::Unit)
360        }
361    }
362
363    fn eval_expr_call(&mut self, expr_call: &syn::ExprCall) -> TriResult<Value<'gcx>, ()> {
364        let args = expr_call
365            .args
366            .iter()
367            .map(|arg| self.eval_expr(arg).map(|ex| ex.value))
368            .collect::<TriResult<Vec<_>, ()>>()?;
369
370        match self.eval_expr(&expr_call.func)?.value {
371            // Ordinary function call
372            Value::Fn(f) => self.apply_to_fn(f, &args),
373            // Constructor
374            Value::Composed(fields) => {
375                let field_names = fields.into_iter().map(|field| field.name);
376                let value = self.apply_to_constructor(field_names, &args);
377                Ok(value)
378            }
379            _ => unreachable!(),
380        }
381    }
382
383    fn eval_expr_paren(
384        &mut self,
385        expr_paren: &syn::ExprParen,
386    ) -> TriResult<ValueWithCtrl<'gcx>, ()> {
387        self.eval_expr(&expr_paren.expr)
388    }
389
390    fn eval_expr_path(&mut self, expr_path: &syn::ExprPath) -> TriResult<Value<'gcx>, ()> {
391        if expr_path.qself.is_none() {
392            if let Some(ident) = expr_path.path.get_ident() {
393                if let Some(value) = self.symbols.get(&*ident.to_string()) {
394                    return Ok(value.clone());
395                }
396            }
397        }
398
399        let syn_path = SynPath {
400            kind: SynPathKind::Expr,
401            qself: expr_path.qself.as_ref(),
402            path: &expr_path.path,
403        };
404        self.host.syn_path_to_value(syn_path)
405    }
406
407    fn eval_expr_struct(&mut self, expr_struct: &syn::ExprStruct) -> TriResult<Value<'gcx>, ()> {
408        let fields = expr_struct
409            .fields
410            .iter()
411            .map(|field| self.eval_field_value(field))
412            .collect::<TriResult<Vec<Field>, ()>>()?;
413        Ok(Value::Composed(fields))
414    }
415
416    fn eval_expr_unary(&mut self, expr_unary: &syn::ExprUnary) -> TriResult<Value<'gcx>, ()> {
417        use known::apply::*;
418
419        let name = match expr_unary.op {
420            syn::UnOp::Deref(_) => todo!(),
421            syn::UnOp::Not(_) => NAME_NOT,
422            syn::UnOp::Neg(_) => NAME_NEG,
423            _ => unreachable!(),
424        };
425
426        let v = self.eval_expr(&expr_unary.expr)?.value;
427        let values = [v];
428        if let Some(res) = self.host.eval_known_fn(name, &values) {
429            return Ok(res);
430        }
431
432        let ty = self.host.find_type(&expr_unary.expr)?;
433        let f = self.host.find_fn(StrPath::absolute(name), &[ty]);
434        self.apply_to_fn(f, &values)
435    }
436
437    fn eval_block(&mut self, block: &syn::Block) -> TriResult<Value<'gcx>, ()> {
438        self.host.on_enter_scope(Scope::Block(block));
439        self.symbols.push_transparent_block();
440
441        let mut last_value = Value::Unit;
442        for stmt in &block.stmts {
443            let ValueWithCtrl {
444                value, is_return, ..
445            } = self.eval_stmt(stmt)?;
446            last_value = value;
447            if is_return {
448                break;
449            }
450        }
451
452        self.symbols.pop_block();
453        self.host.on_exit_scope();
454        Ok(last_value)
455    }
456
457    fn eval_stmt(&mut self, stmt: &syn::Stmt) -> TriResult<ValueWithCtrl<'gcx>, ()> {
458        let value = match stmt {
459            syn::Stmt::Local(v) => {
460                self.eval_local(v)?;
461                Value::Unit.into()
462            }
463            syn::Stmt::Item(_) => Value::Unit.into(),
464            syn::Stmt::Expr(v, _) => self.eval_expr(v)?,
465            syn::Stmt::Macro(_) => Value::Unit.into(),
466        };
467        Ok(value)
468    }
469
470    fn eval_local(&mut self, local: &syn::Local) -> TriResult<(), ()> {
471        // Evaluates rhs first due to the shadowing.
472        let rhs = local
473            .init
474            .as_ref()
475            .map(|init| self.eval_expr(&init.expr).map(|ex| ex.value))
476            .unwrap_or(Ok(Value::Unit))?;
477        self.push_symbol_by_pat(&local.pat, rhs);
478        Ok(())
479    }
480
481    fn eval_lit(&mut self, lit: &syn::Lit, expr: &syn::Expr) -> TriResult<Value<'gcx>, ()> {
482        use infer::{Type, TypeScalar::*};
483
484        let ty = self.host.find_type(expr)?;
485
486        let value = match lit {
487            syn::Lit::Int(v) => match ty {
488                Type::Scalar(Int { .. }) => {
489                    let v = v.base10_parse().unwrap();
490                    Value::Scalar(Scalar::Int(v))
491                }
492                Type::Scalar(I8) => {
493                    let v = v.base10_parse().unwrap();
494                    Value::Scalar(Scalar::I8(v))
495                }
496                Type::Scalar(I16) => {
497                    let v = v.base10_parse().unwrap();
498                    Value::Scalar(Scalar::I16(v))
499                }
500                Type::Scalar(I32) => {
501                    let v = v.base10_parse().unwrap();
502                    Value::Scalar(Scalar::I32(v))
503                }
504                Type::Scalar(I64) => {
505                    let v = v.base10_parse().unwrap();
506                    Value::Scalar(Scalar::I64(v))
507                }
508                Type::Scalar(I128) => {
509                    let v = v.base10_parse().unwrap();
510                    Value::Scalar(Scalar::I128(v))
511                }
512                Type::Scalar(Isize) => {
513                    let v = v.base10_parse().unwrap();
514                    Value::Scalar(Scalar::Isize(v))
515                }
516                Type::Scalar(U8) => {
517                    let v = v.base10_parse().unwrap();
518                    Value::Scalar(Scalar::U8(v))
519                }
520                Type::Scalar(U16) => {
521                    let v = v.base10_parse().unwrap();
522                    Value::Scalar(Scalar::U16(v))
523                }
524                Type::Scalar(U32) => {
525                    let v = v.base10_parse().unwrap();
526                    Value::Scalar(Scalar::U32(v))
527                }
528                Type::Scalar(U64) => {
529                    let v = v.base10_parse().unwrap();
530                    Value::Scalar(Scalar::U64(v))
531                }
532                Type::Scalar(U128) => {
533                    let v = v.base10_parse().unwrap();
534                    Value::Scalar(Scalar::U128(v))
535                }
536                Type::Scalar(Usize) => {
537                    let v = v.base10_parse().unwrap();
538                    Value::Scalar(Scalar::Usize(v))
539                }
540                _ => panic!("An integer does not match with the given type: {ty:?}"),
541            },
542            syn::Lit::Float(v) => match ty {
543                Type::Scalar(Float { .. }) => {
544                    let v = v.base10_parse().unwrap();
545                    Value::Scalar(Scalar::Float(v))
546                }
547                Type::Scalar(F32) => {
548                    let v = v.base10_parse().unwrap();
549                    Value::Scalar(Scalar::F32(v))
550                }
551                Type::Scalar(F64) => {
552                    let v = v.base10_parse().unwrap();
553                    Value::Scalar(Scalar::F64(v))
554                }
555                _ => panic!("A floating point does not match with the given type: {ty:?}"),
556            },
557            syn::Lit::Bool(v) => match ty {
558                Type::Scalar(Bool) => {
559                    let v = v.value();
560                    Value::Scalar(Scalar::Bool(v))
561                }
562                _ => panic!("A boolean does not match with the given type: {ty:?}"),
563            },
564            _ => panic!("not supported yet"),
565        };
566        Ok(value)
567    }
568
569    fn eval_field_value(&mut self, field_value: &syn::FieldValue) -> TriResult<Field<'gcx>, ()> {
570        let name = match &field_value.member {
571            syn::Member::Named(ident) => ident.to_string(),
572            syn::Member::Unnamed(i) => i.index.to_string(),
573        };
574        let value = self.eval_expr(&field_value.expr)?.value;
575        Ok(Field {
576            name: self.gcx.intern_str(&name),
577            value,
578        })
579    }
580
581    fn push_symbol_by_pat(&mut self, pat: &syn::Pat, value: Value<'gcx>) {
582        match pat {
583            syn::Pat::Ident(v) => {
584                let name = self.gcx.intern_str(&v.ident.to_string());
585                self.symbols.push(name, value);
586            }
587            syn::Pat::Type(v) => self.push_symbol_by_pat(&v.pat, value),
588            o => todo!("{o:#?}"),
589        }
590    }
591
592    fn update_symbol_by_expr(&mut self, lhs: &syn::Expr, rhs: Value<'gcx>) {
593        match lhs {
594            syn::Expr::Path(v) => self.update_symbol_by_expr_path(v, rhs),
595            o => todo!("{o:?}"),
596        }
597    }
598
599    fn update_symbol_by_expr_path(&mut self, lhs: &syn::ExprPath, rhs: Value<'gcx>) {
600        assert!(lhs.qself.is_none());
601
602        let lhs = lhs.path.get_ident().unwrap();
603        let name = lhs.to_string();
604        let value = self.symbols.get_mut(&*name).unwrap();
605        *value = rhs;
606    }
607
608    /// Applies the given values to the function.
609    fn apply_to_fn(&mut self, f: Fn, args: &[Value<'gcx>]) -> TriResult<Value<'gcx>, ()> {
610        self.symbols.push_opaque_block();
611
612        match f.inputs {
613            FnInputs::Params(inputs) => {
614                debug_assert_eq!(inputs.len(), args.len());
615                for (arg, value) in inputs.iter().cloned().zip(args) {
616                    let arg = unsafe { arg.as_ref().unwrap() };
617                    match arg {
618                        syn::FnArg::Receiver(_) => todo!(),
619                        syn::FnArg::Typed(v) => self.push_symbol_by_pat(&v.pat, value.clone()),
620                    }
621                }
622            }
623        }
624
625        let value = match f.body {
626            FnBody::Block(block) => {
627                let block = unsafe { block.as_ref().unwrap() };
628                self.eval_block(block)
629            }
630        };
631
632        self.symbols.pop_block();
633        value
634    }
635
636    fn apply_to_constructor<I>(&mut self, mut field_names: I, args: &[Value<'gcx>]) -> Value<'gcx>
637    where
638        I: Iterator<Item = Interned<'gcx, str>>,
639    {
640        let mut fields = Vec::new();
641        let mut args = args.iter();
642
643        while let (Some(field_name), Some(arg)) = (field_names.next(), args.next()) {
644            fields.push(Field {
645                name: field_name,
646                value: arg.clone(),
647            });
648        }
649
650        assert!(field_names.next().is_none());
651        assert!(args.next().is_none());
652
653        Value::Composed(fields)
654    }
655}
656
657#[derive(Debug, Default)]
658pub struct Evaluated<'gcx> {
659    /// Evaluated values that mapped to an expression or a path item.
660    mapped_values: Vec<Value<'gcx>>,
661
662    /// Mapping between an expression and an index to [`Self::mapped_values`].
663    ptr_map: Map<*const syn::Expr, usize>,
664
665    /// Mapping between an path item and an index to [`Self::mapped_values`].
666    pid_map: Map<PathId, usize>,
667}
668
669impl<'gcx> Evaluated<'gcx> {
670    pub(crate) fn new() -> Self {
671        Self {
672            mapped_values: Vec::new(),
673            ptr_map: Map::default(),
674            pid_map: Map::default(),
675        }
676    }
677
678    pub fn get_mapped_value_by_expr_ptr(&self, ptr: *const syn::Expr) -> Option<&Value<'gcx>> {
679        self.ptr_map
680            .get(&ptr)
681            .map(|index| &self.mapped_values[*index])
682    }
683
684    pub fn get_mapped_value_by_path_id(&self, pid: PathId) -> Option<&Value<'gcx>> {
685        self.pid_map
686            .get(&pid)
687            .map(|index| &self.mapped_values[*index])
688    }
689
690    pub(crate) fn get_value_by_expr(&self, expr: &syn::Expr) -> Option<&Value<'gcx>> {
691        self.get_mapped_value_by_expr_ptr(expr)
692    }
693
694    /// Inserts an expression pointer and its evaluated value.
695    ///
696    /// You can find the value using the expression pointer later.
697    pub(crate) fn insert_mapped_value(
698        &mut self,
699        ptr: *const syn::Expr,
700        value: Value<'gcx>,
701    ) -> Option<Value<'gcx>> {
702        match self.ptr_map.entry(ptr) {
703            Entry::Occupied(entry) => {
704                let index = *entry.get();
705                let old_value = mem::replace(&mut self.mapped_values[index], value);
706                Some(old_value)
707            }
708            Entry::Vacant(entry) => {
709                self.mapped_values.push(value);
710                entry.insert(self.mapped_values.len() - 1);
711                None
712            }
713        }
714    }
715
716    /// Inserts an expression pointer with path id and its evaluated value.
717    ///
718    /// You can find the value using the expression pointer or path id later.
719    pub(crate) fn insert_mapped_value2(
720        &mut self,
721        ptr: *const syn::Expr,
722        pid: PathId,
723        value: Value<'gcx>,
724    ) -> Option<Value<'gcx>> {
725        match (self.ptr_map.entry(ptr), self.pid_map.entry(pid)) {
726            (Entry::Occupied(ptr_entry), Entry::Occupied(pid_entry)) => {
727                debug_assert_eq!(ptr_entry.get(), pid_entry.get());
728                let index = *ptr_entry.get();
729                let old_value = mem::replace(&mut self.mapped_values[index], value);
730                Some(old_value)
731            }
732            (Entry::Occupied(ptr_entry), Entry::Vacant(pid_entry)) => {
733                let index = *ptr_entry.get();
734                pid_entry.insert(index);
735                let old_value = mem::replace(&mut self.mapped_values[index], value);
736                Some(old_value)
737            }
738            (Entry::Vacant(ptr_entry), Entry::Occupied(pid_entry)) => {
739                let index = *pid_entry.get();
740                ptr_entry.insert(index);
741                let old_value = mem::replace(&mut self.mapped_values[index], value);
742                Some(old_value)
743            }
744            (Entry::Vacant(ptr_entry), Entry::Vacant(pid_entry)) => {
745                self.mapped_values.push(value);
746                ptr_entry.insert(self.mapped_values.len() - 1);
747                pid_entry.insert(self.mapped_values.len() - 1);
748                None
749            }
750        }
751    }
752}
753
754#[cfg(test)]
755mod tests {
756    use super::{Evaluator, Host};
757    use crate::{
758        etc::syn::SynPath,
759        semantic::{
760            basic_traits::EvaluateArrayLength,
761            entry::GlobalCx,
762            eval::{
763                test_help::TestEvalHost,
764                value::{Fn, Scalar, Value},
765            },
766            infer::{
767                self,
768                test_help::{test_inferer, TestInferLogicHost},
769                Inferer,
770            },
771            logic::{self, test_help::test_logic, Logic},
772        },
773        Intern, Result, TriResult, TriResultHelper,
774    };
775    use logic_eval_util::str::StrPath;
776    use syn_locator::{Find, LocateEntry};
777
778    fn parse(code: &str) -> syn::Expr {
779        syn_locator::enable_thread_local(true);
780        syn_locator::clear();
781
782        let expr: syn::Expr = syn::parse_str(code).unwrap();
783        let pinned = std::pin::Pin::new(&expr);
784        pinned.locate_as_entry("mod.rs", code).unwrap();
785        expr
786    }
787
788    #[test]
789    fn test_eval_operators() {
790        fn eval<'gcx, H: infer::Host<'gcx> + logic::Host<'gcx>>(
791            inferer: &mut Inferer<'gcx>,
792            evaluator: &mut Evaluator<'gcx>,
793            logic: &mut Logic<'gcx>,
794            infer_logic_host: &mut H,
795            expr: &syn::Expr,
796        ) -> Result<Value<'gcx>> {
797            inferer
798                .infer_expr(logic, infer_logic_host, expr, None)
799                .elevate_err()?;
800            let mut eval_host = TestEvalHost::new(inferer);
801            evaluator.eval_expr(&mut eval_host, expr).elevate_err()
802        }
803
804        let gcx = GlobalCx::default();
805        let mut inferer = test_inferer(&gcx);
806        let mut evaluator = Evaluator::new(&gcx);
807        let mut logic = test_logic(&gcx);
808        let mut infer_logic_host = TestInferLogicHost::new(&gcx);
809
810        // Add
811        let expr = parse("{ 1 + 2 }");
812        let value = eval(
813            &mut inferer,
814            &mut evaluator,
815            &mut logic,
816            &mut infer_logic_host,
817            &expr,
818        )
819        .unwrap();
820        assert_eq!(value, Value::Scalar(Scalar::Int(1 + 2)));
821
822        // Sub
823        let expr = parse("{ 3 - 2 }");
824        let value = eval(
825            &mut inferer,
826            &mut evaluator,
827            &mut logic,
828            &mut infer_logic_host,
829            &expr,
830        )
831        .unwrap();
832        assert_eq!(value, Value::Scalar(Scalar::Int(3 - 2)));
833
834        // Mul
835        let expr = parse("{ 2 * 3 }");
836        let value = eval(
837            &mut inferer,
838            &mut evaluator,
839            &mut logic,
840            &mut infer_logic_host,
841            &expr,
842        )
843        .unwrap();
844        assert_eq!(value, Value::Scalar(Scalar::Int(2 * 3)));
845
846        // Div
847        let expr = parse("{ 6 / 3 }");
848        let value = eval(
849            &mut inferer,
850            &mut evaluator,
851            &mut logic,
852            &mut infer_logic_host,
853            &expr,
854        )
855        .unwrap();
856        assert_eq!(value, Value::Scalar(Scalar::Int(6 / 3)));
857
858        // Rem
859        let expr = parse("{ 3 % 2 }");
860        let value = eval(
861            &mut inferer,
862            &mut evaluator,
863            &mut logic,
864            &mut infer_logic_host,
865            &expr,
866        )
867        .unwrap();
868        assert_eq!(value, Value::Scalar(Scalar::Int(3 % 2)));
869
870        // BitXor
871        let expr = parse("{ 1 ^ 2 }");
872        let value = eval(
873            &mut inferer,
874            &mut evaluator,
875            &mut logic,
876            &mut infer_logic_host,
877            &expr,
878        )
879        .unwrap();
880        assert_eq!(value, Value::Scalar(Scalar::Int(1 ^ 2)));
881
882        // BitAnd
883        let expr = parse("{ 1 & 2 }");
884        let value = eval(
885            &mut inferer,
886            &mut evaluator,
887            &mut logic,
888            &mut infer_logic_host,
889            &expr,
890        )
891        .unwrap();
892        assert_eq!(value, Value::Scalar(Scalar::Int(1 & 2)));
893
894        // BitOr
895        let expr = parse("{ 1 | 2 }");
896        let value = eval(
897            &mut inferer,
898            &mut evaluator,
899            &mut logic,
900            &mut infer_logic_host,
901            &expr,
902        )
903        .unwrap();
904        assert_eq!(value, Value::Scalar(Scalar::Int(1 | 2)));
905
906        // Shl
907        let expr = parse("{ 1 << 2 }");
908        let value = eval(
909            &mut inferer,
910            &mut evaluator,
911            &mut logic,
912            &mut infer_logic_host,
913            &expr,
914        )
915        .unwrap();
916        assert_eq!(value, Value::Scalar(Scalar::Int(1 << 2)));
917
918        // Shr
919        let expr = parse("{ 4 >> 2 }");
920        let value = eval(
921            &mut inferer,
922            &mut evaluator,
923            &mut logic,
924            &mut infer_logic_host,
925            &expr,
926        )
927        .unwrap();
928        assert_eq!(value, Value::Scalar(Scalar::Int(4 >> 2)));
929
930        // AddAssign
931        let expr = parse("{ let mut a = 1; a += 2; a }");
932        let value = eval(
933            &mut inferer,
934            &mut evaluator,
935            &mut logic,
936            &mut infer_logic_host,
937            &expr,
938        )
939        .unwrap();
940        assert_eq!(value, Value::Scalar(Scalar::Int(1 + 2)));
941
942        // SubAssign
943        let expr = parse("{ let mut a = 3; a -= 2; a }");
944        let value = eval(
945            &mut inferer,
946            &mut evaluator,
947            &mut logic,
948            &mut infer_logic_host,
949            &expr,
950        )
951        .unwrap();
952        assert_eq!(value, Value::Scalar(Scalar::Int(3 - 2)));
953
954        // MulAssign
955        let expr = parse("{ let mut a = 2; a *= 3; a }");
956        let value = eval(
957            &mut inferer,
958            &mut evaluator,
959            &mut logic,
960            &mut infer_logic_host,
961            &expr,
962        )
963        .unwrap();
964        assert_eq!(value, Value::Scalar(Scalar::Int(2 * 3)));
965
966        // DivAssign
967        let expr = parse("{ let mut a = 6; a /= 3; a }");
968        let value = eval(
969            &mut inferer,
970            &mut evaluator,
971            &mut logic,
972            &mut infer_logic_host,
973            &expr,
974        )
975        .unwrap();
976        assert_eq!(value, Value::Scalar(Scalar::Int(6 / 3)));
977
978        // RemAssign
979        let expr = parse("{ let mut a = 3; a %= 2; a }");
980        let value = eval(
981            &mut inferer,
982            &mut evaluator,
983            &mut logic,
984            &mut infer_logic_host,
985            &expr,
986        )
987        .unwrap();
988        assert_eq!(value, Value::Scalar(Scalar::Int(3 % 2)));
989
990        // BitXorAssign
991        let expr = parse("{ let mut a = 1; a ^= 2; a }");
992        let value = eval(
993            &mut inferer,
994            &mut evaluator,
995            &mut logic,
996            &mut infer_logic_host,
997            &expr,
998        )
999        .unwrap();
1000        assert_eq!(value, Value::Scalar(Scalar::Int(1 ^ 2)));
1001
1002        // BitAndAssign
1003        let expr = parse("{ let mut a = 1; a &= 2; a }");
1004        let value = eval(
1005            &mut inferer,
1006            &mut evaluator,
1007            &mut logic,
1008            &mut infer_logic_host,
1009            &expr,
1010        )
1011        .unwrap();
1012        assert_eq!(value, Value::Scalar(Scalar::Int(1 & 2)));
1013
1014        // BitOrAssign
1015        let expr = parse("{ let mut a = 1; a |= 2; a }");
1016        let value = eval(
1017            &mut inferer,
1018            &mut evaluator,
1019            &mut logic,
1020            &mut infer_logic_host,
1021            &expr,
1022        )
1023        .unwrap();
1024        assert_eq!(value, Value::Scalar(Scalar::Int(1 | 2)));
1025
1026        // ShlAssign
1027        let expr = parse("{ let mut a = 1; a <<= 2; a }");
1028        let value = eval(
1029            &mut inferer,
1030            &mut evaluator,
1031            &mut logic,
1032            &mut infer_logic_host,
1033            &expr,
1034        )
1035        .unwrap();
1036        assert_eq!(value, Value::Scalar(Scalar::Int(1 << 2)));
1037
1038        // ShrAssign
1039        let expr = parse("{ let mut a = 4; a >>= 2; a }");
1040        let value = eval(
1041            &mut inferer,
1042            &mut evaluator,
1043            &mut logic,
1044            &mut infer_logic_host,
1045            &expr,
1046        )
1047        .unwrap();
1048        assert_eq!(value, Value::Scalar(Scalar::Int(4 >> 2)));
1049
1050        // Not
1051        let expr = parse("{ !false }");
1052        let value = eval(
1053            &mut inferer,
1054            &mut evaluator,
1055            &mut logic,
1056            &mut infer_logic_host,
1057            &expr,
1058        )
1059        .unwrap();
1060        assert_eq!(value, Value::Scalar(Scalar::Bool(true)));
1061
1062        // Neg
1063        let expr = parse("{ let mut a = 1; a = -a; a }");
1064        let value = eval(
1065            &mut inferer,
1066            &mut evaluator,
1067            &mut logic,
1068            &mut infer_logic_host,
1069            &expr,
1070        )
1071        .unwrap();
1072        assert_eq!(value, Value::Scalar(Scalar::Int(-1)));
1073
1074        // Operator priority (syn crate is already giving us syntax tree
1075        // concerning this priority)
1076        let expr = parse("{ 1 + 2 * 3 + 4 * 5 }");
1077        let value = eval(
1078            &mut inferer,
1079            &mut evaluator,
1080            &mut logic,
1081            &mut infer_logic_host,
1082            &expr,
1083        )
1084        .unwrap();
1085        assert_eq!(value, Value::Scalar(Scalar::Int(1 + 2 * 3 + 4 * 5)));
1086        let expr = parse("{ (1 + 2) * 3 + 4 * 5 }");
1087        let value = eval(
1088            &mut inferer,
1089            &mut evaluator,
1090            &mut logic,
1091            &mut infer_logic_host,
1092            &expr,
1093        )
1094        .unwrap();
1095        assert_eq!(value, Value::Scalar(Scalar::Int((1 + 2) * 3 + 4 * 5)));
1096    }
1097
1098    #[test]
1099    fn test_eval_function_call() {
1100        let code = r#"{
1101            fn f(x: i32) -> i32 { x * 2 }
1102            f(3)
1103        }"#;
1104
1105        struct TestEvalHost<'a, 'gcx> {
1106            inferer: &'a mut Inferer<'gcx>,
1107            expr: &'a syn::Expr,
1108        }
1109
1110        impl<'gcx> Host<'gcx> for TestEvalHost<'_, 'gcx> {
1111            fn find_type(&mut self, expr: &syn::Expr) -> TriResult<infer::Type<'gcx>, ()> {
1112                let ty = self.inferer.get_type(expr).unwrap().clone();
1113                Ok(ty)
1114            }
1115
1116            fn find_fn(&mut self, _: StrPath, _: &[infer::Type<'gcx>]) -> Fn {
1117                panic!()
1118            }
1119
1120            fn syn_path_to_value(&mut self, path: SynPath) -> TriResult<Value<'gcx>, ()> {
1121                let ident = path.path.get_ident().unwrap().to_string();
1122                if ident == "f" {
1123                    let code = "fn f(x: i32) -> i32 { x * 2 }";
1124                    let f: &syn::ItemFn = self.expr.find(code).unwrap();
1125                    let f = Fn::from_signature_and_block(&f.sig, &f.block);
1126                    Ok(Value::Fn(f))
1127                } else {
1128                    unreachable!()
1129                }
1130            }
1131        }
1132
1133        crate::impl_empty_scoping!(TestEvalHost<'_, '_>);
1134
1135        struct TestInferHost<'gcx> {
1136            gcx: &'gcx GlobalCx<'gcx>,
1137        }
1138
1139        impl<'gcx> infer::Host<'gcx> for TestInferHost<'gcx> {
1140            fn syn_path_to_type(
1141                &mut self,
1142                _: SynPath,
1143                types: &mut infer::UniqueTypes,
1144            ) -> TriResult<infer::Type<'gcx>, ()> {
1145                use infer::{Param, Type, TypeScalar};
1146
1147                let tid_i32 = types.insert_type(Type::Scalar(TypeScalar::I32));
1148
1149                let res = infer::Type::Named(infer::TypeNamed {
1150                    name: self.gcx.intern_str("f"),
1151                    params: [
1152                        Param::Other {
1153                            name: self.gcx.intern_str("0"),
1154                            tid: tid_i32,
1155                        },
1156                        Param::Other {
1157                            name: self.gcx.intern_str("1"),
1158                            tid: tid_i32,
1159                        },
1160                    ]
1161                    .into(),
1162                });
1163                Ok(res)
1164            }
1165        }
1166
1167        impl<'gcx> EvaluateArrayLength<'gcx> for TestInferHost<'gcx> {
1168            fn eval_array_len(&mut self, _: &syn::Expr) -> TriResult<crate::ArrayLen, ()> {
1169                unreachable!()
1170            }
1171        }
1172
1173        crate::impl_empty_scoping!(TestInferHost<'_>);
1174        crate::impl_empty_method_host!(TestInferHost<'_>);
1175
1176        let gcx = GlobalCx::default();
1177        let mut inferer = test_inferer(&gcx);
1178        let mut evaluator = Evaluator::new(&gcx);
1179        let mut logic = test_logic(&gcx);
1180        let mut infer_logic_host = TestInferLogicHost::new(&gcx);
1181        infer_logic_host.override_infer_host(TestInferHost { gcx: &gcx });
1182
1183        let expr = parse(code);
1184        inferer
1185            .infer_expr(&mut logic, &mut infer_logic_host, &expr, None)
1186            .unwrap();
1187        let mut eval_host = TestEvalHost {
1188            inferer: &mut inferer,
1189            expr: &expr,
1190        };
1191        let value = evaluator.eval_expr(&mut eval_host, &expr).unwrap();
1192
1193        assert_eq!(value, Value::Scalar(Scalar::I32(3 * 2)));
1194    }
1195}