1use std::collections::HashMap;
25use tensorlogic_ir::TLExpr;
26
27#[derive(Debug, Clone, PartialEq)]
37pub enum Instruction {
38 PushNum(f64),
41 PushBool(bool),
43 PushSym(String),
45 Pop,
47 Dup,
49
50 Add,
53 Sub,
55 Mul,
57 Div,
59 Pow,
61 Mod,
63 Neg,
65 Abs,
67 Sqrt,
69 Exp,
71 Log,
73 Min,
75 Max,
77
78 Eq,
81 Ne,
83 Lt,
85 Le,
87 Gt,
89 Ge,
91
92 And,
95 Or,
97 Not,
99
100 JumpIfFalse(usize),
104 JumpIfTrue(usize),
107 Jump(usize),
109
110 LoadVar(String),
113 StoreVar(String),
115
116 TNorm,
119 TCoNorm,
121 FuzzyNot,
123
124 Halt,
127}
128
129#[derive(Debug, Clone)]
135pub struct BytecodeProgram {
136 pub instructions: Vec<Instruction>,
138}
139
140impl Default for BytecodeProgram {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146impl BytecodeProgram {
147 pub fn new() -> Self {
149 Self {
150 instructions: Vec::new(),
151 }
152 }
153
154 pub fn push(&mut self, instr: Instruction) -> usize {
156 let idx = self.instructions.len();
157 self.instructions.push(instr);
158 idx
159 }
160
161 pub fn patch_jump(&mut self, idx: usize, target: usize) {
165 match &mut self.instructions[idx] {
166 Instruction::JumpIfFalse(t) | Instruction::JumpIfTrue(t) | Instruction::Jump(t) => {
167 *t = target;
168 }
169 other => {
170 debug_assert!(
171 false,
172 "patch_jump called on non-jump instruction: {:?}",
173 other
174 );
175 }
176 }
177 }
178
179 pub fn len(&self) -> usize {
181 self.instructions.len()
182 }
183
184 pub fn is_empty(&self) -> bool {
186 self.instructions.is_empty()
187 }
188}
189
190#[derive(Debug, Clone, PartialEq)]
196pub enum VmValue {
197 Num(f64),
199 Bool(bool),
201 Sym(String),
203}
204
205impl VmValue {
206 pub fn as_num(&self) -> Result<f64, VmError> {
208 match self {
209 VmValue::Num(n) => Ok(*n),
210 VmValue::Bool(_) => Err(VmError::TypeMismatch {
211 expected: "Num",
212 got: "Bool",
213 }),
214 VmValue::Sym(_) => Err(VmError::TypeMismatch {
215 expected: "Num",
216 got: "Sym",
217 }),
218 }
219 }
220
221 pub fn as_bool(&self) -> Result<bool, VmError> {
223 match self {
224 VmValue::Bool(b) => Ok(*b),
225 VmValue::Num(_) => Err(VmError::TypeMismatch {
226 expected: "Bool",
227 got: "Num",
228 }),
229 VmValue::Sym(_) => Err(VmError::TypeMismatch {
230 expected: "Bool",
231 got: "Sym",
232 }),
233 }
234 }
235
236 pub fn is_truthy(&self) -> bool {
238 match self {
239 VmValue::Num(n) => *n != 0.0,
240 VmValue::Bool(b) => *b,
241 VmValue::Sym(s) => !s.is_empty(),
242 }
243 }
244
245 #[allow(dead_code)]
247 fn type_name(&self) -> &'static str {
248 match self {
249 VmValue::Num(_) => "Num",
250 VmValue::Bool(_) => "Bool",
251 VmValue::Sym(_) => "Sym",
252 }
253 }
254}
255
256#[derive(Debug)]
262pub enum VmError {
263 StackUnderflow,
265 TypeMismatch {
267 expected: &'static str,
269 got: &'static str,
271 },
272 UnboundVariable(String),
274 DivisionByZero,
276 InvalidInstruction(usize),
278 ProgramEmpty,
280}
281
282impl std::fmt::Display for VmError {
283 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284 match self {
285 VmError::StackUnderflow => write!(f, "VM stack underflow"),
286 VmError::TypeMismatch { expected, got } => {
287 write!(f, "type mismatch: expected {}, got {}", expected, got)
288 }
289 VmError::UnboundVariable(name) => {
290 write!(f, "unbound variable: '{}'", name)
291 }
292 VmError::DivisionByZero => write!(f, "division by zero"),
293 VmError::InvalidInstruction(ip) => {
294 write!(f, "invalid instruction pointer: {}", ip)
295 }
296 VmError::ProgramEmpty => write!(f, "program contains no instructions"),
297 }
298 }
299}
300
301impl std::error::Error for VmError {}
302
303#[derive(Debug, Clone, Default)]
312pub struct VmEnv {
313 vars: HashMap<String, VmValue>,
314}
315
316impl VmEnv {
317 pub fn new() -> Self {
319 Self {
320 vars: HashMap::new(),
321 }
322 }
323
324 pub fn set(&mut self, name: impl Into<String>, val: VmValue) {
326 self.vars.insert(name.into(), val);
327 }
328
329 pub fn set_num(&mut self, name: impl Into<String>, val: f64) {
331 self.set(name, VmValue::Num(val));
332 }
333
334 pub fn set_bool(&mut self, name: impl Into<String>, val: bool) {
336 self.set(name, VmValue::Bool(val));
337 }
338
339 pub fn get(&self, name: &str) -> Option<&VmValue> {
341 self.vars.get(name)
342 }
343
344 pub fn len(&self) -> usize {
346 self.vars.len()
347 }
348
349 pub fn is_empty(&self) -> bool {
351 self.vars.is_empty()
352 }
353}
354
355#[derive(Debug, Default, Clone)]
361pub struct VmStats {
362 pub instructions_executed: usize,
364 pub max_stack_depth: usize,
366 pub jumps_taken: usize,
368}
369
370#[derive(Debug)]
376pub enum CompileError {
377 UnsupportedExpr(String),
380 MaxDepthExceeded,
382}
383
384impl std::fmt::Display for CompileError {
385 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 match self {
387 CompileError::UnsupportedExpr(desc) => {
388 write!(f, "unsupported expression in bytecode compiler: {}", desc)
389 }
390 CompileError::MaxDepthExceeded => {
391 write!(f, "expression depth exceeds configured maximum")
392 }
393 }
394 }
395}
396
397impl std::error::Error for CompileError {}
398
399struct Compiler {
405 program: BytecodeProgram,
406 max_depth: usize,
407}
408
409impl Compiler {
410 fn new(max_depth: usize) -> Self {
411 Self {
412 program: BytecodeProgram::new(),
413 max_depth,
414 }
415 }
416
417 fn compile_expr(&mut self, expr: &TLExpr, depth: usize) -> Result<(), CompileError> {
419 if depth > self.max_depth {
420 return Err(CompileError::MaxDepthExceeded);
421 }
422
423 match expr {
424 TLExpr::Constant(c) => {
426 self.program.push(Instruction::PushNum(*c));
427 }
428
429 TLExpr::Pred { name, args } if args.is_empty() => {
431 self.program.push(Instruction::LoadVar(name.clone()));
432 }
433
434 TLExpr::Add(a, b) => {
436 self.compile_expr(a, depth + 1)?;
437 self.compile_expr(b, depth + 1)?;
438 self.program.push(Instruction::Add);
439 }
440 TLExpr::Sub(a, b) => {
441 self.compile_expr(a, depth + 1)?;
442 self.compile_expr(b, depth + 1)?;
443 self.program.push(Instruction::Sub);
444 }
445 TLExpr::Mul(a, b) => {
446 self.compile_expr(a, depth + 1)?;
447 self.compile_expr(b, depth + 1)?;
448 self.program.push(Instruction::Mul);
449 }
450 TLExpr::Div(a, b) => {
451 self.compile_expr(a, depth + 1)?;
452 self.compile_expr(b, depth + 1)?;
453 self.program.push(Instruction::Div);
454 }
455 TLExpr::Pow(a, b) => {
456 self.compile_expr(a, depth + 1)?;
457 self.compile_expr(b, depth + 1)?;
458 self.program.push(Instruction::Pow);
459 }
460 TLExpr::Mod(a, b) => {
461 self.compile_expr(a, depth + 1)?;
462 self.compile_expr(b, depth + 1)?;
463 self.program.push(Instruction::Mod);
464 }
465 TLExpr::Abs(a) => {
466 self.compile_expr(a, depth + 1)?;
467 self.program.push(Instruction::Abs);
468 }
469 TLExpr::Sqrt(a) => {
470 self.compile_expr(a, depth + 1)?;
471 self.program.push(Instruction::Sqrt);
472 }
473 TLExpr::Exp(a) => {
474 self.compile_expr(a, depth + 1)?;
475 self.program.push(Instruction::Exp);
476 }
477 TLExpr::Log(a) => {
478 self.compile_expr(a, depth + 1)?;
479 self.program.push(Instruction::Log);
480 }
481 TLExpr::Min(a, b) => {
482 self.compile_expr(a, depth + 1)?;
483 self.compile_expr(b, depth + 1)?;
484 self.program.push(Instruction::Min);
485 }
486 TLExpr::Max(a, b) => {
487 self.compile_expr(a, depth + 1)?;
488 self.compile_expr(b, depth + 1)?;
489 self.program.push(Instruction::Max);
490 }
491
492 TLExpr::Eq(a, b) => {
494 self.compile_expr(a, depth + 1)?;
495 self.compile_expr(b, depth + 1)?;
496 self.program.push(Instruction::Eq);
497 }
498 TLExpr::Lt(a, b) => {
499 self.compile_expr(a, depth + 1)?;
500 self.compile_expr(b, depth + 1)?;
501 self.program.push(Instruction::Lt);
502 }
503 TLExpr::Gt(a, b) => {
504 self.compile_expr(a, depth + 1)?;
505 self.compile_expr(b, depth + 1)?;
506 self.program.push(Instruction::Gt);
507 }
508 TLExpr::Lte(a, b) => {
509 self.compile_expr(a, depth + 1)?;
510 self.compile_expr(b, depth + 1)?;
511 self.program.push(Instruction::Le);
512 }
513 TLExpr::Gte(a, b) => {
514 self.compile_expr(a, depth + 1)?;
515 self.compile_expr(b, depth + 1)?;
516 self.program.push(Instruction::Ge);
517 }
518
519 TLExpr::And(a, b) => {
534 self.compile_expr(a, depth + 1)?;
535 let jump_idx = self.program.push(Instruction::JumpIfFalse(0));
537 self.compile_expr(b, depth + 1)?;
539 self.program.push(Instruction::Not);
541 self.program.push(Instruction::Not);
542 let end = self.program.len();
543 self.program.patch_jump(jump_idx, end);
544 }
545
546 TLExpr::Or(a, b) => {
553 self.compile_expr(a, depth + 1)?;
554 let jump_idx = self.program.push(Instruction::JumpIfTrue(0));
555 self.compile_expr(b, depth + 1)?;
556 self.program.push(Instruction::Not);
557 self.program.push(Instruction::Not);
558 let end = self.program.len();
559 self.program.patch_jump(jump_idx, end);
560 }
561
562 TLExpr::Not(a) => {
563 self.compile_expr(a, depth + 1)?;
564 self.program.push(Instruction::Not);
565 }
566
567 TLExpr::IfThenElse {
577 condition,
578 then_branch,
579 else_branch,
580 } => {
581 self.compile_expr(condition, depth + 1)?;
582 let jf_idx = self.program.push(Instruction::JumpIfFalse(0));
583 self.compile_expr(then_branch, depth + 1)?;
584 let jump_idx = self.program.push(Instruction::Jump(0));
585 let else_start = self.program.len();
587 self.program.patch_jump(jf_idx, else_start);
588 self.compile_expr(else_branch, depth + 1)?;
589 let end = self.program.len();
590 self.program.patch_jump(jump_idx, end);
591 }
592
593 TLExpr::Let { var, value, body } => {
600 self.compile_expr(value, depth + 1)?;
601 self.program.push(Instruction::StoreVar(var.clone()));
602 self.compile_expr(body, depth + 1)?;
603 }
604
605 TLExpr::TNorm { left, right, .. } => {
611 self.compile_expr(left, depth + 1)?;
612 self.compile_expr(right, depth + 1)?;
613 self.program.push(Instruction::TNorm);
614 }
615 TLExpr::TCoNorm { left, right, .. } => {
616 self.compile_expr(left, depth + 1)?;
617 self.compile_expr(right, depth + 1)?;
618 self.program.push(Instruction::TCoNorm);
619 }
620 TLExpr::FuzzyNot { expr: inner, .. } => {
621 self.compile_expr(inner, depth + 1)?;
622 self.program.push(Instruction::FuzzyNot);
623 }
624
625 TLExpr::SymbolLiteral(s) => {
627 self.program.push(Instruction::PushSym(s.clone()));
628 }
629
630 TLExpr::Match { scrutinee, arms } => {
636 if arms.is_empty() {
637 return Err(CompileError::UnsupportedExpr(
638 "Match with no arms".to_string(),
639 ));
640 }
641 self.compile_expr(scrutinee, depth + 1)?;
643 let tmp = format!("__match_scrutinee_{depth}");
644 self.program.push(Instruction::StoreVar(tmp.clone()));
645
646 let (wildcard_body, non_wildcard) = arms
648 .split_last()
649 .ok_or_else(|| CompileError::UnsupportedExpr("Empty Match arms".into()))?;
650
651 self.emit_match_chain(&tmp, non_wildcard, &wildcard_body.1, depth)?;
655 }
656
657 other => {
659 return Err(CompileError::UnsupportedExpr(format!("{:?}", other)));
660 }
661 }
662
663 Ok(())
664 }
665
666 fn emit_match_chain(
672 &mut self,
673 scrutinee_var: &str,
674 arms: &[(tensorlogic_ir::MatchPattern, Box<TLExpr>)],
675 else_body: &TLExpr,
676 depth: usize,
677 ) -> Result<(), CompileError> {
678 if arms.is_empty() {
679 return self.compile_expr(else_body, depth + 1);
681 }
682
683 let (pat, body) = &arms[0];
685 let remaining = &arms[1..];
686
687 self.program
689 .push(Instruction::LoadVar(scrutinee_var.to_string()));
690 match pat {
691 tensorlogic_ir::MatchPattern::ConstNumber(n) => {
692 self.program.push(Instruction::PushNum(*n));
693 }
694 tensorlogic_ir::MatchPattern::ConstSymbol(s) => {
695 self.program.push(Instruction::PushSym(s.clone()));
696 }
697 tensorlogic_ir::MatchPattern::Wildcard => {
698 return Err(CompileError::UnsupportedExpr(
699 "Wildcard in non-tail position".into(),
700 ));
701 }
702 }
703 self.program.push(Instruction::Eq);
704
705 let jf_idx = self.program.push(Instruction::JumpIfFalse(0));
707 self.compile_expr(body, depth + 1)?;
709 let jump_idx = self.program.push(Instruction::Jump(0));
711 let else_start = self.program.len();
713 self.program.patch_jump(jf_idx, else_start);
714 self.emit_match_chain(scrutinee_var, remaining, else_body, depth)?;
716 let end = self.program.len();
717 self.program.patch_jump(jump_idx, end);
718
719 Ok(())
720 }
721}
722
723pub const DEFAULT_MAX_DEPTH: usize = 512;
729
730pub fn compile(expr: &TLExpr) -> Result<BytecodeProgram, CompileError> {
743 compile_with_config(expr, DEFAULT_MAX_DEPTH)
744}
745
746pub fn compile_with_config(
751 expr: &TLExpr,
752 max_depth: usize,
753) -> Result<BytecodeProgram, CompileError> {
754 let mut compiler = Compiler::new(max_depth);
755 compiler.compile_expr(expr, 0)?;
756 compiler.program.push(Instruction::Halt);
757 Ok(compiler.program)
758}
759
760pub fn execute(program: &BytecodeProgram, env: &VmEnv) -> Result<VmValue, VmError> {
769 let (val, _stats) = execute_with_stats(program, env)?;
770 Ok(val)
771}
772
773pub fn execute_with_stats(
775 program: &BytecodeProgram,
776 env: &VmEnv,
777) -> Result<(VmValue, VmStats), VmError> {
778 if program.is_empty() {
779 return Err(VmError::ProgramEmpty);
780 }
781
782 let mut stack: Vec<VmValue> = Vec::with_capacity(16);
783 let mut local_env = env.clone();
785 let mut ip: usize = 0;
786 let mut stats = VmStats::default();
787
788 loop {
789 if ip >= program.instructions.len() {
790 return Err(VmError::InvalidInstruction(ip));
791 }
792
793 let instr = &program.instructions[ip];
794 stats.instructions_executed += 1;
795
796 match instr {
797 Instruction::PushNum(n) => {
799 stack.push(VmValue::Num(*n));
800 ip += 1;
801 }
802 Instruction::PushBool(b) => {
803 stack.push(VmValue::Bool(*b));
804 ip += 1;
805 }
806 Instruction::PushSym(s) => {
807 stack.push(VmValue::Sym(s.clone()));
808 ip += 1;
809 }
810 Instruction::Pop => {
811 stack.pop().ok_or(VmError::StackUnderflow)?;
812 ip += 1;
813 }
814 Instruction::Dup => {
815 let top = stack.last().ok_or(VmError::StackUnderflow)?.clone();
816 stack.push(top);
817 ip += 1;
818 }
819
820 Instruction::Add => {
822 let b = pop_num(&mut stack)?;
823 let a = pop_num(&mut stack)?;
824 stack.push(VmValue::Num(a + b));
825 ip += 1;
826 }
827 Instruction::Sub => {
828 let b = pop_num(&mut stack)?;
829 let a = pop_num(&mut stack)?;
830 stack.push(VmValue::Num(a - b));
831 ip += 1;
832 }
833 Instruction::Mul => {
834 let b = pop_num(&mut stack)?;
835 let a = pop_num(&mut stack)?;
836 stack.push(VmValue::Num(a * b));
837 ip += 1;
838 }
839 Instruction::Div => {
840 let b = pop_num(&mut stack)?;
841 let a = pop_num(&mut stack)?;
842 if b == 0.0 {
843 return Err(VmError::DivisionByZero);
844 }
845 stack.push(VmValue::Num(a / b));
846 ip += 1;
847 }
848 Instruction::Pow => {
849 let b = pop_num(&mut stack)?;
850 let a = pop_num(&mut stack)?;
851 stack.push(VmValue::Num(a.powf(b)));
852 ip += 1;
853 }
854 Instruction::Mod => {
855 let b = pop_num(&mut stack)?;
856 let a = pop_num(&mut stack)?;
857 stack.push(VmValue::Num(a % b));
858 ip += 1;
859 }
860 Instruction::Neg => {
861 let a = pop_num(&mut stack)?;
862 stack.push(VmValue::Num(-a));
863 ip += 1;
864 }
865 Instruction::Abs => {
866 let a = pop_num(&mut stack)?;
867 stack.push(VmValue::Num(a.abs()));
868 ip += 1;
869 }
870 Instruction::Sqrt => {
871 let a = pop_num(&mut stack)?;
872 stack.push(VmValue::Num(a.sqrt()));
873 ip += 1;
874 }
875 Instruction::Exp => {
876 let a = pop_num(&mut stack)?;
877 stack.push(VmValue::Num(a.exp()));
878 ip += 1;
879 }
880 Instruction::Log => {
881 let a = pop_num(&mut stack)?;
882 stack.push(VmValue::Num(a.ln()));
883 ip += 1;
884 }
885 Instruction::Min => {
886 let b = pop_num(&mut stack)?;
887 let a = pop_num(&mut stack)?;
888 stack.push(VmValue::Num(a.min(b)));
889 ip += 1;
890 }
891 Instruction::Max => {
892 let b = pop_num(&mut stack)?;
893 let a = pop_num(&mut stack)?;
894 stack.push(VmValue::Num(a.max(b)));
895 ip += 1;
896 }
897
898 Instruction::Eq => {
900 let b = pop_value(&mut stack)?;
901 let a = pop_value(&mut stack)?;
902 stack.push(VmValue::Bool(values_equal(&a, &b)));
903 ip += 1;
904 }
905 Instruction::Ne => {
906 let b = pop_value(&mut stack)?;
907 let a = pop_value(&mut stack)?;
908 stack.push(VmValue::Bool(!values_equal(&a, &b)));
909 ip += 1;
910 }
911 Instruction::Lt => {
912 let b = pop_num(&mut stack)?;
913 let a = pop_num(&mut stack)?;
914 stack.push(VmValue::Bool(a < b));
915 ip += 1;
916 }
917 Instruction::Le => {
918 let b = pop_num(&mut stack)?;
919 let a = pop_num(&mut stack)?;
920 stack.push(VmValue::Bool(a <= b));
921 ip += 1;
922 }
923 Instruction::Gt => {
924 let b = pop_num(&mut stack)?;
925 let a = pop_num(&mut stack)?;
926 stack.push(VmValue::Bool(a > b));
927 ip += 1;
928 }
929 Instruction::Ge => {
930 let b = pop_num(&mut stack)?;
931 let a = pop_num(&mut stack)?;
932 stack.push(VmValue::Bool(a >= b));
933 ip += 1;
934 }
935
936 Instruction::And => {
938 let b = pop_value(&mut stack)?;
939 let a = pop_value(&mut stack)?;
940 stack.push(VmValue::Bool(a.is_truthy() && b.is_truthy()));
941 ip += 1;
942 }
943 Instruction::Or => {
944 let b = pop_value(&mut stack)?;
945 let a = pop_value(&mut stack)?;
946 stack.push(VmValue::Bool(a.is_truthy() || b.is_truthy()));
947 ip += 1;
948 }
949 Instruction::Not => {
950 let a = pop_value(&mut stack)?;
951 stack.push(VmValue::Bool(!a.is_truthy()));
952 ip += 1;
953 }
954
955 Instruction::JumpIfFalse(target) => {
957 let target = *target;
958 let cond = pop_value(&mut stack)?;
959 if !cond.is_truthy() {
960 stack.push(VmValue::Bool(false));
965 ip = target;
966 stats.jumps_taken += 1;
967 } else {
968 ip += 1;
969 }
970 }
971 Instruction::JumpIfTrue(target) => {
972 let target = *target;
973 let cond = pop_value(&mut stack)?;
974 if cond.is_truthy() {
975 stack.push(VmValue::Bool(true));
977 ip = target;
978 stats.jumps_taken += 1;
979 } else {
980 ip += 1;
981 }
982 }
983 Instruction::Jump(target) => {
984 ip = *target;
985 stats.jumps_taken += 1;
986 }
987
988 Instruction::LoadVar(name) => {
990 let val = local_env
991 .get(name)
992 .ok_or_else(|| VmError::UnboundVariable(name.clone()))?
993 .clone();
994 stack.push(val);
995 ip += 1;
996 }
997 Instruction::StoreVar(name) => {
998 let val = pop_value(&mut stack)?;
999 local_env.set(name.clone(), val);
1000 ip += 1;
1001 }
1002
1003 Instruction::TNorm => {
1005 let b = pop_num(&mut stack)?;
1006 let a = pop_num(&mut stack)?;
1007 stack.push(VmValue::Num(a * b));
1009 ip += 1;
1010 }
1011 Instruction::TCoNorm => {
1012 let b = pop_num(&mut stack)?;
1013 let a = pop_num(&mut stack)?;
1014 stack.push(VmValue::Num(a + b - a * b));
1016 ip += 1;
1017 }
1018 Instruction::FuzzyNot => {
1019 let a = pop_num(&mut stack)?;
1020 stack.push(VmValue::Num(1.0 - a));
1022 ip += 1;
1023 }
1024
1025 Instruction::Halt => {
1027 let result = stack.pop().ok_or(VmError::StackUnderflow)?;
1028 if stats.max_stack_depth < stack.len() + 1 {
1030 stats.max_stack_depth = stack.len() + 1;
1031 }
1032 return Ok((result, stats));
1033 }
1034 }
1035
1036 if stack.len() > stats.max_stack_depth {
1038 stats.max_stack_depth = stack.len();
1039 }
1040 }
1041}
1042
1043#[inline]
1049fn pop_value(stack: &mut Vec<VmValue>) -> Result<VmValue, VmError> {
1050 stack.pop().ok_or(VmError::StackUnderflow)
1051}
1052
1053#[inline]
1055fn pop_num(stack: &mut Vec<VmValue>) -> Result<f64, VmError> {
1056 let val = stack.pop().ok_or(VmError::StackUnderflow)?;
1057 match val {
1058 VmValue::Num(n) => Ok(n),
1059 VmValue::Bool(_) => Err(VmError::TypeMismatch {
1060 expected: "Num",
1061 got: "Bool",
1062 }),
1063 VmValue::Sym(_) => Err(VmError::TypeMismatch {
1064 expected: "Num",
1065 got: "Sym",
1066 }),
1067 }
1068}
1069
1070#[inline]
1072fn values_equal(a: &VmValue, b: &VmValue) -> bool {
1073 match (a, b) {
1074 (VmValue::Num(x), VmValue::Num(y)) => x == y,
1075 (VmValue::Bool(x), VmValue::Bool(y)) => x == y,
1076 (VmValue::Sym(x), VmValue::Sym(y)) => x == y,
1077 _ => false,
1078 }
1079}
1080
1081#[cfg(test)]
1086mod tests {
1087 use super::*;
1088 use tensorlogic_ir::{FuzzyNegationKind, TCoNormKind, TLExpr, TNormKind};
1089
1090 fn eval(expr: TLExpr) -> VmValue {
1092 let prog = compile(&expr).expect("compile failed");
1093 let env = VmEnv::new();
1094 execute(&prog, &env).expect("execute failed")
1095 }
1096
1097 fn eval_env(expr: TLExpr, env: &VmEnv) -> VmValue {
1098 let prog = compile(&expr).expect("compile failed");
1099 execute(&prog, env).expect("execute failed")
1100 }
1101
1102 #[test]
1104 fn test_compile_constant_shape() {
1105 let val = std::f64::consts::PI;
1106 let prog = compile(&TLExpr::Constant(val)).expect("compile failed");
1107 assert_eq!(prog.len(), 2, "should be [PushNum(PI), Halt]");
1108 assert_eq!(prog.instructions[0], Instruction::PushNum(val));
1109 assert_eq!(prog.instructions[1], Instruction::Halt);
1110 }
1111
1112 #[test]
1114 fn test_execute_push_num() {
1115 let mut prog = BytecodeProgram::new();
1116 prog.push(Instruction::PushNum(5.0));
1117 prog.push(Instruction::Halt);
1118 let env = VmEnv::new();
1119 let result = execute(&prog, &env).expect("execute failed");
1120 assert_eq!(result, VmValue::Num(5.0));
1121 }
1122
1123 #[test]
1125 fn test_add() {
1126 let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
1127 assert_eq!(eval(expr), VmValue::Num(5.0));
1128 }
1129
1130 #[test]
1132 fn test_sub() {
1133 let expr = TLExpr::sub(TLExpr::Constant(10.0), TLExpr::Constant(4.0));
1134 assert_eq!(eval(expr), VmValue::Num(6.0));
1135 }
1136
1137 #[test]
1139 fn test_mul() {
1140 let expr = TLExpr::mul(TLExpr::Constant(3.0), TLExpr::Constant(4.0));
1141 assert_eq!(eval(expr), VmValue::Num(12.0));
1142 }
1143
1144 #[test]
1146 fn test_div() {
1147 let expr = TLExpr::div(TLExpr::Constant(10.0), TLExpr::Constant(2.0));
1148 assert_eq!(eval(expr), VmValue::Num(5.0));
1149 }
1150
1151 #[test]
1153 fn test_pow() {
1154 let expr = TLExpr::pow(TLExpr::Constant(2.0), TLExpr::Constant(8.0));
1155 assert_eq!(eval(expr), VmValue::Num(256.0));
1156 }
1157
1158 #[test]
1160 fn test_eq_true() {
1161 let expr = TLExpr::eq(TLExpr::Constant(3.0), TLExpr::Constant(3.0));
1162 assert_eq!(eval(expr), VmValue::Bool(true));
1163 }
1164
1165 #[test]
1167 fn test_lt_true() {
1168 let expr = TLExpr::lt(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
1169 assert_eq!(eval(expr), VmValue::Bool(true));
1170 }
1171
1172 #[test]
1174 fn test_and_false() {
1175 let expr = TLExpr::and(
1176 TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
1177 TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)),
1178 );
1179 assert_eq!(eval(expr), VmValue::Bool(false));
1180 }
1181
1182 #[test]
1184 fn test_or_true() {
1185 let expr = TLExpr::or(
1186 TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)),
1187 TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
1188 );
1189 assert_eq!(eval(expr), VmValue::Bool(true));
1190 }
1191
1192 #[test]
1194 fn test_not_false_to_true() {
1195 let expr = TLExpr::negate(TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)));
1196 assert_eq!(eval(expr), VmValue::Bool(true));
1197 }
1198
1199 #[test]
1201 fn test_short_circuit_and_jump() {
1202 let expr = TLExpr::and(
1204 TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)), TLExpr::eq(TLExpr::Constant(3.0), TLExpr::Constant(3.0)), );
1207 let prog = compile(&expr).expect("compile failed");
1208 let env = VmEnv::new();
1209 let (result, stats) = execute_with_stats(&prog, &env).expect("execute failed");
1210 assert_eq!(result, VmValue::Bool(false));
1211 assert!(stats.jumps_taken > 0, "JumpIfFalse should have been taken");
1212 }
1213
1214 #[test]
1216 fn test_short_circuit_or_jump() {
1217 let expr = TLExpr::or(
1218 TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(1.0)), TLExpr::eq(TLExpr::Constant(3.0), TLExpr::Constant(4.0)), );
1221 let prog = compile(&expr).expect("compile failed");
1222 let env = VmEnv::new();
1223 let (result, stats) = execute_with_stats(&prog, &env).expect("execute failed");
1224 assert_eq!(result, VmValue::Bool(true));
1225 assert!(stats.jumps_taken > 0, "JumpIfTrue should have been taken");
1226 }
1227
1228 #[test]
1230 fn test_ite_true_branch() {
1231 let expr = TLExpr::if_then_else(
1232 TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
1233 TLExpr::Constant(1.0),
1234 TLExpr::Constant(2.0),
1235 );
1236 assert_eq!(eval(expr), VmValue::Num(1.0));
1237 }
1238
1239 #[test]
1241 fn test_ite_false_branch() {
1242 let expr = TLExpr::if_then_else(
1243 TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)),
1244 TLExpr::Constant(1.0),
1245 TLExpr::Constant(2.0),
1246 );
1247 assert_eq!(eval(expr), VmValue::Num(2.0));
1248 }
1249
1250 #[test]
1252 fn test_load_var() {
1253 let expr = TLExpr::pred("x", vec![]);
1254 let mut env = VmEnv::new();
1255 env.set_num("x", 42.0);
1256 assert_eq!(eval_env(expr, &env), VmValue::Num(42.0));
1257 }
1258
1259 #[test]
1261 fn test_let_binding() {
1262 let expr = TLExpr::Let {
1264 var: "y".to_string(),
1265 value: Box::new(TLExpr::Constant(7.0)),
1266 body: Box::new(TLExpr::mul(
1267 TLExpr::pred("y", vec![]),
1268 TLExpr::Constant(2.0),
1269 )),
1270 };
1271 let env = VmEnv::new();
1272 assert_eq!(eval_env(expr, &env), VmValue::Num(14.0));
1273 }
1274
1275 #[test]
1277 fn test_stack_underflow() {
1278 let mut prog = BytecodeProgram::new();
1279 prog.push(Instruction::Add); prog.push(Instruction::Halt);
1281 let env = VmEnv::new();
1282 let err = execute(&prog, &env).unwrap_err();
1283 assert!(
1284 matches!(err, VmError::StackUnderflow),
1285 "expected StackUnderflow, got {:?}",
1286 err
1287 );
1288 }
1289
1290 #[test]
1292 fn test_unbound_variable() {
1293 let mut prog = BytecodeProgram::new();
1294 prog.push(Instruction::LoadVar("missing".to_string()));
1295 prog.push(Instruction::Halt);
1296 let env = VmEnv::new();
1297 let err = execute(&prog, &env).unwrap_err();
1298 assert!(
1299 matches!(err, VmError::UnboundVariable(_)),
1300 "expected UnboundVariable, got {:?}",
1301 err
1302 );
1303 }
1304
1305 #[test]
1307 fn test_stats_instructions_executed() {
1308 let expr = TLExpr::Constant(1.0);
1309 let prog = compile(&expr).expect("compile failed");
1310 let env = VmEnv::new();
1311 let (_val, stats) = execute_with_stats(&prog, &env).expect("execute failed");
1312 assert!(stats.instructions_executed > 0);
1313 }
1314
1315 #[test]
1317 fn test_stats_max_stack_depth_single_push() {
1318 let mut prog = BytecodeProgram::new();
1319 prog.push(Instruction::PushNum(99.0));
1320 prog.push(Instruction::Halt);
1321 let env = VmEnv::new();
1322 let (_val, stats) = execute_with_stats(&prog, &env).expect("execute failed");
1323 assert_eq!(stats.max_stack_depth, 1, "single push should give depth 1");
1324 }
1325
1326 #[test]
1328 fn test_tnorm_product() {
1329 let expr = TLExpr::TNorm {
1330 kind: TNormKind::Product,
1331 left: Box::new(TLExpr::Constant(0.5)),
1332 right: Box::new(TLExpr::Constant(0.5)),
1333 };
1334 let result = eval(expr);
1335 match result {
1336 VmValue::Num(n) => {
1337 assert!((n - 0.25).abs() < 1e-10, "expected 0.25, got {}", n);
1338 }
1339 _ => panic!("expected Num, got {:?}", result),
1340 }
1341 }
1342
1343 #[test]
1345 fn test_fuzzy_not() {
1346 let expr = TLExpr::FuzzyNot {
1347 kind: FuzzyNegationKind::Standard,
1348 expr: Box::new(TLExpr::Constant(0.3)),
1349 };
1350 let result = eval(expr);
1351 match result {
1352 VmValue::Num(n) => {
1353 assert!((n - 0.7).abs() < 1e-10, "expected 0.7, got {}", n);
1354 }
1355 _ => panic!("expected Num, got {:?}", result),
1356 }
1357 }
1358
1359 #[test]
1361 fn test_tconorm() {
1362 let expr = TLExpr::TCoNorm {
1363 kind: TCoNormKind::ProbabilisticSum,
1364 left: Box::new(TLExpr::Constant(0.5)),
1365 right: Box::new(TLExpr::Constant(0.5)),
1366 };
1367 let result = eval(expr);
1368 match result {
1369 VmValue::Num(n) => {
1370 assert!((n - 0.75).abs() < 1e-10, "expected 0.75, got {}", n);
1372 }
1373 _ => panic!("expected Num, got {:?}", result),
1374 }
1375 }
1376
1377 #[test]
1379 fn test_nested_arithmetic() {
1380 let expr = TLExpr::mul(
1382 TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0)),
1383 TLExpr::add(TLExpr::Constant(3.0), TLExpr::Constant(4.0)),
1384 );
1385 assert_eq!(eval(expr), VmValue::Num(21.0));
1386 }
1387
1388 #[test]
1390 fn test_division_by_zero() {
1391 let mut prog = BytecodeProgram::new();
1392 prog.push(Instruction::PushNum(1.0));
1393 prog.push(Instruction::PushNum(0.0));
1394 prog.push(Instruction::Div);
1395 prog.push(Instruction::Halt);
1396 let env = VmEnv::new();
1397 let err = execute(&prog, &env).unwrap_err();
1398 assert!(
1399 matches!(err, VmError::DivisionByZero),
1400 "expected DivisionByZero, got {:?}",
1401 err
1402 );
1403 }
1404
1405 #[test]
1407 fn test_abs() {
1408 let expr = TLExpr::Abs(Box::new(TLExpr::Constant(-5.0)));
1409 assert_eq!(eval(expr), VmValue::Num(5.0));
1410 }
1411
1412 #[test]
1414 fn test_compile_unsupported_forall() {
1415 use tensorlogic_ir::Term;
1416 let expr = TLExpr::forall("x", "D", TLExpr::pred("P", vec![Term::var("x")]));
1417 let err = compile(&expr).unwrap_err();
1418 assert!(
1419 matches!(err, CompileError::UnsupportedExpr(_)),
1420 "expected UnsupportedExpr, got {:?}",
1421 err
1422 );
1423 }
1424
1425 #[test]
1427 fn test_max_depth_exceeded() {
1428 let inner = TLExpr::add(
1430 TLExpr::add(
1431 TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
1432 TLExpr::Constant(1.0),
1433 ),
1434 TLExpr::Constant(1.0),
1435 );
1436 let err = compile_with_config(&inner, 1).unwrap_err();
1437 assert!(
1438 matches!(err, CompileError::MaxDepthExceeded),
1439 "expected MaxDepthExceeded, got {:?}",
1440 err
1441 );
1442 }
1443}