Skip to main content

runmat_vm/compiler/
core.rs

1use crate::compiler::CompileError;
2use crate::functions::UserFunction;
3use crate::instr::{EmitLabel, Instr};
4use runmat_builtins::{self, Type};
5use runmat_hir::{HirExpr, HirExprKind, HirProgram, HirStmt};
6use std::collections::HashMap;
7
8pub struct LoopLabels {
9    pub break_jumps: Vec<usize>,
10    pub continue_jumps: Vec<usize>,
11}
12
13pub struct Compiler {
14    pub instructions: Vec<Instr>,
15    pub instr_spans: Vec<runmat_hir::Span>,
16    pub call_arg_spans: Vec<Option<Vec<runmat_hir::Span>>>,
17    pub var_count: usize,
18    pub loop_stack: Vec<LoopLabels>,
19    pub functions: HashMap<String, UserFunction>,
20    pub imports: Vec<(Vec<String>, bool)>,
21    pub var_types: Vec<Type>,
22    current_span: Option<runmat_hir::Span>,
23}
24
25struct SpanGuard {
26    compiler: *mut Compiler,
27    prev: Option<runmat_hir::Span>,
28}
29
30impl SpanGuard {
31    fn new(compiler: &mut Compiler, span: runmat_hir::Span) -> Self {
32        let prev = compiler.current_span;
33        compiler.current_span = Some(span);
34        Self {
35            compiler: compiler as *mut Compiler,
36            prev,
37        }
38    }
39}
40
41impl Drop for SpanGuard {
42    fn drop(&mut self) {
43        unsafe {
44            if let Some(compiler) = self.compiler.as_mut() {
45                compiler.current_span = self.prev;
46            }
47        }
48    }
49}
50
51impl Compiler {
52    pub(crate) fn normalize_class_literal_name(raw: &str) -> String {
53        if raw.len() >= 2 {
54            let bytes = raw.as_bytes();
55            let first = bytes[0] as char;
56            let last = bytes[raw.len() - 1] as char;
57            if (first == '\'' || first == '"') && first == last {
58                return raw[1..raw.len() - 1].to_string();
59            }
60        }
61        raw.to_string()
62    }
63
64    pub(crate) fn emit_multiassign_outputs(&mut self, vars: &[Option<runmat_hir::VarId>]) {
65        for v in vars.iter().flatten() {
66            self.emit(Instr::EmitVar {
67                var_index: v.0,
68                label: EmitLabel::Var(v.0),
69            });
70        }
71    }
72
73    pub fn new(prog: &HirProgram) -> Self {
74        let mut max_var = 0;
75        fn visit_expr(expr: &HirExpr, max: &mut usize) {
76            match &expr.kind {
77                HirExprKind::Var(id) => {
78                    if id.0 + 1 > *max {
79                        *max = id.0 + 1;
80                    }
81                }
82                HirExprKind::Unary(_, e) => visit_expr(e, max),
83                HirExprKind::Binary(left, _, right) => {
84                    visit_expr(left, max);
85                    visit_expr(right, max);
86                }
87                HirExprKind::Tensor(rows) | HirExprKind::Cell(rows) => {
88                    for row in rows {
89                        for expr in row {
90                            visit_expr(expr, max);
91                        }
92                    }
93                }
94                HirExprKind::Index(expr, indices) | HirExprKind::IndexCell(expr, indices) => {
95                    visit_expr(expr, max);
96                    for idx in indices {
97                        visit_expr(idx, max);
98                    }
99                }
100                HirExprKind::Range(start, step, end) => {
101                    visit_expr(start, max);
102                    if let Some(step) = step {
103                        visit_expr(step, max);
104                    }
105                    visit_expr(end, max);
106                }
107                HirExprKind::FuncCall(_, args) => {
108                    for arg in args {
109                        visit_expr(arg, max);
110                    }
111                }
112                HirExprKind::MethodCall(base, _, args)
113                | HirExprKind::DottedInvoke(base, _, args) => {
114                    visit_expr(base, max);
115                    for arg in args {
116                        visit_expr(arg, max);
117                    }
118                }
119                HirExprKind::Member(base, _) => visit_expr(base, max),
120                HirExprKind::MemberDynamic(base, name) => {
121                    visit_expr(base, max);
122                    visit_expr(name, max);
123                }
124                HirExprKind::AnonFunc { body, .. } => visit_expr(body, max),
125                HirExprKind::Number(_)
126                | HirExprKind::String(_)
127                | HirExprKind::Constant(_)
128                | HirExprKind::Colon
129                | HirExprKind::End
130                | HirExprKind::FuncHandle(_)
131                | HirExprKind::MetaClass(_) => {}
132            }
133        }
134
135        fn visit_stmts(stmts: &[HirStmt], max: &mut usize) {
136            for s in stmts {
137                match s {
138                    HirStmt::Assign(id, expr, _, _) => {
139                        if id.0 + 1 > *max {
140                            *max = id.0 + 1;
141                        }
142                        visit_expr(expr, max);
143                    }
144                    HirStmt::ExprStmt(expr, _, _) => visit_expr(expr, max),
145                    HirStmt::Return(_) => {}
146                    HirStmt::If {
147                        cond,
148                        then_body,
149                        elseif_blocks,
150                        else_body,
151                        ..
152                    } => {
153                        visit_expr(cond, max);
154                        visit_stmts(then_body, max);
155                        for (cond, body) in elseif_blocks {
156                            visit_expr(cond, max);
157                            visit_stmts(body, max);
158                        }
159                        if let Some(body) = else_body {
160                            visit_stmts(body, max);
161                        }
162                    }
163                    HirStmt::While { cond, body, .. } => {
164                        visit_expr(cond, max);
165                        visit_stmts(body, max);
166                    }
167                    HirStmt::For {
168                        var, expr, body, ..
169                    } => {
170                        if var.0 + 1 > *max {
171                            *max = var.0 + 1;
172                        }
173                        visit_expr(expr, max);
174                        visit_stmts(body, max);
175                    }
176                    HirStmt::Switch {
177                        expr,
178                        cases,
179                        otherwise,
180                        ..
181                    } => {
182                        visit_expr(expr, max);
183                        for (c, b) in cases {
184                            visit_expr(c, max);
185                            visit_stmts(b, max);
186                        }
187                        if let Some(b) = otherwise {
188                            visit_stmts(b, max);
189                        }
190                    }
191                    HirStmt::TryCatch {
192                        try_body,
193                        catch_var,
194                        catch_body,
195                        ..
196                    } => {
197                        if let Some(v) = catch_var {
198                            if v.0 + 1 > *max {
199                                *max = v.0 + 1;
200                            }
201                        }
202                        visit_stmts(try_body, max);
203                        visit_stmts(catch_body, max);
204                    }
205                    HirStmt::Global(vars, _) | HirStmt::Persistent(vars, _) => {
206                        for (v, _name) in vars {
207                            if v.0 + 1 > *max {
208                                *max = v.0 + 1;
209                            }
210                        }
211                    }
212                    HirStmt::AssignLValue(_, expr, _, _) => visit_expr(expr, max),
213                    HirStmt::MultiAssign(vars, expr, _, _) => {
214                        for v in vars.iter().flatten() {
215                            if v.0 + 1 > *max {
216                                *max = v.0 + 1;
217                            }
218                        }
219                        visit_expr(expr, max);
220                    }
221                    HirStmt::Function { .. }
222                    | HirStmt::ClassDef { .. }
223                    | HirStmt::Import { .. }
224                    | HirStmt::Break(_)
225                    | HirStmt::Continue(_) => {}
226                }
227            }
228        }
229
230        visit_stmts(&prog.body, &mut max_var);
231        let mut var_types = prog.var_types.clone();
232        if var_types.len() < max_var {
233            var_types.resize(max_var, Type::Unknown);
234        }
235        Self {
236            instructions: Vec::new(),
237            instr_spans: Vec::new(),
238            call_arg_spans: Vec::new(),
239            var_count: max_var,
240            loop_stack: Vec::new(),
241            functions: HashMap::new(),
242            imports: Vec::new(),
243            var_types,
244            current_span: None,
245        }
246    }
247
248    fn ensure_var(&mut self, id: usize) {
249        if id + 1 > self.var_count {
250            self.var_count = id + 1;
251        }
252        while self.var_types.len() <= id {
253            self.var_types.push(Type::Unknown);
254        }
255    }
256
257    pub(crate) fn alloc_temp(&mut self) -> usize {
258        let id = self.var_count;
259        self.var_count += 1;
260        if self.var_types.len() <= id {
261            self.var_types.push(Type::Unknown);
262        }
263        id
264    }
265
266    pub fn emit(&mut self, instr: Instr) -> usize {
267        match &instr {
268            Instr::LoadVar(id) | Instr::StoreVar(id) => self.ensure_var(*id),
269            _ => {}
270        }
271        let pc = self.instructions.len();
272        self.instructions.push(instr);
273        let span = self.current_span.unwrap_or_default();
274        self.instr_spans.push(span);
275        self.call_arg_spans.push(None);
276        pc
277    }
278
279    pub(crate) fn emit_call_with_arg_spans(
280        &mut self,
281        instr: Instr,
282        arg_spans: &[runmat_hir::Span],
283    ) -> usize {
284        let pc = self.emit(instr);
285        if !arg_spans.is_empty() {
286            if let Some(slot) = self.call_arg_spans.get_mut(pc) {
287                *slot = Some(arg_spans.to_vec());
288            }
289        }
290        pc
291    }
292
293    pub fn patch(&mut self, idx: usize, instr: Instr) {
294        self.instructions[idx] = instr;
295    }
296
297    pub(crate) fn compile_error(&self, message: impl Into<String>) -> CompileError {
298        let mut err = CompileError::new(message);
299        if let Some(span) = self.current_span {
300            err = err.with_span(span);
301        }
302        err
303    }
304
305    pub fn compile_program(&mut self, prog: &HirProgram) -> Result<(), CompileError> {
306        // Validate imports early for duplicate/specific-name ambiguities
307        runmat_hir::validate_imports(prog)?;
308        // Validate class definitions for attribute correctness and name conflicts
309        runmat_hir::validate_classdefs(prog)?;
310        // Pre-collect imports (both wildcard and specific) for name resolution
311        for stmt in &prog.body {
312            let _span_guard = SpanGuard::new(self, stmt.span());
313            if let HirStmt::Import { path, wildcard, .. } = stmt {
314                self.imports.push((path.clone(), *wildcard));
315                self.emit(Instr::RegisterImport {
316                    path: path.clone(),
317                    wildcard: *wildcard,
318                });
319            }
320            if let HirStmt::Global(vars, _) = stmt {
321                let ids: Vec<usize> = vars.iter().map(|(v, _n)| v.0).collect();
322                let names: Vec<String> = vars.iter().map(|(_v, n)| n.clone()).collect();
323                self.emit(Instr::DeclareGlobalNamed(ids, names));
324            }
325            if let HirStmt::Persistent(vars, _) = stmt {
326                let ids: Vec<usize> = vars.iter().map(|(v, _n)| v.0).collect();
327                let names: Vec<String> = vars.iter().map(|(_v, n)| n.clone()).collect();
328                self.emit(Instr::DeclarePersistentNamed(ids, names));
329            }
330        }
331        for stmt in &prog.body {
332            if !matches!(
333                stmt,
334                HirStmt::Import { .. } | HirStmt::Global(_, _) | HirStmt::Persistent(_, _)
335            ) {
336                self.compile_stmt(stmt)?;
337            }
338        }
339        Ok(())
340    }
341
342    pub fn compile_stmt(&mut self, stmt: &HirStmt) -> Result<(), CompileError> {
343        let _span_guard = SpanGuard::new(self, stmt.span());
344        self.compile_stmt_impl(stmt)
345    }
346
347    pub fn compile_expr(&mut self, expr: &HirExpr) -> Result<(), CompileError> {
348        let _span_guard = SpanGuard::new(self, expr.span);
349        self.compile_expr_impl(expr)
350    }
351}