1use crate::ast::*;
2use crate::vm::bet::pack_trits;
3use crate::trit::Trit;
4
5pub struct BytecodeEmitter {
6 code: Vec<u8>,
7 symbols: std::collections::HashMap<String, u8>,
8 func_addrs: std::collections::HashMap<String, u16>,
9 break_patches: Vec<usize>, next_reg: u8,
11 struct_layouts: std::collections::HashMap<String, Vec<String>>,
13 agent_type_ids: std::collections::HashMap<String, u16>,
15 agent_handlers: Vec<(u16, u16)>,
17}
18
19impl BytecodeEmitter {
20 pub fn new() -> Self {
21 Self {
22 code: Vec::new(),
23 symbols: std::collections::HashMap::new(),
24 func_addrs: std::collections::HashMap::new(),
25 break_patches: Vec::new(),
26 next_reg: 0,
27 struct_layouts: std::collections::HashMap::new(),
28 agent_type_ids: std::collections::HashMap::new(),
29 agent_handlers: Vec::new(),
30 }
31 }
32
33 pub fn register_agents(&self, vm: &mut crate::vm::BetVm) {
35 for &(type_id, addr) in &self.agent_handlers {
36 vm.register_agent_type(type_id, addr as usize);
37 }
38 }
39
40 pub fn emit_program(&mut self, program: &Program) {
41 for s in &program.structs {
43 let field_names: Vec<String> = s.fields.iter().map(|(n, _)| n.clone()).collect();
44 self.struct_layouts.insert(s.name.clone(), field_names);
45 }
46 for (idx, agent) in program.agents.iter().enumerate() {
48 self.agent_type_ids.insert(agent.name.clone(), idx as u16);
49 }
50
51 let entry_jmp_patch = self.code.len() + 1;
53 self.code.push(0x0b); self.code.extend_from_slice(&[0u8, 0u8]);
55
56 for agent in &program.agents {
58 let type_id = self.agent_type_ids[&agent.name];
59 let mut handler_addr: Option<u16> = None;
62 for method in &agent.methods {
63 let addr = self.code.len() as u16;
64 if handler_addr.is_none() {
65 handler_addr = Some(addr);
66 }
67 self.emit_function(method);
68 let fq = format!("{}::{}", agent.name, method.name);
70 self.func_addrs.insert(fq, addr);
71 }
72 if let Some(addr) = handler_addr {
73 self.agent_handlers.push((type_id, addr));
74 }
75 }
76
77 for func in &program.functions {
79 self.emit_function(func);
80 }
81
82 let after_funcs = self.code.len() as u16;
84 self.patch_u16(entry_jmp_patch, after_funcs);
85 }
86
87 pub fn emit_function(&mut self, func: &Function) {
88 let func_addr = self.code.len() as u16;
90 self.func_addrs.insert(func.name.clone(), func_addr);
91
92 for stmt in &func.body {
93 self.emit_stmt(stmt);
94 }
95 self.code.push(0x11); }
98
99 pub fn emit_stmt(&mut self, stmt: &Stmt) {
100 match stmt {
101 Stmt::Let { name, ty, value } => {
102 match ty {
103 Type::TritTensor { dims } => {
104 let size: usize = dims.iter().product();
105 self.code.push(0x0f); self.code.extend_from_slice(&(size as u16).to_le_bytes());
107 let reg = self.next_reg;
108 self.symbols.insert(name.clone(), reg);
109 self.next_reg += 1;
110 self.code.push(0x08); self.code.push(reg);
112 }
113 Type::Named(struct_name) => {
114 let fields = self.struct_layouts.get(struct_name)
117 .cloned()
118 .unwrap_or_default();
119 let base_reg = self.next_reg;
121 self.symbols.insert(name.clone(), base_reg);
122 for field in &fields {
123 let reg = self.next_reg;
124 self.next_reg += 1;
125 self.symbols.insert(format!("{}.{}", name, field), reg);
126 self.code.push(0x01); self.code.extend(crate::vm::bet::pack_trits(&[crate::trit::Trit::Zero]));
129 self.code.push(0x08); self.code.push(reg);
131 }
132 if fields.is_empty() {
134 self.next_reg += 1;
135 self.code.push(0x01);
136 self.code.extend(crate::vm::bet::pack_trits(&[crate::trit::Trit::Zero]));
137 self.code.push(0x08);
138 self.code.push(base_reg);
139 }
140 }
141 _ => {
142 self.emit_expr(value);
143 let reg = self.next_reg;
144 self.symbols.insert(name.clone(), reg);
145 self.next_reg += 1;
146 self.code.push(0x08); self.code.push(reg);
148 }
149 }
150 }
151 Stmt::FieldSet { object, field, value } => {
152 let key = format!("{}.{}", object, field);
154 self.emit_expr(value);
155 if let Some(®) = self.symbols.get(&key) {
156 self.code.push(0x08); self.code.push(reg);
158 }
159 }
161 Stmt::IfTernary { condition, on_pos, on_zero, on_neg } => {
162 self.emit_expr(condition);
163
164 self.code.push(0x0a); let jmp_pos_patch = self.code.len() + 1;
167 self.code.push(0x05); self.code.extend_from_slice(&[0, 0]);
169
170 self.code.push(0x0a); let jmp_zero_patch = self.code.len() + 1;
173 self.code.push(0x06); self.code.extend_from_slice(&[0, 0]);
175
176 self.code.push(0x0c); self.emit_stmt(on_neg);
179 let end_jmp_neg_patch = self.code.len() + 1;
180 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
182
183 let pos_addr = self.code.len() as u16;
185 self.patch_u16(jmp_pos_patch, pos_addr);
186 self.code.push(0x0c); self.emit_stmt(on_pos);
188 let end_jmp_pos_patch = self.code.len() + 1;
189 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
191
192 let zero_addr = self.code.len() as u16;
194 self.patch_u16(jmp_zero_patch, zero_addr);
195 self.code.push(0x0c); self.emit_stmt(on_zero);
197
198 let end_addr = self.code.len() as u16;
200 self.patch_u16(end_jmp_neg_patch, end_addr);
201 self.patch_u16(end_jmp_pos_patch, end_addr);
202 }
203 Stmt::Match { condition, arms } => {
204 self.emit_expr(condition);
205
206 let mut patches = Vec::new();
207 let mut end_patches = Vec::new();
208
209 for (val, _stmt) in arms {
210 self.code.push(0x0a); let patch_pos = self.code.len() + 1;
212 match val {
213 1 => self.code.push(0x05), 0 => self.code.push(0x06), -1 => self.code.push(0x07), _ => unreachable!(),
217 }
218 self.code.extend_from_slice(&[0, 0]);
219 patches.push((patch_pos, *val));
220 }
221
222 self.code.push(0x0c); let end_jmp_no_match = self.code.len() + 1;
224 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
226
227 for (patch_pos, val) in patches {
228 let addr = self.code.len() as u16;
229 self.patch_u16(patch_pos, addr);
230 self.code.push(0x0c); let stmt = arms.iter().find(|(v, _)| *v == val).unwrap().1.clone();
234 self.emit_stmt(&stmt);
235
236 let end_patch = self.code.len() + 1;
237 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
239 end_patches.push(end_patch);
240 }
241
242 let end_addr = self.code.len() as u16;
243 self.patch_u16(end_jmp_no_match, end_addr);
244 for p in end_patches {
245 self.patch_u16(p, end_addr);
246 }
247 }
248 Stmt::ForIn { var, iter, body } => {
251 self.emit_expr(iter);
253 let iter_reg = self.next_reg;
254 self.symbols.insert(format!("__iter_{}", var), iter_reg);
255 self.next_reg += 1;
256 self.code.push(0x08); self.code.push(iter_reg); let idx_reg = self.next_reg;
260 self.next_reg += 1;
261 self.code.push(0x09); self.code.push(iter_reg); self.code.push(0x24); let bound_reg = self.next_reg; self.next_reg += 1;
265 self.code.push(0x08); self.code.push(bound_reg); self.code.push(0x0c); self.code.push(0x01);
271 self.code.extend(pack_trits(&[Trit::Zero]));
272 self.code.push(0x08); self.code.push(idx_reg); let loop_top = self.code.len() as u16;
275
276 self.code.push(0x09); self.code.push(iter_reg); self.code.push(0x01); self.code.extend(pack_trits(&[Trit::Zero])); self.code.push(0x09); self.code.push(idx_reg); self.code.push(0x22); let var_reg = self.next_reg; self.next_reg += 1;
282 self.symbols.insert(var.clone(), var_reg);
283 self.code.push(0x08); self.code.push(var_reg); self.emit_stmt(body);
287
288 let jmp_back = self.code.len() + 1;
290 self.code.push(0x0b);
291 self.code.extend_from_slice(&[0, 0]);
292 self.patch_u16(jmp_back, loop_top);
293 }
294
295 Stmt::Loop { body } => {
297 let loop_top = self.code.len() as u16;
298 let pre_break_count = self.break_patches.len();
300 self.emit_stmt(body);
301 let jmp_back = self.code.len() + 1;
303 self.code.push(0x0b);
304 self.code.extend_from_slice(&[0, 0]);
305 self.patch_u16(jmp_back, loop_top);
306 let after_loop = self.code.len() as u16;
308 let patches: Vec<usize> = self.break_patches.drain(pre_break_count..).collect();
309 for patch in patches {
310 self.patch_u16(patch, after_loop);
311 }
312 }
313
314 Stmt::Break => {
315 let patch = self.code.len() + 1;
316 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
318 self.break_patches.push(patch);
319 }
320
321 Stmt::Continue => {
322 }
325
326 Stmt::Use { .. } => {
327 }
329 Stmt::Send { target, message } => {
330 self.emit_expr(target);
332 self.emit_expr(message);
333 self.code.push(0x31); }
335
336 Stmt::WhileTernary { condition, on_pos, on_zero, on_neg } => {
337 let loop_top = self.code.len() as u16;
338 self.emit_expr(condition);
339 self.code.push(0x0a); let jmp_pos_patch = self.code.len() + 1;
342 self.code.push(0x05); self.code.extend_from_slice(&[0, 0]); self.code.push(0x0a); let jmp_zero_patch = self.code.len() + 1;
345 self.code.push(0x06); self.code.extend_from_slice(&[0, 0]); self.code.push(0x0c); self.emit_stmt(on_neg);
349 let exit_patch = self.code.len() + 1;
350 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]); let pos_addr = self.code.len() as u16;
354 self.patch_u16(jmp_pos_patch, pos_addr);
355 self.code.push(0x0c); self.emit_stmt(on_pos);
357 let back_pos = self.code.len() + 1;
358 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
359 self.patch_u16(back_pos, loop_top);
360
361 let zero_addr = self.code.len() as u16;
363 self.patch_u16(jmp_zero_patch, zero_addr);
364 self.code.push(0x0c); self.emit_stmt(on_zero);
366 let back_zero = self.code.len() + 1;
367 self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
368 self.patch_u16(back_zero, loop_top);
369
370 let exit_addr = self.code.len() as u16;
372 self.patch_u16(exit_patch, exit_addr);
373 }
374
375 Stmt::Return(expr) => {
376 self.emit_expr(expr);
377 self.code.push(0x00); }
379 Stmt::Block(stmts) => {
380 for stmt in stmts {
381 self.emit_stmt(stmt);
382 }
383 }
384 Stmt::Decorated { directive, stmt } => {
385 if directive == "sparseskip" {
386 if let Stmt::Expr(inner_expr) = stmt.as_ref() {
388 if let Expr::Call { callee, args } = inner_expr {
389 if callee == "matmul" && args.len() == 2 {
390 self.emit_expr(&args[0]);
391 self.emit_expr(&args[1]);
392 self.code.push(0x21); return;
394 }
395 }
396 }
397 if let Stmt::Let { name, value, .. } = stmt.as_ref() {
399 if let Expr::Call { callee, args } = value {
400 if callee == "matmul" && args.len() == 2 {
401 self.emit_expr(&args[0]);
402 self.emit_expr(&args[1]);
403 self.code.push(0x21); self.code.push(0x0c); let reg = self.next_reg;
407 self.symbols.insert(name.clone(), reg);
408 self.next_reg += 1;
409 self.code.push(0x08); self.code.push(reg);
411 return;
412 }
413 }
414 }
415 }
416 self.emit_stmt(stmt);
418 }
419 _ => {}
420 }
421 }
422
423 fn emit_expr(&mut self, expr: &Expr) {
424 match expr {
425 Expr::TritLiteral(val) => {
426 self.code.push(0x01); let trit = Trit::from(*val);
428 self.code.extend(pack_trits(&[trit]));
429 }
430 Expr::Ident(name) => {
431 if let Some(®) = self.symbols.get(name) {
432 self.code.push(0x09); self.code.push(reg);
434 }
435 }
436 Expr::BinaryOp { op, lhs, rhs } => {
437 self.emit_expr(lhs);
438 self.emit_expr(rhs);
439 match op {
440 BinOp::Add => self.code.push(0x02), BinOp::Mul => self.code.push(0x03), BinOp::Sub => { self.code.push(0x04); self.code.push(0x02); } BinOp::Equal => self.code.push(0x0e), BinOp::NotEqual => { self.code.push(0x0e); self.code.push(0x04); } BinOp::And => self.code.push(0x03), BinOp::Or => self.code.push(0x0e), }
448 }
449 Expr::UnaryOp { op, expr } => {
450 self.emit_expr(expr);
451 match op {
452 UnOp::Neg => self.code.push(0x04),
453 }
454 }
455 Expr::Call { callee, args } => {
456 for arg in args {
457 self.emit_expr(arg);
458 }
459 match callee.as_str() {
460 "consensus" => {
461 if args.len() == 2 {
462 self.code.push(0x0e); }
464 }
465 "invert" => {
466 if args.len() == 1 {
467 self.code.push(0x04); }
469 }
470 "truth" => {
471 self.code.push(0x01); self.code.extend(pack_trits(&[Trit::PosOne]));
473 }
474 "hold" => {
475 self.code.push(0x01); self.code.extend(pack_trits(&[Trit::Zero]));
477 }
478 "conflict" => {
479 self.code.push(0x01); self.code.extend(pack_trits(&[Trit::NegOne]));
481 }
482 "matmul" => {
483 if args.len() == 2 {
484 self.code.push(0x20); }
486 }
487 "sparsity" => {
488 if args.len() == 1 {
489 self.code.push(0x25); }
491 }
492 "shape" => {
493 if args.len() == 1 {
494 self.code.push(0x24); }
496 }
497 _ => {
498 if let Some(&addr) = self.func_addrs.get(callee) {
500 self.code.push(0x10); self.code.extend_from_slice(&addr.to_le_bytes());
502 } else {
503 self.code.push(0x01); self.code.extend(pack_trits(&[Trit::Zero]));
507 }
508 }
509 }
510 }
511 Expr::FieldAccess { object, field } => {
512 if let Expr::Ident(obj_name) = object.as_ref() {
515 let key = format!("{}.{}", obj_name, field);
516 if let Some(®) = self.symbols.get(&key) {
517 self.code.push(0x09); self.code.push(reg);
519 return;
520 }
521 }
522 self.code.push(0x01);
524 self.code.extend(pack_trits(&[Trit::Zero]));
525 }
526 Expr::Cast { expr, .. } => {
527 self.emit_expr(expr);
530 }
531 Expr::Spawn { agent_name, node_addr } => {
532 if let Some(addr) = node_addr {
533 self.emit_expr(&Expr::StringLiteral(addr.clone()));
535 if let Some(&type_id) = self.agent_type_ids.get(agent_name) {
536 self.code.push(0x33); self.code.extend_from_slice(&type_id.to_le_bytes());
538 } else {
539 self.code.push(0x01);
540 self.code.extend(pack_trits(&[Trit::Zero]));
541 }
542 } else if let Some(&type_id) = self.agent_type_ids.get(agent_name) {
543 self.code.push(0x30); self.code.extend_from_slice(&type_id.to_le_bytes());
546 } else {
547 self.code.push(0x01);
549 self.code.extend(pack_trits(&[Trit::Zero]));
550 }
551 }
552 Expr::StringLiteral(s) => {
553 self.code.push(0x01);
560 self.code.extend(pack_trits(&[Trit::Zero]));
561 }
562 Expr::NodeId => {
563 self.code.push(0x12); }
565 Expr::Await { target } => {
566 self.emit_expr(target);
569 self.code.push(0x32); }
571 _ => {}
572 }
573 }
574
575 pub fn finalize(mut self) -> Vec<u8> {
576 self.code.push(0x00); self.code
578 }
579
580 fn patch_u16(&mut self, pos: usize, val: u16) {
581 let bytes = val.to_le_bytes();
582 self.code[pos] = bytes[0];
583 self.code[pos + 1] = bytes[1];
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590 use crate::parser::Parser;
591 use crate::vm::{BetVm, Value};
592
593 #[test]
594 fn test_compile_and_run_simple() {
595 let input = "let x: trit = 1; let y: trit = -x; return y;";
596 let mut parser = Parser::new(input);
597 let mut emitter = BytecodeEmitter::new();
598
599 while let Ok(stmt) = parser.parse_stmt() {
601 emitter.emit_stmt(&stmt);
602 }
603
604 let code = emitter.finalize();
605 let mut vm = BetVm::new(code);
606 vm.run().unwrap();
607
608 assert_eq!(vm.get_register(1), Value::Trit(Trit::NegOne));
610 }
611
612 #[test]
613 fn test_sparseskip_emits_tsparse_matmul() {
614 let input = "let a: trittensor<2 x 2>; let b: trittensor<2 x 2>; @sparseskip let c: trittensor<2 x 2> = matmul(a, b);";
617 let mut parser = Parser::new(input);
618 let mut emitter = BytecodeEmitter::new();
619
620 while let Ok(stmt) = parser.parse_stmt() {
621 emitter.emit_stmt(&stmt);
622 }
623
624 let code = emitter.finalize();
625 assert!(code.contains(&0x21), "Expected TSPARSE_MATMUL (0x21) in bytecode");
627 assert!(!code.contains(&0x20), "Expected no dense TMATMUL (0x20) when @sparseskip used");
629
630 let mut vm = BetVm::new(code);
632 vm.run().unwrap();
633 assert!(matches!(vm.get_register(2), Value::TensorRef(_)));
634 }
635
636 #[test]
637 fn test_compile_match() {
638 let input = "let x: trit = 1; match x { 1 => { let y: trit = -1; } 0 => { let y: trit = 0; } -1 => { let y: trit = 1; } }";
639 let mut parser = Parser::new(input);
640 let mut emitter = BytecodeEmitter::new();
641
642 while let Ok(stmt) = parser.parse_stmt() {
643 emitter.emit_stmt(&stmt);
644 }
645
646 let code = emitter.finalize();
647 let mut vm = BetVm::new(code);
648 vm.run().unwrap();
649
650 assert_eq!(vm.get_register(1), Value::Trit(Trit::NegOne));
652 }
653}