Skip to main content

tidepool_eval/
eval.rs

1use crate::env::Env;
2use crate::error::EvalError;
3use crate::heap::{Heap, ThunkState};
4use crate::value::Value;
5use tidepool_repr::{AltCon, CoreExpr, CoreFrame, DataConId, DataConTable, Literal, PrimOpKind, VarId};
6
7/// Create an environment pre-populated with data constructor functions.
8/// Each constructor with arity N becomes a `ConFun(tag, N, [])` value
9/// bound to its worker VarId, so that `Var` references to constructors
10/// in the expression tree resolve correctly.
11pub fn env_from_datacon_table(table: &DataConTable) -> Env {
12    let mut env = Env::new();
13    for dc in table.iter() {
14        let var = VarId(dc.id.0);
15        if dc.rep_arity == 0 {
16            // Nullary constructor: just a Con value
17            env.insert(var, Value::Con(dc.id, vec![]));
18        } else {
19            env.insert(var, Value::ConFun(dc.id, dc.rep_arity as usize, vec![]));
20        }
21    }
22    env
23}
24
25/// Evaluate a CoreExpr to a Value.
26pub fn eval(expr: &CoreExpr, env: &Env, heap: &mut dyn Heap) -> Result<Value, EvalError> {
27    if expr.nodes.is_empty() {
28        return Err(EvalError::TypeMismatch {
29            expected: "non-empty expression",
30            got: crate::error::ValueKind::Other("empty tree".into()),
31        });
32    }
33    let res = eval_at(expr, expr.nodes.len() - 1, env, heap)?;
34    force(res, heap)
35}
36
37/// Force a thunk to a value.
38pub fn force(val: Value, heap: &mut dyn Heap) -> Result<Value, EvalError> {
39    match val {
40        Value::ThunkRef(id) => {
41            match heap.read(id).clone() {
42                ThunkState::Evaluated(v) => force(v, heap),
43                ThunkState::BlackHole => Err(EvalError::InfiniteLoop(id)),
44                ThunkState::Unevaluated(env, expr) => {
45                    heap.write(id, ThunkState::BlackHole);
46                    match eval(&expr, &env, heap) {
47                        Ok(result) => {
48                            heap.write(id, ThunkState::Evaluated(result.clone()));
49                            Ok(result)
50                        }
51                        Err(err) => {
52                            // Restore state on error to avoid masking original failure
53                            // with InfiniteLoop on subsequent forces.
54                            heap.write(id, ThunkState::Unevaluated(env, expr));
55                            Err(err)
56                        }
57                    }
58                }
59            }
60        }
61        other => Ok(other),
62    }
63}
64
65/// Recursively force a value — forces all thunks inside constructors,
66/// producing a fully-evaluated tree with no `ThunkRef` values.
67pub fn deep_force(val: Value, heap: &mut dyn Heap) -> Result<Value, EvalError> {
68    match val {
69        Value::ThunkRef(id) => {
70            let forced = force(Value::ThunkRef(id), heap)?;
71            deep_force(forced, heap)
72        }
73        Value::Con(tag, fields) => {
74            let mut forced_fields = Vec::with_capacity(fields.len());
75            for f in fields {
76                forced_fields.push(deep_force(f, heap)?);
77            }
78            Ok(Value::Con(tag, forced_fields))
79        }
80        Value::ConFun(tag, arity, args) => {
81            let mut forced_args = Vec::with_capacity(args.len());
82            for a in args {
83                forced_args.push(deep_force(a, heap)?);
84            }
85            Ok(Value::ConFun(tag, arity, forced_args))
86        }
87        other => Ok(other), // Lit, Closure, JoinCont — already values
88    }
89}
90
91/// Evaluate the node at `idx` in the expression tree.
92fn eval_at(
93    expr: &CoreExpr,
94    idx: usize,
95    env: &Env,
96    heap: &mut dyn Heap,
97) -> Result<Value, EvalError> {
98    match &expr.nodes[idx] {
99        CoreFrame::Var(v) => env.get(v).cloned().ok_or(EvalError::UnboundVar(*v)),
100        CoreFrame::Lit(lit) => Ok(Value::Lit(lit.clone())),
101        CoreFrame::App { fun, arg } => {
102            let fun_val = force(eval_at(expr, *fun, env, heap)?, heap)?;
103            let arg_val = eval_at(expr, *arg, env, heap)?;
104            match fun_val {
105                Value::Closure(clos_env, binder, body) => {
106                    let mut new_env = clos_env;
107                    new_env.insert(binder, arg_val);
108                    eval(&body, &new_env, heap)
109                }
110                Value::ConFun(tag, arity, mut args) => {
111                    args.push(arg_val);
112                    if args.len() == arity {
113                        // Force all fields when constructor is saturated
114                        let mut forced_args = Vec::with_capacity(args.len());
115                        for a in args {
116                            forced_args.push(force(a, heap)?);
117                        }
118                        Ok(Value::Con(tag, forced_args))
119                    } else {
120                        Ok(Value::ConFun(tag, arity, args))
121                    }
122                }
123                _ => Err(EvalError::NotAFunction),
124            }
125        }
126        CoreFrame::Lam { binder, body } => {
127            let body_expr = expr.extract_subtree(*body);
128            Ok(Value::Closure(env.clone(), *binder, body_expr))
129        }
130        CoreFrame::LetNonRec { binder, rhs, body } => {
131            let rhs_val = if matches!(&expr.nodes[*rhs], CoreFrame::Lam { .. }) {
132                eval_at(expr, *rhs, env, heap)? // Lambdas are values
133            } else {
134                let thunk_id = heap.alloc(env.clone(), expr.extract_subtree(*rhs));
135                Value::ThunkRef(thunk_id)
136            };
137            let new_env = env.update(*binder, rhs_val);
138            eval_at(expr, *body, &new_env, heap)
139        }
140        CoreFrame::LetRec { bindings, body } => {
141            let mut new_env = env.clone();
142            let mut thunks = Vec::new();
143
144            // 1. Allocate thunks for all binders to allow full knot-tying.
145            // (Spec: non-lambdas -> ThunkRef, but for knot-tying lambdas also need to be accessible)
146            for (binder, rhs_idx) in bindings {
147                let tid = heap.alloc(Env::new(), CoreExpr { nodes: vec![] });
148                new_env = new_env.update(*binder, Value::ThunkRef(tid));
149                thunks.push((*binder, tid, *rhs_idx));
150            }
151
152            // 2. Evaluate lambda RHSes and back-patch thunks. Update env with Closures.
153            for (binder, tid, rhs_idx) in &thunks {
154                if matches!(&expr.nodes[*rhs_idx], CoreFrame::Lam { .. }) {
155                    let lam_val = eval_at(expr, *rhs_idx, &new_env, heap)?;
156                    heap.write(*tid, ThunkState::Evaluated(lam_val.clone()));
157                    new_env = new_env.update(*binder, lam_val);
158                } else {
159                    let rhs_subtree = expr.extract_subtree(*rhs_idx);
160                    heap.write(*tid, ThunkState::Unevaluated(new_env.clone(), rhs_subtree));
161                }
162            }
163
164            eval_at(expr, *body, &new_env, heap)
165        }
166        CoreFrame::Con { tag, fields } => {
167            let mut field_vals = Vec::with_capacity(fields.len());
168            for &f in fields {
169                field_vals.push(eval_at(expr, f, env, heap)?);
170            }
171            Ok(Value::Con(*tag, field_vals))
172        }
173        CoreFrame::Case {
174            scrutinee,
175            binder,
176            alts,
177        } => {
178            let scrut_val = force(eval_at(expr, *scrutinee, env, heap)?, heap)?;
179            let case_env = env.update(*binder, scrut_val.clone());
180
181            // Try specific alternatives first; Default is a fallback, not positional.
182            // GHC Core can place DEFAULT first in the alt list.
183            let mut default_alt = None;
184            for alt in alts {
185                match &alt.con {
186                    AltCon::DataAlt(tag) => {
187                        if let Value::Con(con_tag, fields) = &scrut_val {
188                            if con_tag == tag {
189                                if fields.len() != alt.binders.len() {
190                                    return Err(EvalError::ArityMismatch {
191                                        context: "case binders",
192                                        expected: alt.binders.len(),
193                                        got: fields.len(),
194                                    });
195                                }
196                                let mut alt_env = case_env;
197                                for (b, v) in alt.binders.iter().zip(fields.iter()) {
198                                    alt_env = alt_env.update(*b, v.clone());
199                                }
200                                return eval_at(expr, alt.body, &alt_env, heap);
201                            }
202                        }
203                    }
204                    AltCon::LitAlt(lit) => {
205                        if let Value::Lit(l) = &scrut_val {
206                            if l == lit {
207                                return eval_at(expr, alt.body, &case_env, heap);
208                            }
209                        }
210                    }
211                    AltCon::Default => {
212                        default_alt = Some(alt);
213                    }
214                }
215            }
216            if let Some(alt) = default_alt {
217                return eval_at(expr, alt.body, &case_env, heap);
218            }
219            Err(EvalError::NoMatchingAlt)
220        }
221        CoreFrame::PrimOp { op, args } => {
222            let mut arg_vals = Vec::with_capacity(args.len());
223            for &arg in args {
224                let val = force(eval_at(expr, arg, env, heap)?, heap)?;
225                arg_vals.push(val);
226            }
227            dispatch_primop(*op, arg_vals)
228        }
229        CoreFrame::Join {
230            label,
231            params,
232            rhs,
233            body,
234        } => {
235            let join_val = Value::JoinCont(params.clone(), expr.extract_subtree(*rhs), env.clone());
236            let join_var = VarId(label.0 | (1u64 << 63)); // high bit distinguishes join labels
237            let new_env = env.update(join_var, join_val);
238            eval_at(expr, *body, &new_env, heap)
239        }
240        CoreFrame::Jump { label, args } => {
241            let join_var = VarId(label.0 | (1u64 << 63));
242            match env.get(&join_var) {
243                Some(Value::JoinCont(params, rhs_expr, join_env)) => {
244                    if params.len() != args.len() {
245                        return Err(EvalError::ArityMismatch {
246                            context: "arguments",
247                            expected: params.len(),
248                            got: args.len(),
249                        });
250                    }
251                    let params = params.clone();
252                    let rhs_expr = rhs_expr.clone();
253                    let mut new_env = join_env.clone();
254                    for (param, arg_idx) in params.iter().zip(args.iter()) {
255                        let arg_val = eval_at(expr, *arg_idx, env, heap)?;
256                        new_env = new_env.update(*param, arg_val);
257                    }
258                    eval(&rhs_expr, &new_env, heap)
259                }
260                _ => Err(EvalError::UnboundJoin(*label)),
261            }
262        }
263    }
264}
265
266fn dispatch_primop(op: PrimOpKind, args: Vec<Value>) -> Result<Value, EvalError> {
267    match op {
268        PrimOpKind::IntAdd => {
269            let (a, b) = bin_op_int(op, &args)?;
270            Ok(Value::Lit(Literal::LitInt(a.wrapping_add(b))))
271        }
272        PrimOpKind::IntSub => {
273            let (a, b) = bin_op_int(op, &args)?;
274            Ok(Value::Lit(Literal::LitInt(a.wrapping_sub(b))))
275        }
276        PrimOpKind::IntMul => {
277            let (a, b) = bin_op_int(op, &args)?;
278            Ok(Value::Lit(Literal::LitInt(a.wrapping_mul(b))))
279        }
280        PrimOpKind::IntNegate => {
281            if args.len() != 1 {
282                return Err(EvalError::ArityMismatch {
283                    context: "arguments",
284                    expected: 1,
285                    got: args.len(),
286                });
287            }
288            let a = expect_int(&args[0])?;
289            Ok(Value::Lit(Literal::LitInt(a.wrapping_neg())))
290        }
291        PrimOpKind::IntEq => cmp_int(op, &args, |a, b| a == b),
292        PrimOpKind::IntNe => cmp_int(op, &args, |a, b| a != b),
293        PrimOpKind::IntLt => cmp_int(op, &args, |a, b| a < b),
294        PrimOpKind::IntLe => cmp_int(op, &args, |a, b| a <= b),
295        PrimOpKind::IntGt => cmp_int(op, &args, |a, b| a > b),
296        PrimOpKind::IntGe => cmp_int(op, &args, |a, b| a >= b),
297
298        PrimOpKind::WordAdd => {
299            let (a, b) = bin_op_word(op, &args)?;
300            Ok(Value::Lit(Literal::LitWord(a.wrapping_add(b))))
301        }
302        PrimOpKind::WordSub => {
303            let (a, b) = bin_op_word(op, &args)?;
304            Ok(Value::Lit(Literal::LitWord(a.wrapping_sub(b))))
305        }
306        PrimOpKind::WordMul => {
307            let (a, b) = bin_op_word(op, &args)?;
308            Ok(Value::Lit(Literal::LitWord(a.wrapping_mul(b))))
309        }
310        PrimOpKind::WordEq => cmp_word(op, &args, |a, b| a == b),
311        PrimOpKind::WordNe => cmp_word(op, &args, |a, b| a != b),
312        PrimOpKind::WordLt => cmp_word(op, &args, |a, b| a < b),
313        PrimOpKind::WordLe => cmp_word(op, &args, |a, b| a <= b),
314        PrimOpKind::WordGt => cmp_word(op, &args, |a, b| a > b),
315        PrimOpKind::WordGe => cmp_word(op, &args, |a, b| a >= b),
316
317        PrimOpKind::DoubleAdd => {
318            let (a, b) = bin_op_double(op, &args)?;
319            Ok(Value::Lit(Literal::LitDouble((a + b).to_bits())))
320        }
321        PrimOpKind::DoubleSub => {
322            let (a, b) = bin_op_double(op, &args)?;
323            Ok(Value::Lit(Literal::LitDouble((a - b).to_bits())))
324        }
325        PrimOpKind::DoubleMul => {
326            let (a, b) = bin_op_double(op, &args)?;
327            Ok(Value::Lit(Literal::LitDouble((a * b).to_bits())))
328        }
329        PrimOpKind::DoubleDiv => {
330            let (a, b) = bin_op_double(op, &args)?;
331            Ok(Value::Lit(Literal::LitDouble((a / b).to_bits())))
332        }
333        PrimOpKind::DoubleEq => cmp_double(op, &args, |a, b| a == b),
334        PrimOpKind::DoubleNe => cmp_double(op, &args, |a, b| a != b),
335        PrimOpKind::DoubleLt => cmp_double(op, &args, |a, b| a < b),
336        PrimOpKind::DoubleLe => cmp_double(op, &args, |a, b| a <= b),
337        PrimOpKind::DoubleGt => cmp_double(op, &args, |a, b| a > b),
338        PrimOpKind::DoubleGe => cmp_double(op, &args, |a, b| a >= b),
339
340        PrimOpKind::CharEq => cmp_char(op, &args, |a, b| a == b),
341        PrimOpKind::CharNe => cmp_char(op, &args, |a, b| a != b),
342        PrimOpKind::CharLt => cmp_char(op, &args, |a, b| a < b),
343        PrimOpKind::CharLe => cmp_char(op, &args, |a, b| a <= b),
344        PrimOpKind::CharGt => cmp_char(op, &args, |a, b| a > b),
345        PrimOpKind::CharGe => cmp_char(op, &args, |a, b| a >= b),
346
347        PrimOpKind::SeqOp => {
348            if args.len() != 2 {
349                return Err(EvalError::ArityMismatch {
350                    context: "arguments",
351                    expected: 2,
352                    got: args.len(),
353                });
354            }
355            Ok(args[1].clone())
356        }
357        PrimOpKind::DataToTag => {
358            if args.len() != 1 {
359                return Err(EvalError::ArityMismatch {
360                    context: "arguments",
361                    expected: 1,
362                    got: args.len(),
363                });
364            }
365            if let Value::Con(DataConId(tag), _) = &args[0] {
366                Ok(Value::Lit(Literal::LitInt(*tag as i64)))
367            } else {
368                Err(EvalError::TypeMismatch {
369                    expected: "Data constructor",
370                    got: crate::error::ValueKind::Other(format!("{:?}", args[0])),
371                })
372            }
373        }
374        PrimOpKind::IntQuot => {
375            let (a, b) = bin_op_int(op, &args)?;
376            Ok(Value::Lit(Literal::LitInt(a.wrapping_div(b))))
377        }
378        PrimOpKind::IntRem => {
379            let (a, b) = bin_op_int(op, &args)?;
380            Ok(Value::Lit(Literal::LitInt(a.wrapping_rem(b))))
381        }
382        PrimOpKind::Chr => {
383            if args.len() != 1 {
384                return Err(EvalError::ArityMismatch {
385                    context: "arguments",
386                    expected: 1,
387                    got: args.len(),
388                });
389            }
390            let n = expect_int(&args[0])?;
391            Ok(Value::Lit(Literal::LitChar(
392                char::from_u32(n as u32).unwrap_or('\0'),
393            )))
394        }
395        PrimOpKind::Ord => {
396            if args.len() != 1 {
397                return Err(EvalError::ArityMismatch {
398                    context: "arguments",
399                    expected: 1,
400                    got: args.len(),
401                });
402            }
403            let c = expect_char(&args[0])?;
404            Ok(Value::Lit(Literal::LitInt(c as i64)))
405        }
406        PrimOpKind::IndexArray | PrimOpKind::TagToEnum => Err(EvalError::UnsupportedPrimOp(op)),
407    }
408}
409
410fn expect_int(v: &Value) -> Result<i64, EvalError> {
411    if let Value::Lit(Literal::LitInt(n)) = v {
412        Ok(*n)
413    } else {
414        Err(EvalError::TypeMismatch {
415            expected: "Int#",
416            got: crate::error::ValueKind::Other(format!("{:?}", v)),
417        })
418    }
419}
420
421fn expect_word(v: &Value) -> Result<u64, EvalError> {
422    if let Value::Lit(Literal::LitWord(n)) = v {
423        Ok(*n)
424    } else {
425        Err(EvalError::TypeMismatch {
426            expected: "Word#",
427            got: crate::error::ValueKind::Other(format!("{:?}", v)),
428        })
429    }
430}
431
432fn expect_double(v: &Value) -> Result<f64, EvalError> {
433    if let Value::Lit(Literal::LitDouble(bits)) = v {
434        Ok(f64::from_bits(*bits))
435    } else {
436        Err(EvalError::TypeMismatch {
437            expected: "Double#",
438            got: crate::error::ValueKind::Other(format!("{:?}", v)),
439        })
440    }
441}
442
443fn expect_char(v: &Value) -> Result<char, EvalError> {
444    if let Value::Lit(Literal::LitChar(c)) = v {
445        Ok(*c)
446    } else {
447        Err(EvalError::TypeMismatch {
448            expected: "Char#",
449            got: crate::error::ValueKind::Other(format!("{:?}", v)),
450        })
451    }
452}
453
454fn bin_op_int(_op: PrimOpKind, args: &[Value]) -> Result<(i64, i64), EvalError> {
455    if args.len() != 2 {
456        return Err(EvalError::ArityMismatch {
457            context: "arguments",
458            expected: 2,
459            got: args.len(),
460        });
461    }
462    Ok((expect_int(&args[0])?, expect_int(&args[1])?))
463}
464
465fn bin_op_word(_op: PrimOpKind, args: &[Value]) -> Result<(u64, u64), EvalError> {
466    if args.len() != 2 {
467        return Err(EvalError::ArityMismatch {
468            context: "arguments",
469            expected: 2,
470            got: args.len(),
471        });
472    }
473    Ok((expect_word(&args[0])?, expect_word(&args[1])?))
474}
475
476fn bin_op_double(_op: PrimOpKind, args: &[Value]) -> Result<(f64, f64), EvalError> {
477    if args.len() != 2 {
478        return Err(EvalError::ArityMismatch {
479            context: "arguments",
480            expected: 2,
481            got: args.len(),
482        });
483    }
484    Ok((expect_double(&args[0])?, expect_double(&args[1])?))
485}
486
487fn cmp_int(
488    op: PrimOpKind,
489    args: &[Value],
490    f: impl Fn(i64, i64) -> bool,
491) -> Result<Value, EvalError> {
492    let (a, b) = bin_op_int(op, args)?;
493    Ok(Value::Lit(Literal::LitInt(if f(a, b) { 1 } else { 0 })))
494}
495
496fn cmp_word(
497    op: PrimOpKind,
498    args: &[Value],
499    f: impl Fn(u64, u64) -> bool,
500) -> Result<Value, EvalError> {
501    let (a, b) = bin_op_word(op, args)?;
502    Ok(Value::Lit(Literal::LitInt(if f(a, b) { 1 } else { 0 })))
503}
504
505fn cmp_double(
506    op: PrimOpKind,
507    args: &[Value],
508    f: impl Fn(f64, f64) -> bool,
509) -> Result<Value, EvalError> {
510    let (a, b) = bin_op_double(op, args)?;
511    Ok(Value::Lit(Literal::LitInt(if f(a, b) { 1 } else { 0 })))
512}
513
514fn cmp_char(
515    _op: PrimOpKind,
516    args: &[Value],
517    f: impl Fn(char, char) -> bool,
518) -> Result<Value, EvalError> {
519    if args.len() != 2 {
520        return Err(EvalError::ArityMismatch {
521            context: "arguments",
522            expected: 2,
523            got: args.len(),
524        });
525    }
526    let a = expect_char(&args[0])?;
527    let b = expect_char(&args[1])?;
528    Ok(Value::Lit(Literal::LitInt(if f(a, b) { 1 } else { 0 })))
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534    use tidepool_repr::{Alt, AltCon, CoreFrame, DataConId, JoinId, Literal, RecursiveTree, VarId};
535
536    #[test]
537    fn test_eval_lit() {
538        let expr = RecursiveTree {
539            nodes: vec![CoreFrame::Lit(Literal::LitInt(42))],
540        };
541        let mut heap = crate::heap::VecHeap::new();
542        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
543        if let Value::Lit(Literal::LitInt(n)) = res {
544            assert_eq!(n, 42);
545        } else {
546            panic!("Expected LitInt(42), got {:?}", res);
547        }
548    }
549
550    #[test]
551    fn test_eval_var() {
552        let expr = RecursiveTree {
553            nodes: vec![CoreFrame::Var(VarId(1))],
554        };
555        let mut env = Env::new();
556        env.insert(VarId(1), Value::Lit(Literal::LitInt(42)));
557        let mut heap = crate::heap::VecHeap::new();
558        let res = eval(&expr, &env, &mut heap).unwrap();
559        if let Value::Lit(Literal::LitInt(n)) = res {
560            assert_eq!(n, 42);
561        } else {
562            panic!("Expected LitInt(42), got {:?}", res);
563        }
564    }
565
566    #[test]
567    fn test_eval_unbound_var() {
568        let expr = RecursiveTree {
569            nodes: vec![CoreFrame::Var(VarId(1))],
570        };
571        let mut heap = crate::heap::VecHeap::new();
572        let res = eval(&expr, &Env::new(), &mut heap);
573        assert!(matches!(res, Err(EvalError::UnboundVar(VarId(1)))));
574    }
575
576    #[test]
577    fn test_eval_lam_identity() {
578        let nodes = vec![
579            CoreFrame::Var(VarId(1)),
580            CoreFrame::Lam {
581                binder: VarId(1),
582                body: 0,
583            },
584            CoreFrame::Lit(Literal::LitInt(42)),
585            CoreFrame::App { fun: 1, arg: 2 },
586        ];
587        let expr = CoreExpr { nodes };
588        let mut heap = crate::heap::VecHeap::new();
589        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
590        if let Value::Lit(Literal::LitInt(n)) = res {
591            assert_eq!(n, 42);
592        } else {
593            panic!("Expected LitInt(42), got {:?}", res);
594        }
595    }
596
597    #[test]
598    fn test_eval_let_nonrec() {
599        let nodes = vec![
600            CoreFrame::Lit(Literal::LitInt(1)),
601            CoreFrame::Var(VarId(1)),
602            CoreFrame::LetNonRec {
603                binder: VarId(1),
604                rhs: 0,
605                body: 1,
606            },
607        ];
608        let expr = CoreExpr { nodes };
609        let mut heap = crate::heap::VecHeap::new();
610        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
611        if let Value::Lit(Literal::LitInt(n)) = res {
612            assert_eq!(n, 1);
613        } else {
614            panic!("Expected LitInt(1), got {:?}", res);
615        }
616    }
617
618    #[test]
619    fn test_eval_con() {
620        let nodes = vec![
621            CoreFrame::Lit(Literal::LitInt(42)),
622            CoreFrame::Con {
623                tag: DataConId(1),
624                fields: vec![0],
625            },
626        ];
627        let expr = CoreExpr { nodes };
628        let mut heap = crate::heap::VecHeap::new();
629        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
630        if let Value::Con(tag, fields) = res {
631            assert_eq!(tag.0, 1);
632            assert_eq!(fields.len(), 1);
633            if let Value::Lit(Literal::LitInt(n)) = fields[0] {
634                assert_eq!(n, 42);
635            } else {
636                panic!("Expected LitInt(42)");
637            }
638        } else {
639            panic!("Expected Con");
640        }
641    }
642
643    #[test]
644    fn test_eval_primop_add() {
645        let nodes = vec![
646            CoreFrame::Lit(Literal::LitInt(1)),
647            CoreFrame::Lit(Literal::LitInt(2)),
648            CoreFrame::PrimOp {
649                op: PrimOpKind::IntAdd,
650                args: vec![0, 1],
651            },
652        ];
653        let expr = CoreExpr { nodes };
654        let mut heap = crate::heap::VecHeap::new();
655        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
656        if let Value::Lit(Literal::LitInt(n)) = res {
657            assert_eq!(n, 3);
658        } else {
659            panic!("Expected LitInt(3)");
660        }
661    }
662
663    #[test]
664    fn test_eval_case_data() {
665        let nodes = vec![
666            CoreFrame::Lit(Literal::LitInt(42)),
667            CoreFrame::Con {
668                tag: DataConId(1),
669                fields: vec![0],
670            },
671            CoreFrame::Var(VarId(10)),
672            CoreFrame::Case {
673                scrutinee: 1,
674                binder: VarId(11),
675                alts: vec![Alt {
676                    con: AltCon::DataAlt(DataConId(1)),
677                    binders: vec![VarId(10)],
678                    body: 2,
679                }],
680            },
681        ];
682        let expr = CoreExpr { nodes };
683        let mut heap = crate::heap::VecHeap::new();
684        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
685        if let Value::Lit(Literal::LitInt(n)) = res {
686            assert_eq!(n, 42);
687        } else {
688            panic!("Expected LitInt(42)");
689        }
690    }
691
692    #[test]
693    fn test_eval_case_binder() {
694        let nodes = vec![
695            CoreFrame::Lit(Literal::LitInt(42)),
696            CoreFrame::Con {
697                tag: DataConId(1),
698                fields: vec![0],
699            },
700            CoreFrame::Var(VarId(11)),
701            CoreFrame::Case {
702                scrutinee: 1,
703                binder: VarId(11),
704                alts: vec![Alt {
705                    con: AltCon::DataAlt(DataConId(1)),
706                    binders: vec![VarId(10)],
707                    body: 2,
708                }],
709            },
710        ];
711        let expr = CoreExpr { nodes };
712        let mut heap = crate::heap::VecHeap::new();
713        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
714        if let Value::Con(tag, _) = res {
715            assert_eq!(tag.0, 1);
716        } else {
717            panic!("Expected Con");
718        }
719    }
720
721    #[test]
722    fn test_eval_case_lit_default() {
723        let nodes = vec![
724            CoreFrame::Lit(Literal::LitInt(3)),
725            CoreFrame::Lit(Literal::LitInt(10)),
726            CoreFrame::Lit(Literal::LitInt(20)),
727            CoreFrame::Lit(Literal::LitInt(30)),
728            CoreFrame::Case {
729                scrutinee: 0,
730                binder: VarId(10),
731                alts: vec![
732                    Alt {
733                        con: AltCon::LitAlt(Literal::LitInt(1)),
734                        binders: vec![],
735                        body: 1,
736                    },
737                    Alt {
738                        con: AltCon::LitAlt(Literal::LitInt(2)),
739                        binders: vec![],
740                        body: 2,
741                    },
742                    Alt {
743                        con: AltCon::Default,
744                        binders: vec![],
745                        body: 3,
746                    },
747                ],
748            },
749        ];
750        let expr = CoreExpr { nodes };
751        let mut heap = crate::heap::VecHeap::new();
752        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
753        if let Value::Lit(Literal::LitInt(n)) = res {
754            assert_eq!(n, 30);
755        } else {
756            panic!("Expected LitInt(30)");
757        }
758    }
759
760    #[test]
761    fn test_eval_data_to_tag() {
762        let nodes = vec![
763            CoreFrame::Con {
764                tag: DataConId(5),
765                fields: vec![],
766            },
767            CoreFrame::PrimOp {
768                op: PrimOpKind::DataToTag,
769                args: vec![0],
770            },
771        ];
772        let expr = CoreExpr { nodes };
773        let mut heap = crate::heap::VecHeap::new();
774        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
775        if let Value::Lit(Literal::LitInt(n)) = res {
776            assert_eq!(n, 5);
777        } else {
778            panic!("Expected LitInt(5)");
779        }
780    }
781
782    #[test]
783    fn test_eval_let_rec_forward_refs() {
784        // let { x = 1; y = x } in y
785        let nodes = vec![
786            CoreFrame::Lit(Literal::LitInt(1)), // 0
787            CoreFrame::Var(VarId(1)),           // 1: x
788            CoreFrame::Var(VarId(2)),           // 2: y
789            CoreFrame::LetRec {
790                bindings: vec![(VarId(1), 0), (VarId(2), 1)],
791                body: 2,
792            }, // 3
793        ];
794        let expr = CoreExpr { nodes };
795        let mut heap = crate::heap::VecHeap::new();
796        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
797        if let Value::Lit(Literal::LitInt(n)) = res {
798            assert_eq!(n, 1);
799        } else {
800            panic!("Expected LitInt(1)");
801        }
802    }
803
804    #[test]
805    fn test_eval_join_simple() {
806        // join j(x) = x in jump j(42)
807        let nodes = vec![
808            CoreFrame::Var(VarId(10)),           // 0: x
809            CoreFrame::Lit(Literal::LitInt(42)), // 1
810            CoreFrame::Jump {
811                label: JoinId(1),
812                args: vec![1],
813            }, // 2
814            CoreFrame::Join {
815                label: JoinId(1),
816                params: vec![VarId(10)],
817                rhs: 0,
818                body: 2,
819            }, // 3
820        ];
821        let expr = CoreExpr { nodes };
822        let mut heap = crate::heap::VecHeap::new();
823        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
824        if let Value::Lit(Literal::LitInt(n)) = res {
825            assert_eq!(n, 42);
826        } else {
827            panic!("Expected LitInt(42), got {:?}", res);
828        }
829    }
830
831    #[test]
832    fn test_eval_join_multi_param() {
833        // join j(x, y) = x + y in jump j(1, 2)
834        let nodes = vec![
835            CoreFrame::Var(VarId(10)), // 0: x
836            CoreFrame::Var(VarId(11)), // 1: y
837            CoreFrame::PrimOp {
838                op: PrimOpKind::IntAdd,
839                args: vec![0, 1],
840            }, // 2: x + y
841            CoreFrame::Lit(Literal::LitInt(1)), // 3
842            CoreFrame::Lit(Literal::LitInt(2)), // 4
843            CoreFrame::Jump {
844                label: JoinId(1),
845                args: vec![3, 4],
846            }, // 5
847            CoreFrame::Join {
848                label: JoinId(1),
849                params: vec![VarId(10), VarId(11)],
850                rhs: 2,
851                body: 5,
852            }, // 6
853        ];
854        let expr = CoreExpr { nodes };
855        let mut heap = crate::heap::VecHeap::new();
856        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
857        if let Value::Lit(Literal::LitInt(n)) = res {
858            assert_eq!(n, 3);
859        } else {
860            panic!("Expected LitInt(3)");
861        }
862    }
863
864    #[test]
865    fn test_eval_unbound_jump() {
866        let nodes = vec![CoreFrame::Jump {
867            label: JoinId(1),
868            args: vec![],
869        }];
870        let expr = CoreExpr { nodes };
871        let mut heap = crate::heap::VecHeap::new();
872        let res = eval(&expr, &Env::new(), &mut heap);
873        assert!(matches!(res, Err(EvalError::UnboundJoin(JoinId(1)))));
874    }
875
876    #[test]
877    fn test_thunk_lazy() {
878        // let x = <unbound> in 42
879        let nodes = vec![
880            CoreFrame::Var(VarId(999)),          // 0: unbound
881            CoreFrame::Lit(Literal::LitInt(42)), // 1
882            CoreFrame::LetNonRec {
883                binder: VarId(1),
884                rhs: 0,
885                body: 1,
886            }, // 2
887        ];
888        let expr = CoreExpr { nodes };
889        let mut heap = crate::heap::VecHeap::new();
890        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
891        if let Value::Lit(Literal::LitInt(n)) = res {
892            assert_eq!(n, 42);
893        } else {
894            panic!("Expected LitInt(42)");
895        }
896    }
897
898    #[test]
899    fn test_thunk_caching() {
900        // let x = 1 + 1 in x + x
901        let nodes = vec![
902            CoreFrame::Lit(Literal::LitInt(1)), // 0
903            CoreFrame::PrimOp {
904                op: PrimOpKind::IntAdd,
905                args: vec![0, 0],
906            }, // 1: 1 + 1
907            CoreFrame::Var(VarId(1)),           // 2: x
908            CoreFrame::PrimOp {
909                op: PrimOpKind::IntAdd,
910                args: vec![2, 2],
911            }, // 3: x + x
912            CoreFrame::LetNonRec {
913                binder: VarId(1),
914                rhs: 1,
915                body: 3,
916            }, // 4
917        ];
918        let expr = CoreExpr { nodes };
919        let mut heap = crate::heap::VecHeap::new();
920        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
921        if let Value::Lit(Literal::LitInt(n)) = res {
922            assert_eq!(n, 4);
923        } else {
924            panic!("Expected LitInt(4)");
925        }
926    }
927
928    #[test]
929    fn test_thunk_blackhole() {
930        // letrec x = x in x
931        let nodes = vec![
932            CoreFrame::Var(VarId(1)), // 0: x
933            CoreFrame::LetRec {
934                bindings: vec![(VarId(1), 0)],
935                body: 0,
936            }, // 1
937        ];
938        let expr = CoreExpr { nodes };
939        let mut heap = crate::heap::VecHeap::new();
940        let res = eval(&expr, &Env::new(), &mut heap);
941        assert!(matches!(res, Err(EvalError::InfiniteLoop(_))));
942    }
943
944    #[test]
945    fn test_letrec_mutual_recursion() {
946        // letrec { f = \a -> g a; g = \b -> b } in f 42
947        let nodes = vec![
948            CoreFrame::Var(VarId(11)),         // 0: a
949            CoreFrame::Var(VarId(2)),          // 1: g
950            CoreFrame::App { fun: 1, arg: 0 }, // 2: g a
951            CoreFrame::Lam {
952                binder: VarId(11),
953                body: 2,
954            }, // 3: \a -> g a (f)
955            CoreFrame::Var(VarId(12)),         // 4: b
956            CoreFrame::Lam {
957                binder: VarId(12),
958                body: 4,
959            }, // 5: \b -> b (g)
960            CoreFrame::Var(VarId(1)),          // 6: f
961            CoreFrame::Lit(Literal::LitInt(42)), // 7
962            CoreFrame::App { fun: 6, arg: 7 }, // 8: f 42
963            CoreFrame::LetRec {
964                bindings: vec![(VarId(1), 3), (VarId(2), 5)],
965                body: 8,
966            }, // 9
967        ];
968        let expr = CoreExpr { nodes };
969        let mut heap = crate::heap::VecHeap::new();
970        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
971        if let Value::Lit(Literal::LitInt(n)) = res {
972            assert_eq!(n, 42);
973        } else {
974            panic!("Expected LitInt(42)");
975        }
976    }
977
978    #[test]
979    fn test_eval_join_scoping() {
980        // let y = 100 in
981        // join j(x) = x + y in
982        // let y = 200 in
983        // jump j(1)
984        // Should be 101, not 201.
985        let nodes = vec![
986            CoreFrame::Lit(Literal::LitInt(100)), // 0
987            CoreFrame::Var(VarId(10)),            // 1: x
988            CoreFrame::Var(VarId(20)),            // 2: y (captured)
989            CoreFrame::PrimOp {
990                op: PrimOpKind::IntAdd,
991                args: vec![1, 2],
992            }, // 3: x + y
993            CoreFrame::Lit(Literal::LitInt(200)), // 4
994            CoreFrame::Lit(Literal::LitInt(1)),   // 5
995            CoreFrame::Jump {
996                label: JoinId(1),
997                args: vec![5],
998            }, // 6
999            CoreFrame::LetNonRec {
1000                binder: VarId(20),
1001                rhs: 4,
1002                body: 6,
1003            }, // 7: let y = 200 in jump j(1)
1004            CoreFrame::Join {
1005                label: JoinId(1),
1006                params: vec![VarId(10)],
1007                rhs: 3,
1008                body: 7,
1009            }, // 8: join j(x) = x+y in ...
1010            CoreFrame::LetNonRec {
1011                binder: VarId(20),
1012                rhs: 0,
1013                body: 8,
1014            }, // 9: let y = 100 in join ...
1015        ];
1016        let expr = CoreExpr { nodes };
1017        let mut heap = crate::heap::VecHeap::new();
1018        let res = eval(&expr, &Env::new(), &mut heap).unwrap();
1019        if let Value::Lit(Literal::LitInt(n)) = res {
1020            assert_eq!(n, 101);
1021        } else {
1022            panic!("Expected LitInt(101)");
1023        }
1024    }
1025
1026    #[test]
1027    fn test_thunk_poison_restoration() {
1028        // let x = <unbound> in x
1029        let nodes = vec![
1030            CoreFrame::Var(VarId(999)), // 0: unbound
1031            CoreFrame::LetNonRec {
1032                binder: VarId(1),
1033                rhs: 0,
1034                body: 0,
1035            }, // 1: let x = unbound in x
1036        ];
1037        let expr = CoreExpr { nodes };
1038        let mut heap = crate::heap::VecHeap::new();
1039
1040        // First force fails with UnboundVar
1041        let res1 = eval(&expr, &Env::new(), &mut heap);
1042        assert!(matches!(res1, Err(EvalError::UnboundVar(_))));
1043
1044        // Second force should ALSO fail with UnboundVar, NOT InfiniteLoop (BlackHole)
1045        let res2 = eval(&expr, &Env::new(), &mut heap);
1046        assert!(matches!(res2, Err(EvalError::UnboundVar(_))));
1047    }
1048
1049    #[test]
1050    fn test_eval_jump_arity_mismatch() {
1051        // join j(x) = x in jump j(1, 2)
1052        let nodes = vec![
1053            CoreFrame::Var(VarId(10)),          // 0: x
1054            CoreFrame::Lit(Literal::LitInt(1)), // 1
1055            CoreFrame::Lit(Literal::LitInt(2)), // 2
1056            CoreFrame::Jump {
1057                label: JoinId(1),
1058                args: vec![1, 2],
1059            }, // 3: jump j(1, 2)
1060            CoreFrame::Join {
1061                label: JoinId(1),
1062                params: vec![VarId(10)],
1063                rhs: 0,
1064                body: 3,
1065            }, // 4: join j(x) ...
1066        ];
1067        let expr = CoreExpr { nodes };
1068        let mut heap = crate::heap::VecHeap::new();
1069        let res = eval(&expr, &Env::new(), &mut heap);
1070        assert!(matches!(res, Err(EvalError::ArityMismatch { .. })));
1071    }
1072}