Skip to main content

ternlang_core/codegen/
betbc.rs

1use crate::ast::*;
2use crate::vm::bet::pack_trits;
3use crate::trit::Trit;
4
5pub struct BytecodeEmitter {
6    code: Vec<u8>,
7    symbols: std::collections::HashMap<String, u8>,
8    func_addrs: std::collections::HashMap<String, u16>,
9    function_patches: std::collections::HashMap<String, Vec<usize>>,
10    break_patches: Vec<usize>,
11    continue_patches: Vec<usize>,
12    next_reg: u8,
13    struct_layouts: std::collections::HashMap<String, Vec<String>>,
14    agent_type_ids: std::collections::HashMap<String, u16>,
15    agent_handlers: Vec<(u16, u16)>,
16}
17
18impl BytecodeEmitter {
19    pub fn new() -> Self {
20        Self {
21            code: Vec::new(),
22            symbols: std::collections::HashMap::new(),
23            func_addrs: std::collections::HashMap::new(),
24            function_patches: std::collections::HashMap::new(),
25            break_patches: Vec::new(),
26            continue_patches: Vec::new(),
27            next_reg: 0,
28            struct_layouts: std::collections::HashMap::new(),
29            agent_type_ids: std::collections::HashMap::new(),
30            agent_handlers: Vec::new(),
31        }
32    }
33
34    pub fn register_agents(&self, vm: &mut crate::vm::BetVm) {
35        for &(type_id, addr) in &self.agent_handlers {
36            vm.register_agent_type(type_id, addr as usize);
37        }
38    }
39
40    pub fn emit_header_jump(&mut self) -> usize {
41        let patch_pos = self.code.len() + 1;
42        self.code.push(0x0b); // TJMP
43        self.code.extend_from_slice(&[0u8, 0u8]);
44        patch_pos
45    }
46
47    pub fn patch_header_jump(&mut self, patch_pos: usize) {
48        let addr = self.code.len() as u16;
49        self.patch_u16(patch_pos, addr);
50    }
51
52    pub fn emit_program(&mut self, program: &Program) {
53        let parent_next_reg = self.next_reg;
54        for s in &program.structs {
55            let names: Vec<String> = s.fields.iter().map(|(n, _)| n.clone()).collect();
56            self.struct_layouts.insert(s.name.clone(), names);
57        }
58        for (idx, agent) in program.agents.iter().enumerate() {
59            self.agent_type_ids.insert(agent.name.clone(), idx as u16);
60        }
61
62        // PASS 1: Addresses
63        let real_code = std::mem::take(&mut self.code);
64        let real_func_addrs = std::mem::take(&mut self.func_addrs);
65        let real_agent_handlers = std::mem::take(&mut self.agent_handlers);
66        let base_addr = real_code.len() as u16;
67
68        for agent in &program.agents {
69            let type_id = self.agent_type_ids[&agent.name];
70            let mut handler_addr = None;
71            for method in &agent.methods {
72                let addr = base_addr + self.code.len() as u16;
73                if handler_addr.is_none() { handler_addr = Some(addr); }
74                self.emit_function(method);
75                self.func_addrs.insert(format!("{}::{}", agent.name, method.name), addr);
76            }
77            if let Some(addr) = handler_addr { self.agent_handlers.push((type_id, addr)); }
78        }
79        for func in &program.functions {
80            let addr = base_addr + self.code.len() as u16;
81            self.func_addrs.insert(func.name.clone(), addr);
82            self.emit_function(func);
83        }
84
85        let final_func_addrs = std::mem::replace(&mut self.func_addrs, real_func_addrs);
86        let final_agent_handlers = std::mem::replace(&mut self.agent_handlers, real_agent_handlers);
87        self.code = real_code;
88        self.func_addrs = final_func_addrs;
89        self.agent_handlers = final_agent_handlers;
90        self.next_reg = parent_next_reg;
91
92        // PASS 2: Real
93        for agent in &program.agents {
94            for method in &agent.methods { self.emit_function(method); }
95        }
96        for func in &program.functions { self.emit_function(func); }
97    }
98
99    pub fn emit_function(&mut self, func: &Function) {
100        let func_addr = self.code.len() as u16;
101        self.func_addrs.insert(func.name.clone(), func_addr);
102        if let Some(patches) = self.function_patches.remove(&func.name) {
103            for p in patches {
104                self.code[p..p + 2].copy_from_slice(&func_addr.to_le_bytes());
105            }
106        }
107        let parent_symbols = self.symbols.clone();
108        let parent_next_reg = self.next_reg;
109        self.next_reg = 0;
110
111        // If function has @sparseskip, we could emit a special header here.
112        // For now, it's just a marker in the AST.
113
114        for (name, _) in func.params.iter().rev() {
115            let reg = self.next_reg;
116            self.symbols.insert(name.clone(), reg);
117            self.next_reg += 1;
118            self.code.push(0x08); self.code.push(reg);
119        }
120        for stmt in &func.body { self.emit_stmt(stmt); }
121        self.symbols = parent_symbols;
122        self.next_reg = parent_next_reg;
123        self.code.push(0x11); // TRET
124    }
125
126    pub fn emit_stmt(&mut self, stmt: &Stmt) {
127        match stmt {
128            Stmt::Let { name, ty, value } => {
129                let mut handled = false;
130                if let Type::TritTensor { dims } = ty {
131                    // Only auto-allocate if size is fixed (>0) and no literal is provided
132                    if !dims.is_empty() && !dims.contains(&0) && !matches!(value, Expr::TritTensorLiteral(_)) {
133                        let rows = dims[0];
134                        let cols = if dims.len() > 1 { dims[1] } else { 1 };
135                        self.code.push(0x0f);
136                        self.code.extend_from_slice(&(rows as u16).to_le_bytes());
137                        self.code.extend_from_slice(&(cols as u16).to_le_bytes());
138                        handled = true;
139                    }
140                }
141                if !handled {
142                    self.emit_expr(value);
143                }
144                let reg = self.next_reg;
145                self.symbols.insert(name.clone(), reg);
146                self.next_reg += 1;
147                self.code.push(0x08); self.code.push(reg); // TSTORE
148            }
149            Stmt::Set { name, value } => {
150                self.emit_expr(value);
151                if let Some(&reg) = self.symbols.get(name) {
152                    self.code.push(0x08); self.code.push(reg);
153                }
154            }
155            Stmt::FieldSet { object, field, value } => {
156                let key = format!("{}.{}", object, field);
157                self.emit_expr(value);
158                if let Some(&reg) = self.symbols.get(&key) {
159                    self.code.push(0x08); self.code.push(reg);
160                }
161            }
162            Stmt::IndexSet { object, row, col, value } => {
163                if let Some(&reg) = self.symbols.get(object) {
164                    self.code.push(0x09); self.code.push(reg);
165                    self.emit_expr(row);
166                    self.emit_expr(col);
167                    self.emit_expr(value);
168                    self.code.push(0x23);
169                }
170            }
171            Stmt::IfTernary { condition, on_pos, on_zero, on_neg } => {
172                self.emit_expr(condition);
173                self.code.push(0x0a);
174                let pos_patch = self.code.len() + 1;
175                self.code.push(0x05); self.code.extend_from_slice(&[0, 0]);
176                self.code.push(0x0a);
177                let zero_patch = self.code.len() + 1;
178                self.code.push(0x06); self.code.extend_from_slice(&[0, 0]);
179                self.code.push(0x0c);
180                self.emit_stmt(on_neg);
181                let exit_patch = self.code.len() + 1;
182                self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
183                let pos_addr = self.code.len() as u16;
184                self.patch_u16(pos_patch, pos_addr);
185                self.code.push(0x0c);
186                self.emit_stmt(on_pos);
187                let exit_pos = self.code.len() + 1;
188                self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
189                let zero_addr = self.code.len() as u16;
190                self.patch_u16(zero_patch, zero_addr);
191                self.code.push(0x0c);
192                self.emit_stmt(on_zero);
193                let end = self.code.len() as u16;
194                self.patch_u16(exit_patch, end);
195                self.patch_u16(exit_pos, end);
196            }
197            Stmt::Match { condition, arms } => {
198                self.emit_expr(condition);
199                let cond_reg = self.next_reg; self.next_reg += 1;
200                self.code.push(0x08); self.code.push(cond_reg); // Tstore
201
202                let mut end_patches = Vec::new();
203                let mut next_arm_patch = None;
204
205                for (val, stmt) in arms {
206                    if let Some(p) = next_arm_patch {
207                        let addr = self.code.len() as u16;
208                        self.patch_u16(p, addr);
209                    }
210
211                    // Load condition for this arm
212                    self.code.push(0x09); self.code.push(cond_reg); // Tload
213
214                    let match_patch;
215                    match val {
216                        1 => {
217                            self.code.push(0x05); // TjmpPos (peeks)
218                            match_patch = self.code.len();
219                            self.code.extend_from_slice(&[0, 0]);
220                        }
221                        0 => {
222                            self.code.push(0x06); // TjmpZero (peeks)
223                            match_patch = self.code.len();
224                            self.code.extend_from_slice(&[0, 0]);
225                        }
226                        -1 => {
227                            self.code.push(0x07); // TjmpNeg (peeks)
228                            match_patch = self.code.len();
229                            self.code.extend_from_slice(&[0, 0]);
230                        }
231                        v => {
232                            self.code.push(0x25); // TjmpEqInt (peeks)
233                            self.code.extend_from_slice(&v.to_le_bytes());
234                            match_patch = self.code.len();
235                            self.code.extend_from_slice(&[0, 0]);
236                        }
237                    }
238
239                    // Mismatch: Jump past body to the next arm's check
240                    let skip_patch = self.code.len() + 1;
241                    self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
242                    next_arm_patch = Some(skip_patch);
243
244                    // Match found: execute body
245                    let body_addr = self.code.len() as u16;
246                    self.patch_u16(match_patch, body_addr);
247                    
248                    // Body: first pop the condition we were peeking at
249                    self.code.push(0x0c); // Tpop
250                    self.emit_stmt(stmt);
251                    
252                    // After body, jump to end of match
253                    let end_patch = self.code.len() + 1;
254                    self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
255                    end_patches.push(end_patch);
256                }
257
258                if let Some(p) = next_arm_patch {
259                    let addr = self.code.len() as u16;
260                    self.patch_u16(p, addr);
261                }
262                
263                // If no arms matched, we still have one Tload on stack from the last failed arm check
264                // unless arms was empty (but semantic enforces it isn't for Trit, and for Int it might be)
265                if !arms.is_empty() {
266                    self.code.push(0x0c); // Tpop
267                }
268
269                let end_addr = self.code.len() as u16;
270                for p in end_patches { self.patch_u16(p, end_addr); }
271                self.next_reg -= 1;
272            }
273            Stmt::ForIn { var, iter, body } => {
274                self.emit_expr(iter);
275                let it_reg = self.next_reg; self.next_reg += 1;
276                self.code.push(0x08); self.code.push(it_reg);
277                self.code.push(0x09); self.code.push(it_reg);
278                self.code.push(0x24);
279                let r_reg = self.next_reg; self.next_reg += 1;
280                self.code.push(0x08); self.code.push(r_reg);
281                self.code.push(0x0c);
282                let i_reg = self.next_reg; self.next_reg += 1;
283                self.code.push(0x17); self.code.extend_from_slice(&0i64.to_le_bytes());
284                self.code.push(0x08); self.code.push(i_reg);
285                
286                let top = self.code.len() as u16;
287                let pre_break = self.break_patches.len();
288                let pre_cont = self.continue_patches.len();
289
290                self.code.push(0x09); self.code.push(i_reg);
291                self.code.push(0x09); self.code.push(r_reg);
292                self.code.push(0x14);
293                self.code.push(0x0a);
294                let neg = self.code.len() + 1;
295                self.code.push(0x07); self.code.extend_from_slice(&[0, 0]);
296                self.code.push(0x0a);
297                let zero = self.code.len() + 1;
298                self.code.push(0x06); self.code.extend_from_slice(&[0, 0]);
299                self.code.push(0x0c);
300                self.code.push(0x09); self.code.push(it_reg);
301                self.code.push(0x09); self.code.push(i_reg);
302                self.code.push(0x17); self.code.extend_from_slice(&0i64.to_le_bytes());
303                self.code.push(0x22);
304                let v_reg = self.next_reg; self.next_reg += 1;
305                self.symbols.insert(var.clone(), v_reg);
306                self.code.push(0x08); self.code.push(v_reg);
307                self.emit_stmt(body);
308                
309                let cont_addr = self.code.len() as u16;
310                let cs: Vec<usize> = self.continue_patches.drain(pre_cont..).collect();
311                for p in cs { self.patch_u16(p, cont_addr); }
312
313                self.code.push(0x09); self.code.push(i_reg);
314                self.code.push(0x17); self.code.extend_from_slice(&1i64.to_le_bytes());
315                self.code.push(0x18);
316                self.code.push(0x08); self.code.push(i_reg);
317                let back = self.code.len() + 1;
318                self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
319                self.patch_u16(back, top);
320                let end = self.code.len() as u16;
321                self.patch_u16(neg, end); self.patch_u16(zero, end);
322                let bs: Vec<usize> = self.break_patches.drain(pre_break..).collect();
323                for p in bs { self.patch_u16(p, end); }
324            }
325            Stmt::WhileTernary { condition, on_pos, on_zero, on_neg } => {
326                let top = self.code.len() as u16;
327                let pre_break = self.break_patches.len();
328                let pre_cont = self.continue_patches.len();
329
330                self.emit_expr(condition);
331                self.code.push(0x0a); // TDUP
332                let pos_patch = self.code.len() + 1;
333                self.code.push(0x05); self.code.extend_from_slice(&[0, 0]); // TJMP_POS
334                self.code.push(0x0a); // TDUP
335                let zero_patch = self.code.len() + 1;
336                self.code.push(0x06); self.code.extend_from_slice(&[0, 0]); // TJMP_ZERO
337                
338                // NEG ARM: execute and EXIT (don't loop back)
339                self.code.push(0x0c); // TPOP
340                self.emit_stmt(on_neg);
341                let exit_neg = self.code.len() + 1;
342                self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]); // TJMP to end
343
344                // POS ARM: execute and LOOP BACK
345                let pos_addr = self.code.len() as u16;
346                self.patch_u16(pos_patch, pos_addr);
347                self.code.push(0x0c); // TPOP
348                self.emit_stmt(on_pos);
349                let back_pos = self.code.len() + 1;
350                self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
351                self.patch_u16(back_pos, top);
352
353                // ZERO ARM: execute and EXIT (don't loop back)
354                let zero_addr = self.code.len() as u16;
355                self.patch_u16(zero_patch, zero_addr);
356                self.code.push(0x0c); // TPOP
357                self.emit_stmt(on_zero);
358                
359                let end = self.code.len() as u16;
360                self.patch_u16(exit_neg, end);
361
362                let cs: Vec<usize> = self.continue_patches.drain(pre_cont..).collect();
363                for p in cs { self.patch_u16(p, top); }
364                let bs: Vec<usize> = self.break_patches.drain(pre_break..).collect();
365                for p in bs { self.patch_u16(p, end); }
366            }
367            Stmt::Loop { body } => {
368                let top = self.code.len() as u16;
369                let pre_break = self.break_patches.len();
370                let pre_cont = self.continue_patches.len();
371                self.emit_stmt(body);
372                let back = self.code.len() + 1;
373                self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
374                self.patch_u16(back, top);
375                let end = self.code.len() as u16;
376                let cs: Vec<usize> = self.continue_patches.drain(pre_cont..).collect();
377                for p in cs { self.patch_u16(p, top); }
378                let bs: Vec<usize> = self.break_patches.drain(pre_break..).collect();
379                for p in bs { self.patch_u16(p, end); }
380            }
381            Stmt::Break => {
382                let p = self.code.len() + 1;
383                self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
384                self.break_patches.push(p);
385            }
386            Stmt::Continue => {
387                let p = self.code.len() + 1;
388                self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]);
389                self.continue_patches.push(p);
390            }
391            Stmt::Send { target, message } => {
392                self.emit_expr(target);
393                self.emit_expr(message);
394                self.code.push(0x31); // TSEND
395            }
396            Stmt::Return(e) => { self.emit_expr(e); self.code.push(0x11); }
397            Stmt::Block(ss) => { for s in ss { self.emit_stmt(s); } }
398            Stmt::Expr(e) => { self.emit_expr(e); self.code.push(0x0c); }
399            Stmt::Decorated { directive: _, stmt } => { self.emit_stmt(stmt); }
400            _ => {}
401        }
402    }
403
404    fn emit_expr(&mut self, expr: &Expr) {
405        match expr {
406            Expr::TritLiteral(v) => {
407                self.code.push(0x01);
408                self.code.extend(pack_trits(&[Trit::from(*v)]));
409            }
410            Expr::IntLiteral(v) => {
411                self.code.push(0x17);
412                self.code.extend_from_slice(&v.to_le_bytes());
413            }
414            Expr::FloatLiteral(val) => {
415                self.code.push(0x19);
416                self.code.extend_from_slice(&val.to_le_bytes());
417            }
418            Expr::StringLiteral(val) => {
419                self.code.push(0x21); // TPUSH_STRING
420                let bytes = val.as_bytes();
421                self.code.extend_from_slice(&(bytes.len() as u16).to_le_bytes());
422                self.code.extend_from_slice(bytes);
423            }
424            Expr::Ident(name) => {
425                if let Some(&r) = self.symbols.get(name) {
426                    self.code.push(0x09); self.code.push(r);
427                }
428            }
429            Expr::BinaryOp { op, lhs, rhs } => {
430                self.emit_expr(lhs); self.emit_expr(rhs);
431                match op {
432                    BinOp::Add => self.code.push(0x02),
433                    BinOp::Mul => self.code.push(0x03),
434                    BinOp::Div => self.code.push(0x1e),
435                    BinOp::Mod => self.code.push(0x1f),
436                    BinOp::Sub => { self.code.push(0x04); self.code.push(0x02); }
437                    BinOp::Equal => self.code.push(0x16),
438                    BinOp::NotEqual => { self.code.push(0x16); self.code.push(0x04); }
439                    BinOp::And => self.code.push(0x03),
440                    BinOp::Or => self.code.push(0x0e),
441                    BinOp::Less => self.code.push(0x14),
442                    BinOp::Greater => self.code.push(0x15),
443                    BinOp::LessEqual => self.code.push(0x26),
444                    BinOp::GreaterEqual => self.code.push(0x27),
445                }
446            }
447            Expr::UnaryOp { op, expr } => {
448                self.emit_expr(expr);
449                match op { UnOp::Neg => self.code.push(0x04) }
450            }
451            Expr::Call { callee, args } => {
452                match callee.as_str() {
453                    "print" | "println" => {
454                        for a in args {
455                            self.emit_expr(a);
456                            self.code.push(0x20); // TPRINT
457                        }
458                        self.code.push(0x01); self.code.extend(pack_trits(&[Trit::Tend])); // return hold()
459                    }
460                    "consensus" => {
461                        for a in args { self.emit_expr(a); }
462                        if args.len() == 2 { self.code.push(0x0e); }
463                    }
464                    "length" => {
465                        if args.len() == 1 {
466                            self.emit_expr(&args[0]);
467                            self.code.push(0x24); // TSHAPE
468                            self.code.push(0x0c); // TPOP (cols)
469                        }
470                    }
471                    "mul" => {
472                        for a in args { self.emit_expr(a); }
473                        if args.len() == 2 { self.code.push(0x03); }
474                    }
475                    "truth" => { self.code.push(0x01); self.code.extend(pack_trits(&[Trit::Affirm])); }
476                    "hold" => { self.code.push(0x01); self.code.extend(pack_trits(&[Trit::Tend])); }
477                    "conflict" => { self.code.push(0x01); self.code.extend(pack_trits(&[Trit::Reject])); }
478                    _ => {
479                        for a in args { self.emit_expr(a); }
480                        self.code.push(0x10); // TCALL
481                        if let Some(&addr) = self.func_addrs.get(callee) {
482                            self.code.extend_from_slice(&addr.to_le_bytes());
483                        } else {
484                            let patch = self.code.len();
485                            self.code.extend_from_slice(&[0, 0]);
486                            self.function_patches.entry(callee.to_string()).or_default().push(patch);
487                        }
488                    }
489                }
490            }
491            Expr::Spawn { agent_name, .. } => {
492                if let Some(&type_id) = self.agent_type_ids.get(agent_name) {
493                    self.code.push(0x30); // TSPAWN
494                    self.code.extend_from_slice(&type_id.to_le_bytes());
495                } else {
496                    self.code.push(0x01); self.code.extend(pack_trits(&[Trit::Tend]));
497                }
498            }
499            Expr::Await { target } => {
500                self.emit_expr(target);
501                self.code.push(0x32); // TAWAIT
502            }
503            Expr::TritTensorLiteral(vs) => {
504                let rows = vs.len();
505                let cols = 1;
506                self.code.push(0x0f);
507                self.code.extend_from_slice(&(rows as u16).to_le_bytes());
508                self.code.extend_from_slice(&(cols as u16).to_le_bytes());
509                let tr = self.next_reg; self.next_reg += 1;
510                self.code.push(0x08); self.code.push(tr);
511                for (idx, &v) in vs.iter().enumerate() {
512                    self.code.push(0x09); self.code.push(tr);
513                    self.code.push(0x17); self.code.extend_from_slice(&(idx as i64).to_le_bytes());
514                    self.code.push(0x17); self.code.extend_from_slice(&0i64.to_le_bytes());
515                    self.code.push(0x01); self.code.extend(pack_trits(&[Trit::from(v)]));
516                    self.code.push(0x23);
517                }
518                self.code.push(0x09); self.code.push(tr);
519            }
520            Expr::Propagate { expr } => {
521                self.emit_expr(expr);
522                self.code.push(0x0a); // TDUP
523                let patch = self.code.len() + 1;
524                self.code.push(0x07); self.code.extend_from_slice(&[0, 0]); // TJMP_NEG
525                let skip = self.code.len() + 1;
526                self.code.push(0x0b); self.code.extend_from_slice(&[0, 0]); // TJMP
527                let early_ret = self.code.len() as u16;
528                self.patch_u16(patch, early_ret);
529                self.code.push(0x11); // TRET
530                let next = self.code.len() as u16;
531                self.patch_u16(skip, next);
532            }
533            Expr::Index { object, row, col } => {
534                self.emit_expr(object); self.emit_expr(row); self.emit_expr(col);
535                self.code.push(0x22);
536            }
537            _ => {}
538        }
539    }
540
541    pub fn emit_entry_call(&mut self, name: &str) {
542        if let Some(&addr) = self.func_addrs.get(name) {
543            self.code.push(0x10); self.code.extend_from_slice(&addr.to_le_bytes());
544        }
545    }
546
547    pub fn get_agent_handlers(&self) -> Vec<(u16, usize)> {
548        self.agent_handlers.iter().map(|&(id, addr)| (id, addr as usize)).collect()
549    }
550
551    pub fn finalize(&mut self) -> Vec<u8> { std::mem::take(&mut self.code) }
552
553    fn patch_u16(&mut self, pos: usize, val: u16) {
554        let b = val.to_le_bytes();
555        self.code[pos] = b[0]; self.code[pos + 1] = b[1];
556    }
557}