1use crate::ast::*;
2use crate::vm::bet::pack_trits;
3use crate::trit::Trit;
4
5pub struct BytecodeEmitter {
6 code: Vec<u8>,
7 symbols: std::collections::HashMap<String, u8>,
8 func_addrs: std::collections::HashMap<String, u16>,
9 break_patches: Vec<usize>, next_reg: u8,
11 struct_layouts: std::collections::HashMap<String, Vec<String>>,
13 agent_type_ids: std::collections::HashMap<String, u16>,
15 agent_handlers: Vec<(u16, u16)>,
17}
18
19impl BytecodeEmitter {
20 pub fn new() -> Self {
21 Self {
22 code: Vec::new(),
23 symbols: std::collections::HashMap::new(),
24 func_addrs: std::collections::HashMap::new(),
25 break_patches: Vec::new(),
26 next_reg: 0,
27 struct_layouts: std::collections::HashMap::new(),
28 agent_type_ids: std::collections::HashMap::new(),
29 agent_handlers: Vec::new(),
30 }
31 }
32
33 pub fn register_agents(&self, vm: &mut crate::vm::BetVm) {
35 for &(type_id, addr) in &self.agent_handlers {
36 vm.register_agent_type(type_id, addr as usize);
37 }
38 }
39
40 pub fn emit_program(&mut self, program: &Program) {
41 for s in &program.structs {
43 let field_names: Vec<String> = s.fields.iter().map(|(n, _)| n.clone()).collect();
44 self.struct_layouts.insert(s.name.clone(), field_names);
45 }
46 for (idx, agent) in program.agents.iter().enumerate() {
48 self.agent_type_ids.insert(agent.name.clone(), idx as u16);
49 }
50
51 let entry_jmp_patch = self.code.len() + 1;
53 self.code.push(0x0b); self.code.extend_from_slice(&[0u8, 0u8]);
55
56 for agent in &program.agents {
58 let type_id = self.agent_type_ids[&agent.name];
59 let mut handler_addr: Option<u16> = None;
62 for method in &agent.methods {
63 let addr = self.code.len() as u16;
64 if handler_addr.is_none() {
65 handler_addr = Some(addr);
66 }
67 self.emit_function(method);
68 let fq = format!("{}::{}", agent.name, method.name);
70 self.func_addrs.insert(fq, addr);
71 }
72 if let Some(addr) = handler_addr {
73 self.agent_handlers.push((type_id, addr));
74 }
75 }
76
77 for func in &program.functions {
79 self.emit_function(func);
80 }
81
82 let after_funcs = self.code.len() as u16;
84 self.patch_u16(entry_jmp_patch, after_funcs);
85 }
86
87 pub fn emit_function(&mut self, func: &Function) {
88 let func_addr = self.code.len() as u16;
90 self.func_addrs.insert(func.name.clone(), func_addr);
91
92 for stmt in &func.body {
93 self.emit_stmt(stmt);
94 }
95 self.code.push(0x11); }
98
99 pub fn emit_stmt(&mut self, stmt: &Stmt) {
100 match stmt {
101 Stmt::Let { name, ty, value } => {
102 match ty {
103 Type::TritTensor { dims } => {
104 let size: usize = dims.iter().product();
105 self.code.push(0x0f); self.code.extend_from_slice(&(size as u16).to_le_bytes());
107 let reg = self.next_reg;
108 self.symbols.insert(name.clone(), reg);
109 self.next_reg += 1;
110 self.code.push(0x08); self.code.push(reg);
112 }
113 Type::Named(struct_name) => {
114 let fields = self.struct_layouts.get(struct_name)
117 .cloned()
118 .unwrap_or_default();
119 let base_reg = self.next_reg;
121 self.symbols.insert(name.clone(), base_reg);
122 for field in &fields {
123 let reg = self.next_reg;
124 self.next_reg += 1;
125 self.symbols.insert(format!("{}.{}", name, field), reg);
126 self.code.push(0x01); self.code.extend(crate::vm::bet::pack_trits(&[crate::trit::Trit::Tend]));
129 self.code.push(0x08); self.code.push(reg);
131 }
132 if fields.is_empty() {
134 self.next_reg += 1;
135 self.code.push(0x01);
136 self.code.extend(crate::vm::bet::pack_trits(&[crate::trit::Trit::Tend]));
137 self.code.push(0x08);
138 self.code.push(base_reg);
139 }
140 }
141 _ => {
142 self.emit_expr(value);
143 let reg = self.next_reg;
144 self.symbols.insert(name.clone(), reg);
145 self.next_reg += 1;
146 self.code.push(0x08); self.code.push(reg);
148 }
149 }
150 }
151 Stmt::FieldSet { object, field, value } => {
152 let key = format!("{}.{}", object, field);
154 self.emit_expr(value);
155 if let Some(®) = self.symbols.get(&key) {
156 self.code.push(0x08); self.code.push(reg);
158 }
159 }
161 Stmt::IndexSet { object, row, col, value } => {
162 if let Some(®) = self.symbols.get(object) {
163 self.code.push(0x09); self.code.push(reg); self.emit_expr(row);
165 self.emit_expr(col);
166 self.emit_expr(value);
167 self.code.push(0x23); }
169 }
170 Stmt::IfTernary { condition, on_pos, on_zero, on_neg } => {
171 self.emit_expr(condition);
172
173 self.code.push(0x0a); let jmp_pos_patch = self.code.len() + 1;
176 self.code.push(0x05); self.code.extend_from_slice(&[0, 0]);
178
179 self.code.push(0x0a); let jmp_zero_patch = self.code.len() + 1;
182 self.code.push(0x06); self.code.extend_from_slice(&[0, 0]);
184
185 self.code.push(0x0c); self.emit_stmt(on_neg);
188 let end_jmp_neg_patch = self.code.len() + 1;
189 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
191
192 let pos_addr = self.code.len() as u16;
194 self.patch_u16(jmp_pos_patch, pos_addr);
195 self.code.push(0x0c); self.emit_stmt(on_pos);
197 let end_jmp_pos_patch = self.code.len() + 1;
198 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
200
201 let zero_addr = self.code.len() as u16;
203 self.patch_u16(jmp_zero_patch, zero_addr);
204 self.code.push(0x0c); self.emit_stmt(on_zero);
206
207 let end_addr = self.code.len() as u16;
209 self.patch_u16(end_jmp_neg_patch, end_addr);
210 self.patch_u16(end_jmp_pos_patch, end_addr);
211 }
212 Stmt::Match { condition, arms } => {
213 self.emit_expr(condition);
214
215 let mut patches = Vec::new();
216 let mut end_patches = Vec::new();
217
218 for (val, _stmt) in arms {
219 self.code.push(0x0a); let patch_pos = self.code.len() + 1;
221 match val {
222 1 => self.code.push(0x05), 0 => self.code.push(0x06), -1 => self.code.push(0x07), _ => unreachable!(),
226 }
227 self.code.extend_from_slice(&[0, 0]);
228 patches.push((patch_pos, *val));
229 }
230
231 self.code.push(0x0c); let end_jmp_no_match = self.code.len() + 1;
233 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
235
236 for (patch_pos, val) in patches {
237 let addr = self.code.len() as u16;
238 self.patch_u16(patch_pos, addr);
239 self.code.push(0x0c); let stmt = arms.iter().find(|(v, _)| *v == val).unwrap().1.clone();
243 self.emit_stmt(&stmt);
244
245 let end_patch = self.code.len() + 1;
246 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
248 end_patches.push(end_patch);
249 }
250
251 let end_addr = self.code.len() as u16;
252 self.patch_u16(end_jmp_no_match, end_addr);
253 for p in end_patches {
254 self.patch_u16(p, end_addr);
255 }
256 }
257 Stmt::ForIn { var, iter, body } => {
260 self.emit_expr(iter);
262 let iter_reg = self.next_reg;
263 self.symbols.insert(format!("__iter_{}", var), iter_reg);
264 self.next_reg += 1;
265 self.code.push(0x08); self.code.push(iter_reg); let idx_reg = self.next_reg;
269 self.next_reg += 1;
270 self.code.push(0x09); self.code.push(iter_reg); self.code.push(0x24); let bound_reg = self.next_reg; self.next_reg += 1;
274 self.code.push(0x08); self.code.push(bound_reg); self.code.push(0x0c); self.code.push(0x01);
280 self.code.extend(pack_trits(&[Trit::Tend]));
281 self.code.push(0x08); self.code.push(idx_reg); let loop_top = self.code.len() as u16;
284
285
286
287 self.code.push(0x09); self.code.push(iter_reg); self.code.push(0x01); self.code.extend(pack_trits(&[Trit::Tend])); self.code.push(0x09); self.code.push(idx_reg); self.code.push(0x22); let var_reg = self.next_reg; self.next_reg += 1;
293 self.symbols.insert(var.clone(), var_reg);
294 self.code.push(0x08); self.code.push(var_reg); self.emit_stmt(body);
298
299 let jmp_back = self.code.len() + 1;
301 self.code.push(0x0b);
302 self.code.extend_from_slice(&[0, 0]);
303 self.patch_u16(jmp_back, loop_top);
304 }
305
306 Stmt::Loop { body } => {
308 let loop_top = self.code.len() as u16;
309
310 let pre_break_count = self.break_patches.len();
312 self.emit_stmt(body);
313 let jmp_back = self.code.len() + 1;
315 self.code.push(0x0b);
316 self.code.extend_from_slice(&[0, 0]);
317 self.patch_u16(jmp_back, loop_top);
318 let after_loop = self.code.len() as u16;
320 let patches: Vec<usize> = self.break_patches.drain(pre_break_count..).collect();
321 for patch in patches {
322 self.patch_u16(patch, after_loop);
323 }
324 }
325
326 Stmt::Break => {
327 let patch = self.code.len() + 1;
328 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
330 self.break_patches.push(patch);
331 }
332
333 Stmt::Continue => {
334 }
337
338 Stmt::Use { .. } => {
339 }
341 Stmt::Send { target, message } => {
342 self.emit_expr(target);
344 self.emit_expr(message);
345 self.code.push(0x31); }
347
348 Stmt::WhileTernary { condition, on_pos, on_zero, on_neg } => {
349 let loop_top = self.code.len() as u16;
350 self.emit_expr(condition);
351 self.code.push(0x0a); let jmp_pos_patch = self.code.len() + 1;
354 self.code.push(0x05); self.code.extend_from_slice(&[0, 0]); self.code.push(0x0a); let jmp_zero_patch = self.code.len() + 1;
357 self.code.push(0x06); self.code.extend_from_slice(&[0, 0]); self.code.push(0x0c); self.emit_stmt(on_neg);
361 let exit_patch = self.code.len() + 1;
362 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]); let pos_addr = self.code.len() as u16;
366 self.patch_u16(jmp_pos_patch, pos_addr);
367 self.code.push(0x0c); self.emit_stmt(on_pos);
369 let back_pos = self.code.len() + 1;
370 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
371 self.patch_u16(back_pos, loop_top);
372
373 let zero_addr = self.code.len() as u16;
375 self.patch_u16(jmp_zero_patch, zero_addr);
376 self.code.push(0x0c); self.emit_stmt(on_zero);
378 let back_zero = self.code.len() + 1;
379 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
380 self.patch_u16(back_zero, loop_top);
381
382 let exit_addr = self.code.len() as u16;
384 self.patch_u16(exit_patch, exit_addr);
385 }
386
387 Stmt::Return(expr) => {
388 self.emit_expr(expr);
389 self.code.push(0x11); }
391 Stmt::Block(stmts) => {
392 for stmt in stmts {
393 self.emit_stmt(stmt);
394 }
395 }
396 Stmt::Decorated { directive, stmt } => {
397 if directive == "sparseskip" {
398 let opcode = 0x21u8; if let Stmt::Expr(inner_expr) = stmt.as_ref() {
402 if let Expr::Call { callee, args } = inner_expr {
403 if callee == "matmul" && args.len() == 2 {
404 self.emit_expr(&args[0]);
405 self.emit_expr(&args[1]);
406 self.code.push(opcode); return;
408 }
409 }
410 }
411 if let Stmt::Let { name, value, .. } = stmt.as_ref() {
413 if let Expr::Call { callee, args } = value {
414 if callee == "matmul" && args.len() == 2 {
415 self.emit_expr(&args[0]);
416 self.emit_expr(&args[1]);
417 self.code.push(opcode); if opcode == 0x21 {
420 self.code.push(0x0c); }
423
424 let reg = self.next_reg;
425 self.symbols.insert(name.clone(), reg);
426 self.next_reg += 1;
427 self.code.push(0x08); self.code.push(reg);
429 return;
430 }
431 }
432 }
433 }
434 self.emit_stmt(stmt);
436 }
437 _ => {}
438 }
439 }
440
441 fn emit_expr(&mut self, expr: &Expr) {
442 match expr {
443 Expr::TritLiteral(val) => {
444 self.code.push(0x01); let trit = Trit::from(*val);
446 self.code.extend(pack_trits(&[trit]));
447 }
448 Expr::Ident(name) => {
449 if let Some(®) = self.symbols.get(name) {
450 self.code.push(0x09); self.code.push(reg);
452 }
453 }
454 Expr::BinaryOp { op, lhs, rhs } => {
455 self.emit_expr(lhs);
456 self.emit_expr(rhs);
457 match op {
458 BinOp::Add => self.code.push(0x02), BinOp::Mul => self.code.push(0x03), BinOp::Sub => { self.code.push(0x04); self.code.push(0x02); } BinOp::Equal => self.code.push(0x16), BinOp::NotEqual => { self.code.push(0x16); self.code.push(0x04); } BinOp::And => self.code.push(0x03), BinOp::Or => self.code.push(0x0e), BinOp::Less => self.code.push(0x14), BinOp::Greater => self.code.push(0x15), }
468 }
469 Expr::UnaryOp { op, expr } => {
470 self.emit_expr(expr);
471 match op {
472 UnOp::Neg => self.code.push(0x04),
473 }
474 }
475 Expr::Call { callee, args } => {
476 for arg in args {
477 self.emit_expr(arg);
478 }
479 match callee.as_str() {
480 "consensus" => {
481 if args.len() == 2 {
482 self.code.push(0x0e); }
484 }
485 "invert" => {
486 if args.len() == 1 {
487 self.code.push(0x04); }
489 }
490 "truth" => {
491 self.code.push(0x01); self.code.extend(pack_trits(&[Trit::Affirm]));
493 }
494 "hold" => {
495 self.code.push(0x01); self.code.extend(pack_trits(&[Trit::Tend]));
497 }
498 "conflict" => {
499 self.code.push(0x01); self.code.extend(pack_trits(&[Trit::Reject]));
501 }
502 "matmul" => {
503 if args.len() == 2 {
504 self.code.push(0x20); }
506 }
507 "sparsity" => {
508 if args.len() == 1 {
509 self.code.push(0x25); }
511 }
512 "shape" => {
513 if args.len() == 1 {
514 self.code.push(0x24); }
516 }
517 _ => {
518 if let Some(&addr) = self.func_addrs.get(callee) {
520 self.code.push(0x10); self.code.extend_from_slice(&addr.to_le_bytes());
522 } else {
523 self.code.push(0x01); self.code.extend(pack_trits(&[Trit::Tend]));
527 }
528 }
529 }
530 }
531 Expr::FieldAccess { object, field } => {
532 if let Expr::Ident(obj_name) = object.as_ref() {
535 let key = format!("{}.{}", obj_name, field);
536 if let Some(®) = self.symbols.get(&key) {
537 self.code.push(0x09); self.code.push(reg);
539 return;
540 }
541 }
542 self.code.push(0x01);
544 self.code.extend(pack_trits(&[Trit::Tend]));
545 }
546 Expr::Index { object, row, col } => {
547 self.emit_expr(object);
548 self.emit_expr(row);
549 self.emit_expr(col);
550 self.code.push(0x22); }
552 Expr::Propagate { expr } => {
553 self.emit_expr(expr);
555 self.code.push(0x0a);
557 let neg_patch = self.code.len() + 1;
559 self.code.push(0x07); self.code.extend_from_slice(&[0u8, 0u8]);
561 let skip_patch = self.code.len() + 1;
563 self.code.push(0x0b); self.code.extend_from_slice(&[0u8, 0u8]);
565 let prop_addr = self.code.len() as u16;
567 self.patch_u16(neg_patch, prop_addr);
568 self.code.push(0x11); let skip_addr = self.code.len() as u16;
571 self.patch_u16(skip_patch, skip_addr);
572 }
573 Expr::Cast { expr, .. } => {
574 self.emit_expr(expr);
577 }
578 Expr::Spawn { agent_name, node_addr } => {
579 if let Some(addr) = node_addr {
580 self.emit_expr(&Expr::StringLiteral(addr.clone()));
582 if let Some(&type_id) = self.agent_type_ids.get(agent_name) {
583 self.code.push(0x33); self.code.extend_from_slice(&type_id.to_le_bytes());
585 } else {
586 self.code.push(0x01);
587 self.code.extend(pack_trits(&[Trit::Tend]));
588 }
589 } else if let Some(&type_id) = self.agent_type_ids.get(agent_name) {
590 self.code.push(0x30); self.code.extend_from_slice(&type_id.to_le_bytes());
593 } else {
594 self.code.push(0x01);
596 self.code.extend(pack_trits(&[Trit::Tend]));
597 }
598 }
599 Expr::StringLiteral(_s) => {
600 self.code.push(0x01);
607 self.code.extend(pack_trits(&[Trit::Tend]));
608 }
609 Expr::NodeId => {
610 self.code.push(0x12); }
612 Expr::Await { target } => {
613 self.emit_expr(target);
616 self.code.push(0x32); }
618 _ => {}
619 }
620 }
621
622 pub fn emit_entry_call(&mut self, func_name: &str) {
629 if let Some(&addr) = self.func_addrs.get(func_name) {
630 self.code.push(0x10); self.code.extend_from_slice(&addr.to_le_bytes());
632 }
633 }
634
635 pub fn finalize(mut self) -> Vec<u8> {
636 self.code.push(0x00); self.code
638 }
639
640 fn patch_u16(&mut self, pos: usize, val: u16) {
641 let bytes = val.to_le_bytes();
642 self.code[pos] = bytes[0];
643 self.code[pos + 1] = bytes[1];
644 }
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650 use crate::parser::Parser;
651 use crate::vm::{BetVm, Value};
652
653 #[test]
654 fn test_compile_and_run_simple() {
655 let input = "let x: trit = 1; let y: trit = -x; return y;";
656 let mut parser = Parser::new(input);
657 let mut emitter = BytecodeEmitter::new();
658
659 while let Ok(stmt) = parser.parse_stmt() {
661 emitter.emit_stmt(&stmt);
662 }
663
664 let code = emitter.finalize();
665 let mut vm = BetVm::new(code);
666 vm.run().unwrap();
667
668 assert_eq!(vm.get_register(1), Value::Trit(Trit::Reject));
670 }
671
672 #[test]
673 fn test_sparseskip_emits_tsparse_matmul() {
674 let input = "let a: trittensor<2 x 2>; let b: trittensor<2 x 2>; @sparseskip let c: trittensor<2 x 2> = matmul(a, b);";
677 let mut parser = Parser::new(input);
678 let mut emitter = BytecodeEmitter::new();
679
680 while let Ok(stmt) = parser.parse_stmt() {
681 emitter.emit_stmt(&stmt);
682 }
683
684 let code = emitter.finalize();
685 assert!(code.contains(&0x21), "Expected TSPARSE_MATMUL (0x21) in bytecode");
687 assert!(!code.contains(&0x20), "Expected no dense TMATMUL (0x20) when @sparseskip used");
689
690 let mut vm = BetVm::new(code);
692 vm.run().unwrap();
693 assert!(matches!(vm.get_register(2), Value::TensorRef(_)));
694 }
695
696 #[test]
697 fn test_compile_match() {
698 let input = "let x: trit = 1; match x { 1 => { let y: trit = -1; } 0 => { let y: trit = 0; } -1 => { let y: trit = 1; } }";
699 let mut parser = Parser::new(input);
700 let mut emitter = BytecodeEmitter::new();
701
702 while let Ok(stmt) = parser.parse_stmt() {
703 emitter.emit_stmt(&stmt);
704 }
705
706 let code = emitter.finalize();
707 let mut vm = BetVm::new(code);
708 vm.run().unwrap();
709
710 assert_eq!(vm.get_register(1), Value::Trit(Trit::Reject));
712 }
713}