Skip to main content

tensorlogic_compiler/
bytecode.rs

1//! Stack-based bytecode VM for TensorLogic expressions.
2//!
3//! This module provides a compiler from [`TLExpr`] to a flat [`BytecodeProgram`]
4//! and a lightweight virtual machine that executes it. Repeated evaluation of
5//! compiled expressions is faster than recursive interpretation because the
6//! expression tree is only traversed once during compilation; subsequent
7//! executions only walk the flat instruction array.
8//!
9//! # Quick Start
10//!
11//! ```rust
12//! use tensorlogic_compiler::bytecode::{compile, execute, VmEnv, VmValue};
13//! use tensorlogic_ir::TLExpr;
14//!
15//! // Compile 2.0 + 3.0
16//! let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
17//! let program = compile(&expr).unwrap();
18//!
19//! let env = VmEnv::new();
20//! let result = execute(&program, &env).unwrap();
21//! assert_eq!(result, VmValue::Num(5.0));
22//! ```
23
24use std::collections::HashMap;
25use tensorlogic_ir::TLExpr;
26
27// ─────────────────────────────────────────────────────────────────────────────
28// Instruction set
29// ─────────────────────────────────────────────────────────────────────────────
30
31/// Stack-based VM instruction.
32///
33/// Instructions operate on a `Vec<VmValue>` stack and an instruction pointer.
34/// All binary operations pop two values and push one; all unary operations pop
35/// one and push one.
36#[derive(Debug, Clone, PartialEq)]
37pub enum Instruction {
38    // ── Stack management ──────────────────────────────────────────────────
39    /// Push a numeric constant onto the stack.
40    PushNum(f64),
41    /// Push a boolean constant onto the stack.
42    PushBool(bool),
43    /// Push a symbol literal onto the stack.
44    PushSym(String),
45    /// Discard the top-of-stack value.
46    Pop,
47    /// Duplicate the top-of-stack value.
48    Dup,
49
50    // ── Arithmetic (binary, both args must be Num) ────────────────────────
51    /// Pop b, pop a; push a + b.
52    Add,
53    /// Pop b, pop a; push a - b.
54    Sub,
55    /// Pop b, pop a; push a * b.
56    Mul,
57    /// Pop b, pop a; push a / b (error on zero divisor).
58    Div,
59    /// Pop b, pop a; push a ^ b.
60    Pow,
61    /// Pop b, pop a; push a % b.
62    Mod,
63    /// Pop a; push -a.
64    Neg,
65    /// Pop a; push |a|.
66    Abs,
67    /// Pop a; push √a.
68    Sqrt,
69    /// Pop a; push e^a.
70    Exp,
71    /// Pop a; push ln(a).
72    Log,
73    /// Pop b, pop a; push min(a, b).
74    Min,
75    /// Pop b, pop a; push max(a, b).
76    Max,
77
78    // ── Comparison (result pushed as Bool) ────────────────────────────────
79    /// Pop b, pop a; push a == b.
80    Eq,
81    /// Pop b, pop a; push a != b.
82    Ne,
83    /// Pop b, pop a; push a < b.
84    Lt,
85    /// Pop b, pop a; push a <= b.
86    Le,
87    /// Pop b, pop a; push a > b.
88    Gt,
89    /// Pop b, pop a; push a >= b.
90    Ge,
91
92    // ── Boolean logic ─────────────────────────────────────────────────────
93    /// Pop b, pop a; push a && b (both must be truthy-compatible).
94    And,
95    /// Pop b, pop a; push a || b.
96    Or,
97    /// Pop a; push !a.
98    Not,
99
100    // ── Control flow ──────────────────────────────────────────────────────
101    /// If TOS is falsy, jump to absolute instruction index; otherwise fall through.
102    /// TOS is consumed.
103    JumpIfFalse(usize),
104    /// If TOS is truthy, jump to absolute instruction index; otherwise fall through.
105    /// TOS is consumed.
106    JumpIfTrue(usize),
107    /// Unconditional jump to absolute instruction index.
108    Jump(usize),
109
110    // ── Variables ─────────────────────────────────────────────────────────
111    /// Push the value of the named variable from the execution environment.
112    LoadVar(String),
113    /// Pop TOS and bind it to the named variable in the execution environment.
114    StoreVar(String),
115
116    // ── Fuzzy operations ──────────────────────────────────────────────────
117    /// Product t-norm: pop b, pop a; push a * b.
118    TNorm,
119    /// Probabilistic sum t-conorm: pop b, pop a; push a + b - a*b.
120    TCoNorm,
121    /// Standard fuzzy NOT: pop a; push 1.0 - a.
122    FuzzyNot,
123
124    // ── Termination ───────────────────────────────────────────────────────
125    /// Stop execution; TOS is the result.
126    Halt,
127}
128
129// ─────────────────────────────────────────────────────────────────────────────
130// BytecodeProgram
131// ─────────────────────────────────────────────────────────────────────────────
132
133/// A compiled, flat sequence of [`Instruction`]s.
134#[derive(Debug, Clone)]
135pub struct BytecodeProgram {
136    /// The ordered list of instructions.
137    pub instructions: Vec<Instruction>,
138}
139
140impl Default for BytecodeProgram {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146impl BytecodeProgram {
147    /// Create an empty program.
148    pub fn new() -> Self {
149        Self {
150            instructions: Vec::new(),
151        }
152    }
153
154    /// Append an instruction and return its absolute index.
155    pub fn push(&mut self, instr: Instruction) -> usize {
156        let idx = self.instructions.len();
157        self.instructions.push(instr);
158        idx
159    }
160
161    /// Patch the jump target of the instruction at `idx`.
162    ///
163    /// Panics in debug builds if `idx` does not contain a jump instruction.
164    pub fn patch_jump(&mut self, idx: usize, target: usize) {
165        match &mut self.instructions[idx] {
166            Instruction::JumpIfFalse(t) | Instruction::JumpIfTrue(t) | Instruction::Jump(t) => {
167                *t = target;
168            }
169            other => {
170                debug_assert!(
171                    false,
172                    "patch_jump called on non-jump instruction: {:?}",
173                    other
174                );
175            }
176        }
177    }
178
179    /// Return the number of instructions in the program.
180    pub fn len(&self) -> usize {
181        self.instructions.len()
182    }
183
184    /// Return `true` if the program contains no instructions.
185    pub fn is_empty(&self) -> bool {
186        self.instructions.is_empty()
187    }
188}
189
190// ─────────────────────────────────────────────────────────────────────────────
191// VmValue
192// ─────────────────────────────────────────────────────────────────────────────
193
194/// A runtime value on the VM stack.
195#[derive(Debug, Clone, PartialEq)]
196pub enum VmValue {
197    /// A 64-bit floating-point number.
198    Num(f64),
199    /// A boolean flag.
200    Bool(bool),
201    /// A symbol (named constant), used in pattern matching.
202    Sym(String),
203}
204
205impl VmValue {
206    /// Extract the numeric payload or return a [`VmError::TypeMismatch`].
207    pub fn as_num(&self) -> Result<f64, VmError> {
208        match self {
209            VmValue::Num(n) => Ok(*n),
210            VmValue::Bool(_) => Err(VmError::TypeMismatch {
211                expected: "Num",
212                got: "Bool",
213            }),
214            VmValue::Sym(_) => Err(VmError::TypeMismatch {
215                expected: "Num",
216                got: "Sym",
217            }),
218        }
219    }
220
221    /// Extract the boolean payload or return a [`VmError::TypeMismatch`].
222    pub fn as_bool(&self) -> Result<bool, VmError> {
223        match self {
224            VmValue::Bool(b) => Ok(*b),
225            VmValue::Num(_) => Err(VmError::TypeMismatch {
226                expected: "Bool",
227                got: "Num",
228            }),
229            VmValue::Sym(_) => Err(VmError::TypeMismatch {
230                expected: "Bool",
231                got: "Sym",
232            }),
233        }
234    }
235
236    /// Numeric: non-zero is truthy; Boolean: direct value; Symbol: always truthy.
237    pub fn is_truthy(&self) -> bool {
238        match self {
239            VmValue::Num(n) => *n != 0.0,
240            VmValue::Bool(b) => *b,
241            VmValue::Sym(s) => !s.is_empty(),
242        }
243    }
244
245    /// Return a static type name string for error messages.
246    #[allow(dead_code)]
247    fn type_name(&self) -> &'static str {
248        match self {
249            VmValue::Num(_) => "Num",
250            VmValue::Bool(_) => "Bool",
251            VmValue::Sym(_) => "Sym",
252        }
253    }
254}
255
256// ─────────────────────────────────────────────────────────────────────────────
257// VmError
258// ─────────────────────────────────────────────────────────────────────────────
259
260/// Errors that can occur during VM execution.
261#[derive(Debug)]
262pub enum VmError {
263    /// A pop or peek was attempted on an empty stack.
264    StackUnderflow,
265    /// An operation received a value of an unexpected type.
266    TypeMismatch {
267        /// The type that was expected.
268        expected: &'static str,
269        /// The type that was actually present.
270        got: &'static str,
271    },
272    /// A `LoadVar` instruction referenced a name not present in the environment.
273    UnboundVariable(String),
274    /// A `Div` instruction was attempted with a zero denominator.
275    DivisionByZero,
276    /// The instruction pointer jumped outside the program bounds.
277    InvalidInstruction(usize),
278    /// `execute` was called with a program that contains no instructions.
279    ProgramEmpty,
280}
281
282impl std::fmt::Display for VmError {
283    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284        match self {
285            VmError::StackUnderflow => write!(f, "VM stack underflow"),
286            VmError::TypeMismatch { expected, got } => {
287                write!(f, "type mismatch: expected {}, got {}", expected, got)
288            }
289            VmError::UnboundVariable(name) => {
290                write!(f, "unbound variable: '{}'", name)
291            }
292            VmError::DivisionByZero => write!(f, "division by zero"),
293            VmError::InvalidInstruction(ip) => {
294                write!(f, "invalid instruction pointer: {}", ip)
295            }
296            VmError::ProgramEmpty => write!(f, "program contains no instructions"),
297        }
298    }
299}
300
301impl std::error::Error for VmError {}
302
303// ─────────────────────────────────────────────────────────────────────────────
304// VmEnv
305// ─────────────────────────────────────────────────────────────────────────────
306
307/// Variable environment passed to the VM at execution time.
308///
309/// The VM does **not** modify the caller's environment; `StoreVar` writes into
310/// a local clone that is discarded when execution ends.
311#[derive(Debug, Clone, Default)]
312pub struct VmEnv {
313    vars: HashMap<String, VmValue>,
314}
315
316impl VmEnv {
317    /// Create an empty environment.
318    pub fn new() -> Self {
319        Self {
320            vars: HashMap::new(),
321        }
322    }
323
324    /// Bind a variable to an arbitrary [`VmValue`].
325    pub fn set(&mut self, name: impl Into<String>, val: VmValue) {
326        self.vars.insert(name.into(), val);
327    }
328
329    /// Convenience helper: bind a variable to a numeric value.
330    pub fn set_num(&mut self, name: impl Into<String>, val: f64) {
331        self.set(name, VmValue::Num(val));
332    }
333
334    /// Convenience helper: bind a variable to a boolean value.
335    pub fn set_bool(&mut self, name: impl Into<String>, val: bool) {
336        self.set(name, VmValue::Bool(val));
337    }
338
339    /// Look up a variable by name.
340    pub fn get(&self, name: &str) -> Option<&VmValue> {
341        self.vars.get(name)
342    }
343
344    /// Return the number of bindings in the environment.
345    pub fn len(&self) -> usize {
346        self.vars.len()
347    }
348
349    /// Return `true` if the environment contains no bindings.
350    pub fn is_empty(&self) -> bool {
351        self.vars.is_empty()
352    }
353}
354
355// ─────────────────────────────────────────────────────────────────────────────
356// VmStats
357// ─────────────────────────────────────────────────────────────────────────────
358
359/// Execution statistics collected during a single VM run.
360#[derive(Debug, Default, Clone)]
361pub struct VmStats {
362    /// Total number of instructions dispatched (including the final `Halt`).
363    pub instructions_executed: usize,
364    /// The highest stack depth observed at any point during execution.
365    pub max_stack_depth: usize,
366    /// Number of conditional or unconditional jumps that were actually taken.
367    pub jumps_taken: usize,
368}
369
370// ─────────────────────────────────────────────────────────────────────────────
371// CompileError
372// ─────────────────────────────────────────────────────────────────────────────
373
374/// Errors that can occur during bytecode compilation.
375#[derive(Debug)]
376pub enum CompileError {
377    /// The expression contains a variant that the bytecode compiler does not
378    /// support (e.g. quantifiers, modal operators, lambda, …).
379    UnsupportedExpr(String),
380    /// The expression tree is deeper than the configured `max_depth` limit.
381    MaxDepthExceeded,
382}
383
384impl std::fmt::Display for CompileError {
385    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        match self {
387            CompileError::UnsupportedExpr(desc) => {
388                write!(f, "unsupported expression in bytecode compiler: {}", desc)
389            }
390            CompileError::MaxDepthExceeded => {
391                write!(f, "expression depth exceeds configured maximum")
392            }
393        }
394    }
395}
396
397impl std::error::Error for CompileError {}
398
399// ─────────────────────────────────────────────────────────────────────────────
400// Compiler internals
401// ─────────────────────────────────────────────────────────────────────────────
402
403/// Internal compilation state passed through recursive descent.
404struct Compiler {
405    program: BytecodeProgram,
406    max_depth: usize,
407}
408
409impl Compiler {
410    fn new(max_depth: usize) -> Self {
411        Self {
412            program: BytecodeProgram::new(),
413            max_depth,
414        }
415    }
416
417    /// Recursively compile `expr` into `self.program`, respecting `depth`.
418    fn compile_expr(&mut self, expr: &TLExpr, depth: usize) -> Result<(), CompileError> {
419        if depth > self.max_depth {
420            return Err(CompileError::MaxDepthExceeded);
421        }
422
423        match expr {
424            // ── Numeric literal ───────────────────────────────────────────
425            TLExpr::Constant(c) => {
426                self.program.push(Instruction::PushNum(*c));
427            }
428
429            // ── Zero-arity predicates are treated as variable loads ────────
430            TLExpr::Pred { name, args } if args.is_empty() => {
431                self.program.push(Instruction::LoadVar(name.clone()));
432            }
433
434            // ── Arithmetic ────────────────────────────────────────────────
435            TLExpr::Add(a, b) => {
436                self.compile_expr(a, depth + 1)?;
437                self.compile_expr(b, depth + 1)?;
438                self.program.push(Instruction::Add);
439            }
440            TLExpr::Sub(a, b) => {
441                self.compile_expr(a, depth + 1)?;
442                self.compile_expr(b, depth + 1)?;
443                self.program.push(Instruction::Sub);
444            }
445            TLExpr::Mul(a, b) => {
446                self.compile_expr(a, depth + 1)?;
447                self.compile_expr(b, depth + 1)?;
448                self.program.push(Instruction::Mul);
449            }
450            TLExpr::Div(a, b) => {
451                self.compile_expr(a, depth + 1)?;
452                self.compile_expr(b, depth + 1)?;
453                self.program.push(Instruction::Div);
454            }
455            TLExpr::Pow(a, b) => {
456                self.compile_expr(a, depth + 1)?;
457                self.compile_expr(b, depth + 1)?;
458                self.program.push(Instruction::Pow);
459            }
460            TLExpr::Mod(a, b) => {
461                self.compile_expr(a, depth + 1)?;
462                self.compile_expr(b, depth + 1)?;
463                self.program.push(Instruction::Mod);
464            }
465            TLExpr::Abs(a) => {
466                self.compile_expr(a, depth + 1)?;
467                self.program.push(Instruction::Abs);
468            }
469            TLExpr::Sqrt(a) => {
470                self.compile_expr(a, depth + 1)?;
471                self.program.push(Instruction::Sqrt);
472            }
473            TLExpr::Exp(a) => {
474                self.compile_expr(a, depth + 1)?;
475                self.program.push(Instruction::Exp);
476            }
477            TLExpr::Log(a) => {
478                self.compile_expr(a, depth + 1)?;
479                self.program.push(Instruction::Log);
480            }
481            TLExpr::Min(a, b) => {
482                self.compile_expr(a, depth + 1)?;
483                self.compile_expr(b, depth + 1)?;
484                self.program.push(Instruction::Min);
485            }
486            TLExpr::Max(a, b) => {
487                self.compile_expr(a, depth + 1)?;
488                self.compile_expr(b, depth + 1)?;
489                self.program.push(Instruction::Max);
490            }
491
492            // ── Comparison ────────────────────────────────────────────────
493            TLExpr::Eq(a, b) => {
494                self.compile_expr(a, depth + 1)?;
495                self.compile_expr(b, depth + 1)?;
496                self.program.push(Instruction::Eq);
497            }
498            TLExpr::Lt(a, b) => {
499                self.compile_expr(a, depth + 1)?;
500                self.compile_expr(b, depth + 1)?;
501                self.program.push(Instruction::Lt);
502            }
503            TLExpr::Gt(a, b) => {
504                self.compile_expr(a, depth + 1)?;
505                self.compile_expr(b, depth + 1)?;
506                self.program.push(Instruction::Gt);
507            }
508            TLExpr::Lte(a, b) => {
509                self.compile_expr(a, depth + 1)?;
510                self.compile_expr(b, depth + 1)?;
511                self.program.push(Instruction::Le);
512            }
513            TLExpr::Gte(a, b) => {
514                self.compile_expr(a, depth + 1)?;
515                self.compile_expr(b, depth + 1)?;
516                self.program.push(Instruction::Ge);
517            }
518
519            // ── Boolean logic with short-circuit jumps ────────────────────
520            //
521            // And(a, b):
522            //   compile(a)
523            //   JumpIfFalse(end)   ; pops a; if false pushes Bool(false) and jumps to end
524            //                      ; if true falls through (a consumed; stack unchanged)
525            //   compile(b)         ; result of b is the final value (a was truthy)
526            //   Not, Not           ; coerce numeric b to Bool if needed (identity on Bool)
527            //   end:
528            //
529            // Note: JumpIfFalse CONSUMES the condition value from the stack. When it
530            // does NOT jump (a is truthy) the stack is one entry shorter, so we only
531            // need b's result at the end. When it DOES jump it pushes Bool(false) at
532            // the target, so both paths leave exactly one value on the stack.
533            TLExpr::And(a, b) => {
534                self.compile_expr(a, depth + 1)?;
535                // Emit placeholder jump; we'll patch the target after compiling b.
536                let jump_idx = self.program.push(Instruction::JumpIfFalse(0));
537                // a was truthy (and consumed); compile b — its value is the And result.
538                self.compile_expr(b, depth + 1)?;
539                // Coerce b's result to Bool so both paths leave a Bool on the stack.
540                self.program.push(Instruction::Not);
541                self.program.push(Instruction::Not);
542                let end = self.program.len();
543                self.program.patch_jump(jump_idx, end);
544            }
545
546            // Or(a, b):
547            //   compile(a)
548            //   JumpIfTrue(end)    ; pops a; if true pushes Bool(true) and jumps to end
549            //   compile(b)         ; result of b is the final value (a was falsy)
550            //   Not, Not           ; coerce to Bool
551            //   end:
552            TLExpr::Or(a, b) => {
553                self.compile_expr(a, depth + 1)?;
554                let jump_idx = self.program.push(Instruction::JumpIfTrue(0));
555                self.compile_expr(b, depth + 1)?;
556                self.program.push(Instruction::Not);
557                self.program.push(Instruction::Not);
558                let end = self.program.len();
559                self.program.patch_jump(jump_idx, end);
560            }
561
562            TLExpr::Not(a) => {
563                self.compile_expr(a, depth + 1)?;
564                self.program.push(Instruction::Not);
565            }
566
567            // ── Conditional ───────────────────────────────────────────────
568            //
569            // IfThenElse(cond, t, f):
570            //   compile(cond)
571            //   JumpIfFalse(else_branch)
572            //   compile(t)
573            //   Jump(end)
574            //   else_branch: compile(f)
575            //   end:
576            TLExpr::IfThenElse {
577                condition,
578                then_branch,
579                else_branch,
580            } => {
581                self.compile_expr(condition, depth + 1)?;
582                let jf_idx = self.program.push(Instruction::JumpIfFalse(0));
583                self.compile_expr(then_branch, depth + 1)?;
584                let jump_idx = self.program.push(Instruction::Jump(0));
585                // patch the JumpIfFalse to point here
586                let else_start = self.program.len();
587                self.program.patch_jump(jf_idx, else_start);
588                self.compile_expr(else_branch, depth + 1)?;
589                let end = self.program.len();
590                self.program.patch_jump(jump_idx, end);
591            }
592
593            // ── Let binding ───────────────────────────────────────────────
594            //
595            // Let { var, value, body }:
596            //   compile(value)
597            //   StoreVar(var)
598            //   compile(body)
599            TLExpr::Let { var, value, body } => {
600                self.compile_expr(value, depth + 1)?;
601                self.program.push(Instruction::StoreVar(var.clone()));
602                self.compile_expr(body, depth + 1)?;
603            }
604
605            // ── Fuzzy operations ──────────────────────────────────────────
606            //
607            // The bytecode VM uses a single product t-norm / probabilistic-sum
608            // t-conorm regardless of the `kind` tag.  More sophisticated
609            // dispatch can be added later without changing the instruction set.
610            TLExpr::TNorm { left, right, .. } => {
611                self.compile_expr(left, depth + 1)?;
612                self.compile_expr(right, depth + 1)?;
613                self.program.push(Instruction::TNorm);
614            }
615            TLExpr::TCoNorm { left, right, .. } => {
616                self.compile_expr(left, depth + 1)?;
617                self.compile_expr(right, depth + 1)?;
618                self.program.push(Instruction::TCoNorm);
619            }
620            TLExpr::FuzzyNot { expr: inner, .. } => {
621                self.compile_expr(inner, depth + 1)?;
622                self.program.push(Instruction::FuzzyNot);
623            }
624
625            // ── Symbol literal ────────────────────────────────────────────
626            TLExpr::SymbolLiteral(s) => {
627                self.program.push(Instruction::PushSym(s.clone()));
628            }
629
630            // ── Pattern matching ──────────────────────────────────────────
631            //
632            // Lower `Match { scrutinee, arms }` to nested IfThenElse at the
633            // bytecode level.  The scrutinee is compiled once and stored in a
634            // fresh temporary variable to avoid re-evaluation.
635            TLExpr::Match { scrutinee, arms } => {
636                if arms.is_empty() {
637                    return Err(CompileError::UnsupportedExpr(
638                        "Match with no arms".to_string(),
639                    ));
640                }
641                // Store scrutinee in a fresh temp.
642                self.compile_expr(scrutinee, depth + 1)?;
643                let tmp = format!("__match_scrutinee_{depth}");
644                self.program.push(Instruction::StoreVar(tmp.clone()));
645
646                // Build nested IfThenElse from the arms (last arm = wildcard).
647                let (wildcard_body, non_wildcard) = arms
648                    .split_last()
649                    .ok_or_else(|| CompileError::UnsupportedExpr("Empty Match arms".into()))?;
650
651                // Inline compile the chain: iterate non-wildcard arms in
652                // reverse, wrapping around the accumulated else-branch.
653                // We achieve this by emitting the structure directly.
654                self.emit_match_chain(&tmp, non_wildcard, &wildcard_body.1, depth)?;
655            }
656
657            // ── Unsupported variants ──────────────────────────────────────
658            other => {
659                return Err(CompileError::UnsupportedExpr(format!("{:?}", other)));
660            }
661        }
662
663        Ok(())
664    }
665
666    /// Emit a chain of conditional jumps for the non-wildcard arms of a Match.
667    ///
668    /// The scrutinee has already been stored in `scrutinee_var`.
669    /// `arms` is the slice of (Pattern, body) excluding the wildcard tail.
670    /// `else_body` is the wildcard-arm body.
671    fn emit_match_chain(
672        &mut self,
673        scrutinee_var: &str,
674        arms: &[(tensorlogic_ir::MatchPattern, Box<TLExpr>)],
675        else_body: &TLExpr,
676        depth: usize,
677    ) -> Result<(), CompileError> {
678        if arms.is_empty() {
679            // Only wildcard — compile the else body directly.
680            return self.compile_expr(else_body, depth + 1);
681        }
682
683        // Emit current arm: if scrutinee == rhs { body } else { rest }
684        let (pat, body) = &arms[0];
685        let remaining = &arms[1..];
686
687        // Condition: load scrutinee, push rhs, compare.
688        self.program
689            .push(Instruction::LoadVar(scrutinee_var.to_string()));
690        match pat {
691            tensorlogic_ir::MatchPattern::ConstNumber(n) => {
692                self.program.push(Instruction::PushNum(*n));
693            }
694            tensorlogic_ir::MatchPattern::ConstSymbol(s) => {
695                self.program.push(Instruction::PushSym(s.clone()));
696            }
697            tensorlogic_ir::MatchPattern::Wildcard => {
698                return Err(CompileError::UnsupportedExpr(
699                    "Wildcard in non-tail position".into(),
700                ));
701            }
702        }
703        self.program.push(Instruction::Eq);
704
705        // JumpIfFalse → else branch
706        let jf_idx = self.program.push(Instruction::JumpIfFalse(0));
707        // Then branch: compile arm body
708        self.compile_expr(body, depth + 1)?;
709        // Jump past else
710        let jump_idx = self.program.push(Instruction::Jump(0));
711        // Else branch:
712        let else_start = self.program.len();
713        self.program.patch_jump(jf_idx, else_start);
714        // Recurse for remaining arms
715        self.emit_match_chain(scrutinee_var, remaining, else_body, depth)?;
716        let end = self.program.len();
717        self.program.patch_jump(jump_idx, end);
718
719        Ok(())
720    }
721}
722
723// ─────────────────────────────────────────────────────────────────────────────
724// Public compile API
725// ─────────────────────────────────────────────────────────────────────────────
726
727/// Default maximum expression depth allowed during compilation.
728pub const DEFAULT_MAX_DEPTH: usize = 512;
729
730/// Compile a [`TLExpr`] to a [`BytecodeProgram`].
731///
732/// The final instruction appended is always [`Instruction::Halt`].
733///
734/// # Errors
735///
736/// Returns [`CompileError::UnsupportedExpr`] when the expression tree contains
737/// variants that cannot be translated to bytecode (e.g. quantifiers, modal
738/// operators, lambda abstractions).
739///
740/// Returns [`CompileError::MaxDepthExceeded`] if the default depth limit
741/// ([`DEFAULT_MAX_DEPTH`]) is surpassed.
742pub fn compile(expr: &TLExpr) -> Result<BytecodeProgram, CompileError> {
743    compile_with_config(expr, DEFAULT_MAX_DEPTH)
744}
745
746/// Compile a [`TLExpr`] to a [`BytecodeProgram`] with an explicit depth limit.
747///
748/// `max_depth` controls how deeply the compiler will recurse into nested
749/// sub-expressions before emitting [`CompileError::MaxDepthExceeded`].
750pub fn compile_with_config(
751    expr: &TLExpr,
752    max_depth: usize,
753) -> Result<BytecodeProgram, CompileError> {
754    let mut compiler = Compiler::new(max_depth);
755    compiler.compile_expr(expr, 0)?;
756    compiler.program.push(Instruction::Halt);
757    Ok(compiler.program)
758}
759
760// ─────────────────────────────────────────────────────────────────────────────
761// VM executor
762// ─────────────────────────────────────────────────────────────────────────────
763
764/// Execute a [`BytecodeProgram`] and return the top-of-stack value after `Halt`.
765///
766/// A mutable local copy of `env` is used so that [`Instruction::StoreVar`]
767/// does not modify the caller's environment.
768pub fn execute(program: &BytecodeProgram, env: &VmEnv) -> Result<VmValue, VmError> {
769    let (val, _stats) = execute_with_stats(program, env)?;
770    Ok(val)
771}
772
773/// Execute a [`BytecodeProgram`] and return both the result and execution statistics.
774pub fn execute_with_stats(
775    program: &BytecodeProgram,
776    env: &VmEnv,
777) -> Result<(VmValue, VmStats), VmError> {
778    if program.is_empty() {
779        return Err(VmError::ProgramEmpty);
780    }
781
782    let mut stack: Vec<VmValue> = Vec::with_capacity(16);
783    // Local mutable copy so StoreVar doesn't mutate the caller's env.
784    let mut local_env = env.clone();
785    let mut ip: usize = 0;
786    let mut stats = VmStats::default();
787
788    loop {
789        if ip >= program.instructions.len() {
790            return Err(VmError::InvalidInstruction(ip));
791        }
792
793        let instr = &program.instructions[ip];
794        stats.instructions_executed += 1;
795
796        match instr {
797            // ── Stack management ─────────────────────────────────────────
798            Instruction::PushNum(n) => {
799                stack.push(VmValue::Num(*n));
800                ip += 1;
801            }
802            Instruction::PushBool(b) => {
803                stack.push(VmValue::Bool(*b));
804                ip += 1;
805            }
806            Instruction::PushSym(s) => {
807                stack.push(VmValue::Sym(s.clone()));
808                ip += 1;
809            }
810            Instruction::Pop => {
811                stack.pop().ok_or(VmError::StackUnderflow)?;
812                ip += 1;
813            }
814            Instruction::Dup => {
815                let top = stack.last().ok_or(VmError::StackUnderflow)?.clone();
816                stack.push(top);
817                ip += 1;
818            }
819
820            // ── Arithmetic ───────────────────────────────────────────────
821            Instruction::Add => {
822                let b = pop_num(&mut stack)?;
823                let a = pop_num(&mut stack)?;
824                stack.push(VmValue::Num(a + b));
825                ip += 1;
826            }
827            Instruction::Sub => {
828                let b = pop_num(&mut stack)?;
829                let a = pop_num(&mut stack)?;
830                stack.push(VmValue::Num(a - b));
831                ip += 1;
832            }
833            Instruction::Mul => {
834                let b = pop_num(&mut stack)?;
835                let a = pop_num(&mut stack)?;
836                stack.push(VmValue::Num(a * b));
837                ip += 1;
838            }
839            Instruction::Div => {
840                let b = pop_num(&mut stack)?;
841                let a = pop_num(&mut stack)?;
842                if b == 0.0 {
843                    return Err(VmError::DivisionByZero);
844                }
845                stack.push(VmValue::Num(a / b));
846                ip += 1;
847            }
848            Instruction::Pow => {
849                let b = pop_num(&mut stack)?;
850                let a = pop_num(&mut stack)?;
851                stack.push(VmValue::Num(a.powf(b)));
852                ip += 1;
853            }
854            Instruction::Mod => {
855                let b = pop_num(&mut stack)?;
856                let a = pop_num(&mut stack)?;
857                stack.push(VmValue::Num(a % b));
858                ip += 1;
859            }
860            Instruction::Neg => {
861                let a = pop_num(&mut stack)?;
862                stack.push(VmValue::Num(-a));
863                ip += 1;
864            }
865            Instruction::Abs => {
866                let a = pop_num(&mut stack)?;
867                stack.push(VmValue::Num(a.abs()));
868                ip += 1;
869            }
870            Instruction::Sqrt => {
871                let a = pop_num(&mut stack)?;
872                stack.push(VmValue::Num(a.sqrt()));
873                ip += 1;
874            }
875            Instruction::Exp => {
876                let a = pop_num(&mut stack)?;
877                stack.push(VmValue::Num(a.exp()));
878                ip += 1;
879            }
880            Instruction::Log => {
881                let a = pop_num(&mut stack)?;
882                stack.push(VmValue::Num(a.ln()));
883                ip += 1;
884            }
885            Instruction::Min => {
886                let b = pop_num(&mut stack)?;
887                let a = pop_num(&mut stack)?;
888                stack.push(VmValue::Num(a.min(b)));
889                ip += 1;
890            }
891            Instruction::Max => {
892                let b = pop_num(&mut stack)?;
893                let a = pop_num(&mut stack)?;
894                stack.push(VmValue::Num(a.max(b)));
895                ip += 1;
896            }
897
898            // ── Comparison ───────────────────────────────────────────────
899            Instruction::Eq => {
900                let b = pop_value(&mut stack)?;
901                let a = pop_value(&mut stack)?;
902                stack.push(VmValue::Bool(values_equal(&a, &b)));
903                ip += 1;
904            }
905            Instruction::Ne => {
906                let b = pop_value(&mut stack)?;
907                let a = pop_value(&mut stack)?;
908                stack.push(VmValue::Bool(!values_equal(&a, &b)));
909                ip += 1;
910            }
911            Instruction::Lt => {
912                let b = pop_num(&mut stack)?;
913                let a = pop_num(&mut stack)?;
914                stack.push(VmValue::Bool(a < b));
915                ip += 1;
916            }
917            Instruction::Le => {
918                let b = pop_num(&mut stack)?;
919                let a = pop_num(&mut stack)?;
920                stack.push(VmValue::Bool(a <= b));
921                ip += 1;
922            }
923            Instruction::Gt => {
924                let b = pop_num(&mut stack)?;
925                let a = pop_num(&mut stack)?;
926                stack.push(VmValue::Bool(a > b));
927                ip += 1;
928            }
929            Instruction::Ge => {
930                let b = pop_num(&mut stack)?;
931                let a = pop_num(&mut stack)?;
932                stack.push(VmValue::Bool(a >= b));
933                ip += 1;
934            }
935
936            // ── Boolean logic ─────────────────────────────────────────────
937            Instruction::And => {
938                let b = pop_value(&mut stack)?;
939                let a = pop_value(&mut stack)?;
940                stack.push(VmValue::Bool(a.is_truthy() && b.is_truthy()));
941                ip += 1;
942            }
943            Instruction::Or => {
944                let b = pop_value(&mut stack)?;
945                let a = pop_value(&mut stack)?;
946                stack.push(VmValue::Bool(a.is_truthy() || b.is_truthy()));
947                ip += 1;
948            }
949            Instruction::Not => {
950                let a = pop_value(&mut stack)?;
951                stack.push(VmValue::Bool(!a.is_truthy()));
952                ip += 1;
953            }
954
955            // ── Control flow ──────────────────────────────────────────────
956            Instruction::JumpIfFalse(target) => {
957                let target = *target;
958                let cond = pop_value(&mut stack)?;
959                if !cond.is_truthy() {
960                    // Push a Bool(false) so the result is still on the stack at the
961                    // jump target — the caller is responsible for having set up the
962                    // stack correctly, but we preserve the false value here so that
963                    // short-circuit And returns the correct result.
964                    stack.push(VmValue::Bool(false));
965                    ip = target;
966                    stats.jumps_taken += 1;
967                } else {
968                    ip += 1;
969                }
970            }
971            Instruction::JumpIfTrue(target) => {
972                let target = *target;
973                let cond = pop_value(&mut stack)?;
974                if cond.is_truthy() {
975                    // Similarly preserve the true value for short-circuit Or.
976                    stack.push(VmValue::Bool(true));
977                    ip = target;
978                    stats.jumps_taken += 1;
979                } else {
980                    ip += 1;
981                }
982            }
983            Instruction::Jump(target) => {
984                ip = *target;
985                stats.jumps_taken += 1;
986            }
987
988            // ── Variables ─────────────────────────────────────────────────
989            Instruction::LoadVar(name) => {
990                let val = local_env
991                    .get(name)
992                    .ok_or_else(|| VmError::UnboundVariable(name.clone()))?
993                    .clone();
994                stack.push(val);
995                ip += 1;
996            }
997            Instruction::StoreVar(name) => {
998                let val = pop_value(&mut stack)?;
999                local_env.set(name.clone(), val);
1000                ip += 1;
1001            }
1002
1003            // ── Fuzzy operations ──────────────────────────────────────────
1004            Instruction::TNorm => {
1005                let b = pop_num(&mut stack)?;
1006                let a = pop_num(&mut stack)?;
1007                // Product t-norm: T(a, b) = a * b
1008                stack.push(VmValue::Num(a * b));
1009                ip += 1;
1010            }
1011            Instruction::TCoNorm => {
1012                let b = pop_num(&mut stack)?;
1013                let a = pop_num(&mut stack)?;
1014                // Probabilistic sum: S(a, b) = a + b - a*b
1015                stack.push(VmValue::Num(a + b - a * b));
1016                ip += 1;
1017            }
1018            Instruction::FuzzyNot => {
1019                let a = pop_num(&mut stack)?;
1020                // Standard fuzzy NOT: N(a) = 1 - a
1021                stack.push(VmValue::Num(1.0 - a));
1022                ip += 1;
1023            }
1024
1025            // ── Termination ───────────────────────────────────────────────
1026            Instruction::Halt => {
1027                let result = stack.pop().ok_or(VmError::StackUnderflow)?;
1028                // Update final stats
1029                if stats.max_stack_depth < stack.len() + 1 {
1030                    stats.max_stack_depth = stack.len() + 1;
1031                }
1032                return Ok((result, stats));
1033            }
1034        }
1035
1036        // Track maximum stack depth after each instruction.
1037        if stack.len() > stats.max_stack_depth {
1038            stats.max_stack_depth = stack.len();
1039        }
1040    }
1041}
1042
1043// ─────────────────────────────────────────────────────────────────────────────
1044// Private helpers
1045// ─────────────────────────────────────────────────────────────────────────────
1046
1047/// Pop the top-of-stack value (any type).
1048#[inline]
1049fn pop_value(stack: &mut Vec<VmValue>) -> Result<VmValue, VmError> {
1050    stack.pop().ok_or(VmError::StackUnderflow)
1051}
1052
1053/// Pop the top-of-stack value and coerce it to `f64`.
1054#[inline]
1055fn pop_num(stack: &mut Vec<VmValue>) -> Result<f64, VmError> {
1056    let val = stack.pop().ok_or(VmError::StackUnderflow)?;
1057    match val {
1058        VmValue::Num(n) => Ok(n),
1059        VmValue::Bool(_) => Err(VmError::TypeMismatch {
1060            expected: "Num",
1061            got: "Bool",
1062        }),
1063        VmValue::Sym(_) => Err(VmError::TypeMismatch {
1064            expected: "Num",
1065            got: "Sym",
1066        }),
1067    }
1068}
1069
1070/// Value equality that is aware of the two possible [`VmValue`] variants.
1071#[inline]
1072fn values_equal(a: &VmValue, b: &VmValue) -> bool {
1073    match (a, b) {
1074        (VmValue::Num(x), VmValue::Num(y)) => x == y,
1075        (VmValue::Bool(x), VmValue::Bool(y)) => x == y,
1076        (VmValue::Sym(x), VmValue::Sym(y)) => x == y,
1077        _ => false,
1078    }
1079}
1080
1081// ─────────────────────────────────────────────────────────────────────────────
1082// Tests
1083// ─────────────────────────────────────────────────────────────────────────────
1084
1085#[cfg(test)]
1086mod tests {
1087    use super::*;
1088    use tensorlogic_ir::{FuzzyNegationKind, TCoNormKind, TLExpr, TNormKind};
1089
1090    // ── Helper: compile + execute in one step ──────────────────────────────
1091    fn eval(expr: TLExpr) -> VmValue {
1092        let prog = compile(&expr).expect("compile failed");
1093        let env = VmEnv::new();
1094        execute(&prog, &env).expect("execute failed")
1095    }
1096
1097    fn eval_env(expr: TLExpr, env: &VmEnv) -> VmValue {
1098        let prog = compile(&expr).expect("compile failed");
1099        execute(&prog, env).expect("execute failed")
1100    }
1101
1102    // ── 1. Constant compile shape ──────────────────────────────────────────
1103    #[test]
1104    fn test_compile_constant_shape() {
1105        let val = std::f64::consts::PI;
1106        let prog = compile(&TLExpr::Constant(val)).expect("compile failed");
1107        assert_eq!(prog.len(), 2, "should be [PushNum(PI), Halt]");
1108        assert_eq!(prog.instructions[0], Instruction::PushNum(val));
1109        assert_eq!(prog.instructions[1], Instruction::Halt);
1110    }
1111
1112    // ── 2. Execute single PushNum ──────────────────────────────────────────
1113    #[test]
1114    fn test_execute_push_num() {
1115        let mut prog = BytecodeProgram::new();
1116        prog.push(Instruction::PushNum(5.0));
1117        prog.push(Instruction::Halt);
1118        let env = VmEnv::new();
1119        let result = execute(&prog, &env).expect("execute failed");
1120        assert_eq!(result, VmValue::Num(5.0));
1121    }
1122
1123    // ── 3. Add ────────────────────────────────────────────────────────────
1124    #[test]
1125    fn test_add() {
1126        let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
1127        assert_eq!(eval(expr), VmValue::Num(5.0));
1128    }
1129
1130    // ── 4. Sub ────────────────────────────────────────────────────────────
1131    #[test]
1132    fn test_sub() {
1133        let expr = TLExpr::sub(TLExpr::Constant(10.0), TLExpr::Constant(4.0));
1134        assert_eq!(eval(expr), VmValue::Num(6.0));
1135    }
1136
1137    // ── 5. Mul ────────────────────────────────────────────────────────────
1138    #[test]
1139    fn test_mul() {
1140        let expr = TLExpr::mul(TLExpr::Constant(3.0), TLExpr::Constant(4.0));
1141        assert_eq!(eval(expr), VmValue::Num(12.0));
1142    }
1143
1144    // ── 6. Div ────────────────────────────────────────────────────────────
1145    #[test]
1146    fn test_div() {
1147        let expr = TLExpr::div(TLExpr::Constant(10.0), TLExpr::Constant(2.0));
1148        assert_eq!(eval(expr), VmValue::Num(5.0));
1149    }
1150
1151    // ── 7. Pow ────────────────────────────────────────────────────────────
1152    #[test]
1153    fn test_pow() {
1154        let expr = TLExpr::pow(TLExpr::Constant(2.0), TLExpr::Constant(8.0));
1155        assert_eq!(eval(expr), VmValue::Num(256.0));
1156    }
1157
1158    // ── 8. Eq true ────────────────────────────────────────────────────────
1159    #[test]
1160    fn test_eq_true() {
1161        let expr = TLExpr::eq(TLExpr::Constant(3.0), TLExpr::Constant(3.0));
1162        assert_eq!(eval(expr), VmValue::Bool(true));
1163    }
1164
1165    // ── 9. Lt true ────────────────────────────────────────────────────────
1166    #[test]
1167    fn test_lt_true() {
1168        let expr = TLExpr::lt(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
1169        assert_eq!(eval(expr), VmValue::Bool(true));
1170    }
1171
1172    // ── 10. And false ─────────────────────────────────────────────────────
1173    #[test]
1174    fn test_and_false() {
1175        let expr = TLExpr::and(
1176            TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
1177            TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)),
1178        );
1179        assert_eq!(eval(expr), VmValue::Bool(false));
1180    }
1181
1182    // ── 11. Or true ───────────────────────────────────────────────────────
1183    #[test]
1184    fn test_or_true() {
1185        let expr = TLExpr::or(
1186            TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)),
1187            TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
1188        );
1189        assert_eq!(eval(expr), VmValue::Bool(true));
1190    }
1191
1192    // ── 12. Not false → true ──────────────────────────────────────────────
1193    #[test]
1194    fn test_not_false_to_true() {
1195        let expr = TLExpr::negate(TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)));
1196        assert_eq!(eval(expr), VmValue::Bool(true));
1197    }
1198
1199    // ── 13. Short-circuit And: jump taken when first arg is false ─────────
1200    #[test]
1201    fn test_short_circuit_and_jump() {
1202        // First argument is false → second should be skipped entirely.
1203        let expr = TLExpr::and(
1204            TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)), // false
1205            TLExpr::eq(TLExpr::Constant(3.0), TLExpr::Constant(3.0)), // true (never reached)
1206        );
1207        let prog = compile(&expr).expect("compile failed");
1208        let env = VmEnv::new();
1209        let (result, stats) = execute_with_stats(&prog, &env).expect("execute failed");
1210        assert_eq!(result, VmValue::Bool(false));
1211        assert!(stats.jumps_taken > 0, "JumpIfFalse should have been taken");
1212    }
1213
1214    // ── 14. Short-circuit Or: jump taken when first arg is true ───────────
1215    #[test]
1216    fn test_short_circuit_or_jump() {
1217        let expr = TLExpr::or(
1218            TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(1.0)), // true
1219            TLExpr::eq(TLExpr::Constant(3.0), TLExpr::Constant(4.0)), // false (never reached)
1220        );
1221        let prog = compile(&expr).expect("compile failed");
1222        let env = VmEnv::new();
1223        let (result, stats) = execute_with_stats(&prog, &env).expect("execute failed");
1224        assert_eq!(result, VmValue::Bool(true));
1225        assert!(stats.jumps_taken > 0, "JumpIfTrue should have been taken");
1226    }
1227
1228    // ── 15. IfThenElse true branch ────────────────────────────────────────
1229    #[test]
1230    fn test_ite_true_branch() {
1231        let expr = TLExpr::if_then_else(
1232            TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
1233            TLExpr::Constant(1.0),
1234            TLExpr::Constant(2.0),
1235        );
1236        assert_eq!(eval(expr), VmValue::Num(1.0));
1237    }
1238
1239    // ── 16. IfThenElse false branch ───────────────────────────────────────
1240    #[test]
1241    fn test_ite_false_branch() {
1242        let expr = TLExpr::if_then_else(
1243            TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)),
1244            TLExpr::Constant(1.0),
1245            TLExpr::Constant(2.0),
1246        );
1247        assert_eq!(eval(expr), VmValue::Num(2.0));
1248    }
1249
1250    // ── 17. LoadVar retrieves value from VmEnv ────────────────────────────
1251    #[test]
1252    fn test_load_var() {
1253        let expr = TLExpr::pred("x", vec![]);
1254        let mut env = VmEnv::new();
1255        env.set_num("x", 42.0);
1256        assert_eq!(eval_env(expr, &env), VmValue::Num(42.0));
1257    }
1258
1259    // ── 18. Let binding roundtrip ─────────────────────────────────────────
1260    #[test]
1261    fn test_let_binding() {
1262        // let y = 7.0 in y * 2.0
1263        let expr = TLExpr::Let {
1264            var: "y".to_string(),
1265            value: Box::new(TLExpr::Constant(7.0)),
1266            body: Box::new(TLExpr::mul(
1267                TLExpr::pred("y", vec![]),
1268                TLExpr::Constant(2.0),
1269            )),
1270        };
1271        let env = VmEnv::new();
1272        assert_eq!(eval_env(expr, &env), VmValue::Num(14.0));
1273    }
1274
1275    // ── 19. VmError::StackUnderflow ───────────────────────────────────────
1276    #[test]
1277    fn test_stack_underflow() {
1278        let mut prog = BytecodeProgram::new();
1279        prog.push(Instruction::Add); // no operands
1280        prog.push(Instruction::Halt);
1281        let env = VmEnv::new();
1282        let err = execute(&prog, &env).unwrap_err();
1283        assert!(
1284            matches!(err, VmError::StackUnderflow),
1285            "expected StackUnderflow, got {:?}",
1286            err
1287        );
1288    }
1289
1290    // ── 20. VmError::UnboundVariable ──────────────────────────────────────
1291    #[test]
1292    fn test_unbound_variable() {
1293        let mut prog = BytecodeProgram::new();
1294        prog.push(Instruction::LoadVar("missing".to_string()));
1295        prog.push(Instruction::Halt);
1296        let env = VmEnv::new();
1297        let err = execute(&prog, &env).unwrap_err();
1298        assert!(
1299            matches!(err, VmError::UnboundVariable(_)),
1300            "expected UnboundVariable, got {:?}",
1301            err
1302        );
1303    }
1304
1305    // ── 21. VmStats.instructions_executed > 0 ─────────────────────────────
1306    #[test]
1307    fn test_stats_instructions_executed() {
1308        let expr = TLExpr::Constant(1.0);
1309        let prog = compile(&expr).expect("compile failed");
1310        let env = VmEnv::new();
1311        let (_val, stats) = execute_with_stats(&prog, &env).expect("execute failed");
1312        assert!(stats.instructions_executed > 0);
1313    }
1314
1315    // ── 22. VmStats.max_stack_depth = 1 for simple push+halt ──────────────
1316    #[test]
1317    fn test_stats_max_stack_depth_single_push() {
1318        let mut prog = BytecodeProgram::new();
1319        prog.push(Instruction::PushNum(99.0));
1320        prog.push(Instruction::Halt);
1321        let env = VmEnv::new();
1322        let (_val, stats) = execute_with_stats(&prog, &env).expect("execute failed");
1323        assert_eq!(stats.max_stack_depth, 1, "single push should give depth 1");
1324    }
1325
1326    // ── 23. TNorm(0.5, 0.5) = 0.25 ───────────────────────────────────────
1327    #[test]
1328    fn test_tnorm_product() {
1329        let expr = TLExpr::TNorm {
1330            kind: TNormKind::Product,
1331            left: Box::new(TLExpr::Constant(0.5)),
1332            right: Box::new(TLExpr::Constant(0.5)),
1333        };
1334        let result = eval(expr);
1335        match result {
1336            VmValue::Num(n) => {
1337                assert!((n - 0.25).abs() < 1e-10, "expected 0.25, got {}", n);
1338            }
1339            _ => panic!("expected Num, got {:?}", result),
1340        }
1341    }
1342
1343    // ── 24. FuzzyNot(0.3) = 0.7 ──────────────────────────────────────────
1344    #[test]
1345    fn test_fuzzy_not() {
1346        let expr = TLExpr::FuzzyNot {
1347            kind: FuzzyNegationKind::Standard,
1348            expr: Box::new(TLExpr::Constant(0.3)),
1349        };
1350        let result = eval(expr);
1351        match result {
1352            VmValue::Num(n) => {
1353                assert!((n - 0.7).abs() < 1e-10, "expected 0.7, got {}", n);
1354            }
1355            _ => panic!("expected Num, got {:?}", result),
1356        }
1357    }
1358
1359    // ── Bonus 25. TCoNorm(0.5, 0.5) = 0.75 ───────────────────────────────
1360    #[test]
1361    fn test_tconorm() {
1362        let expr = TLExpr::TCoNorm {
1363            kind: TCoNormKind::ProbabilisticSum,
1364            left: Box::new(TLExpr::Constant(0.5)),
1365            right: Box::new(TLExpr::Constant(0.5)),
1366        };
1367        let result = eval(expr);
1368        match result {
1369            VmValue::Num(n) => {
1370                // 0.5 + 0.5 - 0.5*0.5 = 0.75
1371                assert!((n - 0.75).abs() < 1e-10, "expected 0.75, got {}", n);
1372            }
1373            _ => panic!("expected Num, got {:?}", result),
1374        }
1375    }
1376
1377    // ── Bonus 26. Nested arithmetic depth ─────────────────────────────────
1378    #[test]
1379    fn test_nested_arithmetic() {
1380        // (1 + 2) * (3 + 4) = 21
1381        let expr = TLExpr::mul(
1382            TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0)),
1383            TLExpr::add(TLExpr::Constant(3.0), TLExpr::Constant(4.0)),
1384        );
1385        assert_eq!(eval(expr), VmValue::Num(21.0));
1386    }
1387
1388    // ── Bonus 27. DivisionByZero error ────────────────────────────────────
1389    #[test]
1390    fn test_division_by_zero() {
1391        let mut prog = BytecodeProgram::new();
1392        prog.push(Instruction::PushNum(1.0));
1393        prog.push(Instruction::PushNum(0.0));
1394        prog.push(Instruction::Div);
1395        prog.push(Instruction::Halt);
1396        let env = VmEnv::new();
1397        let err = execute(&prog, &env).unwrap_err();
1398        assert!(
1399            matches!(err, VmError::DivisionByZero),
1400            "expected DivisionByZero, got {:?}",
1401            err
1402        );
1403    }
1404
1405    // ── Bonus 28. Abs of negative number ──────────────────────────────────
1406    #[test]
1407    fn test_abs() {
1408        let expr = TLExpr::Abs(Box::new(TLExpr::Constant(-5.0)));
1409        assert_eq!(eval(expr), VmValue::Num(5.0));
1410    }
1411
1412    // ── Bonus 29. Compile unsupported expr returns error ──────────────────
1413    #[test]
1414    fn test_compile_unsupported_forall() {
1415        use tensorlogic_ir::Term;
1416        let expr = TLExpr::forall("x", "D", TLExpr::pred("P", vec![Term::var("x")]));
1417        let err = compile(&expr).unwrap_err();
1418        assert!(
1419            matches!(err, CompileError::UnsupportedExpr(_)),
1420            "expected UnsupportedExpr, got {:?}",
1421            err
1422        );
1423    }
1424
1425    // ── Bonus 30. Max depth exceeded ──────────────────────────────────────
1426    #[test]
1427    fn test_max_depth_exceeded() {
1428        // Build a deeply nested Add expression that exceeds depth 2.
1429        let inner = TLExpr::add(
1430            TLExpr::add(
1431                TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
1432                TLExpr::Constant(1.0),
1433            ),
1434            TLExpr::Constant(1.0),
1435        );
1436        let err = compile_with_config(&inner, 1).unwrap_err();
1437        assert!(
1438            matches!(err, CompileError::MaxDepthExceeded),
1439            "expected MaxDepthExceeded, got {:?}",
1440            err
1441        );
1442    }
1443}