1use crate::compiler::CompileError;
2use crate::functions::UserFunction;
3use crate::instr::{EmitLabel, Instr};
4use runmat_builtins::{self, Type};
5use runmat_hir::{HirExpr, HirExprKind, HirProgram, HirStmt};
6use std::collections::HashMap;
7
8pub struct LoopLabels {
9 pub break_jumps: Vec<usize>,
10 pub continue_jumps: Vec<usize>,
11}
12
13pub struct Compiler {
14 pub instructions: Vec<Instr>,
15 pub instr_spans: Vec<runmat_hir::Span>,
16 pub call_arg_spans: Vec<Option<Vec<runmat_hir::Span>>>,
17 pub var_count: usize,
18 pub loop_stack: Vec<LoopLabels>,
19 pub functions: HashMap<String, UserFunction>,
20 pub imports: Vec<(Vec<String>, bool)>,
21 pub var_types: Vec<Type>,
22 current_span: Option<runmat_hir::Span>,
23}
24
25struct SpanGuard {
26 compiler: *mut Compiler,
27 prev: Option<runmat_hir::Span>,
28}
29
30impl SpanGuard {
31 fn new(compiler: &mut Compiler, span: runmat_hir::Span) -> Self {
32 let prev = compiler.current_span;
33 compiler.current_span = Some(span);
34 Self {
35 compiler: compiler as *mut Compiler,
36 prev,
37 }
38 }
39}
40
41impl Drop for SpanGuard {
42 fn drop(&mut self) {
43 unsafe {
44 if let Some(compiler) = self.compiler.as_mut() {
45 compiler.current_span = self.prev;
46 }
47 }
48 }
49}
50
51impl Compiler {
52 pub(crate) fn normalize_class_literal_name(raw: &str) -> String {
53 if raw.len() >= 2 {
54 let bytes = raw.as_bytes();
55 let first = bytes[0] as char;
56 let last = bytes[raw.len() - 1] as char;
57 if (first == '\'' || first == '"') && first == last {
58 return raw[1..raw.len() - 1].to_string();
59 }
60 }
61 raw.to_string()
62 }
63
64 pub(crate) fn emit_multiassign_outputs(&mut self, vars: &[Option<runmat_hir::VarId>]) {
65 for v in vars.iter().flatten() {
66 self.emit(Instr::EmitVar {
67 var_index: v.0,
68 label: EmitLabel::Var(v.0),
69 });
70 }
71 }
72
73 pub fn new(prog: &HirProgram) -> Self {
74 let mut max_var = 0;
75 fn visit_expr(expr: &HirExpr, max: &mut usize) {
76 match &expr.kind {
77 HirExprKind::Var(id) => {
78 if id.0 + 1 > *max {
79 *max = id.0 + 1;
80 }
81 }
82 HirExprKind::Unary(_, e) => visit_expr(e, max),
83 HirExprKind::Binary(left, _, right) => {
84 visit_expr(left, max);
85 visit_expr(right, max);
86 }
87 HirExprKind::Tensor(rows) | HirExprKind::Cell(rows) => {
88 for row in rows {
89 for expr in row {
90 visit_expr(expr, max);
91 }
92 }
93 }
94 HirExprKind::Index(expr, indices) | HirExprKind::IndexCell(expr, indices) => {
95 visit_expr(expr, max);
96 for idx in indices {
97 visit_expr(idx, max);
98 }
99 }
100 HirExprKind::Range(start, step, end) => {
101 visit_expr(start, max);
102 if let Some(step) = step {
103 visit_expr(step, max);
104 }
105 visit_expr(end, max);
106 }
107 HirExprKind::FuncCall(_, args) => {
108 for arg in args {
109 visit_expr(arg, max);
110 }
111 }
112 HirExprKind::MethodCall(base, _, args)
113 | HirExprKind::DottedInvoke(base, _, args) => {
114 visit_expr(base, max);
115 for arg in args {
116 visit_expr(arg, max);
117 }
118 }
119 HirExprKind::Member(base, _) => visit_expr(base, max),
120 HirExprKind::MemberDynamic(base, name) => {
121 visit_expr(base, max);
122 visit_expr(name, max);
123 }
124 HirExprKind::AnonFunc { body, .. } => visit_expr(body, max),
125 HirExprKind::Number(_)
126 | HirExprKind::String(_)
127 | HirExprKind::Constant(_)
128 | HirExprKind::Colon
129 | HirExprKind::End
130 | HirExprKind::FuncHandle(_)
131 | HirExprKind::MetaClass(_) => {}
132 }
133 }
134
135 fn visit_stmts(stmts: &[HirStmt], max: &mut usize) {
136 for s in stmts {
137 match s {
138 HirStmt::Assign(id, expr, _, _) => {
139 if id.0 + 1 > *max {
140 *max = id.0 + 1;
141 }
142 visit_expr(expr, max);
143 }
144 HirStmt::ExprStmt(expr, _, _) => visit_expr(expr, max),
145 HirStmt::Return(_) => {}
146 HirStmt::If {
147 cond,
148 then_body,
149 elseif_blocks,
150 else_body,
151 ..
152 } => {
153 visit_expr(cond, max);
154 visit_stmts(then_body, max);
155 for (cond, body) in elseif_blocks {
156 visit_expr(cond, max);
157 visit_stmts(body, max);
158 }
159 if let Some(body) = else_body {
160 visit_stmts(body, max);
161 }
162 }
163 HirStmt::While { cond, body, .. } => {
164 visit_expr(cond, max);
165 visit_stmts(body, max);
166 }
167 HirStmt::For {
168 var, expr, body, ..
169 } => {
170 if var.0 + 1 > *max {
171 *max = var.0 + 1;
172 }
173 visit_expr(expr, max);
174 visit_stmts(body, max);
175 }
176 HirStmt::Switch {
177 expr,
178 cases,
179 otherwise,
180 ..
181 } => {
182 visit_expr(expr, max);
183 for (c, b) in cases {
184 visit_expr(c, max);
185 visit_stmts(b, max);
186 }
187 if let Some(b) = otherwise {
188 visit_stmts(b, max);
189 }
190 }
191 HirStmt::TryCatch {
192 try_body,
193 catch_var,
194 catch_body,
195 ..
196 } => {
197 if let Some(v) = catch_var {
198 if v.0 + 1 > *max {
199 *max = v.0 + 1;
200 }
201 }
202 visit_stmts(try_body, max);
203 visit_stmts(catch_body, max);
204 }
205 HirStmt::Global(vars, _) | HirStmt::Persistent(vars, _) => {
206 for (v, _name) in vars {
207 if v.0 + 1 > *max {
208 *max = v.0 + 1;
209 }
210 }
211 }
212 HirStmt::AssignLValue(_, expr, _, _) => visit_expr(expr, max),
213 HirStmt::MultiAssign(vars, expr, _, _) => {
214 for v in vars.iter().flatten() {
215 if v.0 + 1 > *max {
216 *max = v.0 + 1;
217 }
218 }
219 visit_expr(expr, max);
220 }
221 HirStmt::Function { .. }
222 | HirStmt::ClassDef { .. }
223 | HirStmt::Import { .. }
224 | HirStmt::Break(_)
225 | HirStmt::Continue(_) => {}
226 }
227 }
228 }
229
230 visit_stmts(&prog.body, &mut max_var);
231 let mut var_types = prog.var_types.clone();
232 if var_types.len() < max_var {
233 var_types.resize(max_var, Type::Unknown);
234 }
235 Self {
236 instructions: Vec::new(),
237 instr_spans: Vec::new(),
238 call_arg_spans: Vec::new(),
239 var_count: max_var,
240 loop_stack: Vec::new(),
241 functions: HashMap::new(),
242 imports: Vec::new(),
243 var_types,
244 current_span: None,
245 }
246 }
247
248 fn ensure_var(&mut self, id: usize) {
249 if id + 1 > self.var_count {
250 self.var_count = id + 1;
251 }
252 while self.var_types.len() <= id {
253 self.var_types.push(Type::Unknown);
254 }
255 }
256
257 pub(crate) fn alloc_temp(&mut self) -> usize {
258 let id = self.var_count;
259 self.var_count += 1;
260 if self.var_types.len() <= id {
261 self.var_types.push(Type::Unknown);
262 }
263 id
264 }
265
266 pub fn emit(&mut self, instr: Instr) -> usize {
267 match &instr {
268 Instr::LoadVar(id) | Instr::StoreVar(id) => self.ensure_var(*id),
269 _ => {}
270 }
271 let pc = self.instructions.len();
272 self.instructions.push(instr);
273 let span = self.current_span.unwrap_or_default();
274 self.instr_spans.push(span);
275 self.call_arg_spans.push(None);
276 pc
277 }
278
279 pub(crate) fn emit_call_with_arg_spans(
280 &mut self,
281 instr: Instr,
282 arg_spans: &[runmat_hir::Span],
283 ) -> usize {
284 let pc = self.emit(instr);
285 if !arg_spans.is_empty() {
286 if let Some(slot) = self.call_arg_spans.get_mut(pc) {
287 *slot = Some(arg_spans.to_vec());
288 }
289 }
290 pc
291 }
292
293 pub fn patch(&mut self, idx: usize, instr: Instr) {
294 self.instructions[idx] = instr;
295 }
296
297 pub(crate) fn compile_error(&self, message: impl Into<String>) -> CompileError {
298 let mut err = CompileError::new(message);
299 if let Some(span) = self.current_span {
300 err = err.with_span(span);
301 }
302 err
303 }
304
305 pub fn compile_program(&mut self, prog: &HirProgram) -> Result<(), CompileError> {
306 runmat_hir::validate_imports(prog)?;
308 runmat_hir::validate_classdefs(prog)?;
310 for stmt in &prog.body {
312 let _span_guard = SpanGuard::new(self, stmt.span());
313 if let HirStmt::Import { path, wildcard, .. } = stmt {
314 self.imports.push((path.clone(), *wildcard));
315 self.emit(Instr::RegisterImport {
316 path: path.clone(),
317 wildcard: *wildcard,
318 });
319 }
320 if let HirStmt::Global(vars, _) = stmt {
321 let ids: Vec<usize> = vars.iter().map(|(v, _n)| v.0).collect();
322 let names: Vec<String> = vars.iter().map(|(_v, n)| n.clone()).collect();
323 self.emit(Instr::DeclareGlobalNamed(ids, names));
324 }
325 if let HirStmt::Persistent(vars, _) = stmt {
326 let ids: Vec<usize> = vars.iter().map(|(v, _n)| v.0).collect();
327 let names: Vec<String> = vars.iter().map(|(_v, n)| n.clone()).collect();
328 self.emit(Instr::DeclarePersistentNamed(ids, names));
329 }
330 }
331 for stmt in &prog.body {
332 if !matches!(
333 stmt,
334 HirStmt::Import { .. } | HirStmt::Global(_, _) | HirStmt::Persistent(_, _)
335 ) {
336 self.compile_stmt(stmt)?;
337 }
338 }
339 Ok(())
340 }
341
342 pub fn compile_stmt(&mut self, stmt: &HirStmt) -> Result<(), CompileError> {
343 let _span_guard = SpanGuard::new(self, stmt.span());
344 self.compile_stmt_impl(stmt)
345 }
346
347 pub fn compile_expr(&mut self, expr: &HirExpr) -> Result<(), CompileError> {
348 let _span_guard = SpanGuard::new(self, expr.span);
349 self.compile_expr_impl(expr)
350 }
351}