rush_interpreter_vm/
compiler.rs

1use std::{collections::HashMap, mem, vec};
2
3use rush_analyzer::{ast::*, InfixOp, PrefixOp, Type};
4
5use crate::{
6    instruction::{self, Instruction, Program},
7    value::{Pointer, Value},
8};
9
10#[derive(Default)]
11pub struct Compiler<'src> {
12    /// The first item is the prelude: execution will start here
13    functions: Vec<Vec<Instruction>>,
14    /// Maps a function name to its position in the `functions` Vec
15    fn_names: HashMap<&'src str, usize>,
16
17    /// Maps a name to a global index and the variable type.
18    globals: HashMap<&'src str, usize>,
19
20    /// Contains the scopes of the current function. The last item is the current scope.
21    scopes: Vec<Scope<'src>>,
22
23    /// Counter for let bindings in the current function.
24    local_let_count: usize,
25
26    /// Contains the indices of `SetMp` instructions which need a value correction at the end of a
27    /// function declaration.
28    setmp_indices: Vec<usize>,
29
30    /// Contains information about the current loop(s)
31    loops: Vec<Loop>,
32}
33
34/// Maps idents to variables
35type Scope<'src> = HashMap<&'src str, Variable>;
36
37#[derive(Debug, Clone, Copy)]
38enum Variable {
39    Unit,
40    Local { offset: isize },
41    Global { addr: usize },
42}
43
44#[derive(Default)]
45struct Loop {
46    /// Specifies the instruction indices in the current function of `break` statements.
47    /// Used for replacing the offset with the real value after the loop body has been compiled.
48    break_jmp_indices: Vec<usize>,
49    /// Specifies the instruction indices in the current function of `continue` statements.
50    /// Used for replacing the offset with the real value after the loop body has been compiled.
51    continue_jmp_indices: Vec<usize>,
52}
53
54impl<'src> Compiler<'src> {
55    pub(crate) fn new() -> Self {
56        Self {
57            // begin with empty `prelude`
58            functions: vec![vec![]],
59            ..Default::default()
60        }
61    }
62
63    #[inline]
64    /// Emits a new instruction and appends it to the `instructions` [`Vec`].
65    fn insert(&mut self, instruction: Instruction) {
66        self.functions
67            .last_mut()
68            .expect("there is always a function")
69            .push(instruction)
70    }
71
72    #[inline]
73    /// Returns a reference to the current function
74    fn curr_fn(&self) -> &Vec<Instruction> {
75        self.functions.last().expect("there is always a function")
76    }
77
78    #[inline]
79    /// Returns a mutable reference to the current function
80    fn curr_fn_mut(&mut self) -> &mut Vec<Instruction> {
81        self.functions
82            .last_mut()
83            .expect("there is always a function")
84    }
85
86    #[inline]
87    /// Returns a mutable reference to the current scope
88    fn scope_mut(&mut self) -> &mut Scope<'src> {
89        self.scopes.last_mut().expect("there is always a scope")
90    }
91
92    #[inline]
93    /// Returns a mutable reference to the current loop
94    fn curr_loop_mut(&mut self) -> &mut Loop {
95        self.loops
96            .last_mut()
97            .expect("there is always a loop when called")
98    }
99
100    /// Returns the specified variable given its identifier
101    fn resolve_var(&self, name: &'src str) -> Variable {
102        for scope in self.scopes.iter().rev() {
103            if let Some(i) = scope.get(name) {
104                return *i;
105            };
106        }
107        Variable::Global {
108            addr: self.globals[name],
109        }
110    }
111
112    /// Loads the value of the specified variable name on the stack
113    fn load_var(&mut self, name: &'src str) {
114        let var = self.resolve_var(name);
115        match var {
116            Variable::Unit => {} // ignore unit / never values
117            Variable::Local { offset, .. } => {
118                self.insert(Instruction::Push(Value::Ptr(Pointer::Rel(offset))));
119                self.insert(Instruction::GetVar)
120            }
121            Variable::Global { addr } => {
122                self.insert(Instruction::Push(Value::Ptr(Pointer::Abs(addr))));
123                self.insert(Instruction::GetVar)
124            }
125        }
126    }
127
128    pub(crate) fn compile(mut self, ast: AnalyzedProgram<'src>) -> Program {
129        // map function names to indices
130        for (idx, func) in ast.functions.iter().filter(|f| f.used).enumerate() {
131            self.fn_names.insert(func.name, idx + 2);
132        }
133
134        // add stack space for the globals
135        self.insert(Instruction::SetMp(ast.globals.len() as isize));
136
137        // add global variables
138        for var in ast.globals.into_iter().filter(|g| g.used) {
139            self.declare_global(var);
140        }
141
142        // call the main fn
143        self.insert(Instruction::Call(1));
144
145        // compile the main function
146        self.main_fn(ast.main_fn);
147
148        // compile all other functions
149        for func in ast.functions.into_iter().filter(|f| f.used) {
150            self.functions.push(vec![]);
151            self.fn_declaration(func);
152        }
153
154        Program(self.functions)
155    }
156
157    fn declare_global(&mut self, node: AnalyzedLetStmt<'src>) {
158        // map the name to the new global index
159        let addr = self.globals.len();
160        self.globals.insert(node.name, addr);
161        // push global value onto the stack
162        self.expression(node.expr);
163        // pop and set the value as global
164        self.insert(Instruction::SetVarImm(Pointer::Abs(addr)));
165    }
166
167    fn fn_declaration(&mut self, node: AnalyzedFunctionDefinition<'src>) {
168        self.local_let_count = 0;
169        self.scopes.push(Scope::default());
170        mem::take(&mut self.setmp_indices);
171
172        // contains a placeholder value which is corrected later
173        let setmp_idx = self.curr_fn().len();
174        self.insert(Instruction::SetMp(isize::MAX));
175
176        for param in node.params.iter().rev() {
177            let offset = -(self.local_let_count as isize);
178
179            let var = match param.type_ {
180                Type::Unit | Type::Never => Variable::Unit,
181                _ => {
182                    self.insert(Instruction::SetVarImm(Pointer::Rel(offset)));
183                    self.local_let_count += 1;
184                    Variable::Local { offset }
185                }
186            };
187            self.scope_mut().insert(param.name, var);
188        }
189
190        self.block(node.block, false);
191
192        // correct the placeholder set mp offset
193        self.curr_fn_mut()[setmp_idx] = Instruction::SetMp(self.local_let_count as isize);
194
195        self.scopes.pop();
196
197        // `return` also deallocates space used by this function
198        let pos = self.curr_fn().len();
199        self.setmp_indices.push(pos);
200        self.insert(Instruction::SetMp(isize::MIN));
201        self.insert(Instruction::Ret);
202
203        // correct values in `SetMp` instructions before return
204        self.correct_setmp_values();
205    }
206
207    fn correct_setmp_values(&mut self) {
208        let offset = -(self.local_let_count as isize);
209        for idx in self.setmp_indices.clone() {
210            match (&mut self.curr_fn_mut()[idx], offset) {
211                (_, 0) => self.curr_fn_mut()[idx] = Instruction::Nop,
212                (Instruction::SetMp(o), _) => *o = offset,
213                other => unreachable!("other instructions do not modify mp: {other:?}"),
214            }
215        }
216    }
217
218    fn main_fn(&mut self, node: AnalyzedBlock<'src>) {
219        self.functions.push(vec![]);
220        self.local_let_count = 0;
221        self.fn_names.insert("main", 1);
222
223        // contains a placeholder value which is corrected later
224        let setmp_idx = self.curr_fn().len();
225        self.insert(Instruction::SetMp(isize::MAX));
226
227        self.block(node, true);
228
229        // correct the placeholder set mp offset
230        self.curr_fn_mut()[setmp_idx] = Instruction::SetMp(self.local_let_count as isize);
231
232        self.correct_setmp_values()
233    }
234
235    /// Compiles a block of statements.
236    /// Results in the optional expr (unit if there is none).
237    /// Automatically pushes a new [`Scope`] for the block when `new_scope` is `true`.
238    fn block(&mut self, node: AnalyzedBlock<'src>, new_scope: bool) {
239        if new_scope {
240            self.scopes.push(Scope::default());
241        }
242        for stmt in node.stmts {
243            self.statement(stmt);
244        }
245        if let Some(expr) = node.expr {
246            self.expression(expr);
247        }
248        if new_scope {
249            self.scopes.pop();
250        }
251    }
252
253    fn statement(&mut self, node: AnalyzedStatement<'src>) {
254        match node {
255            AnalyzedStatement::Let(node) => self.let_stmt(node),
256            AnalyzedStatement::Return(expr) => {
257                if let Some(expr) = expr {
258                    self.expression(expr);
259                }
260                let pos = self.curr_fn().len();
261                self.setmp_indices.push(pos);
262                self.insert(Instruction::SetMp(isize::MIN));
263                self.insert(Instruction::Ret);
264            }
265            AnalyzedStatement::Loop(node) => self.loop_stmt(node),
266            AnalyzedStatement::While(node) => self.while_stmt(node),
267            AnalyzedStatement::For(node) => self.for_stmt(node),
268            AnalyzedStatement::Break => {
269                // the jmp instruction is corrected later
270                let pos = self.curr_fn().len();
271                self.curr_loop_mut().break_jmp_indices.push(pos);
272                self.insert(Instruction::Jmp(usize::MAX));
273            }
274            AnalyzedStatement::Continue => {
275                // the jmp instruction is corrected later
276                let pos = self.curr_fn().len();
277                self.curr_loop_mut().continue_jmp_indices.push(pos);
278                self.insert(Instruction::Jmp(usize::MAX));
279            }
280            AnalyzedStatement::Expr(node) => {
281                let expr_type = node.result_type();
282                self.expression(node);
283                if !matches!(expr_type, Type::Unit | Type::Never) {
284                    self.insert(Instruction::Drop)
285                }
286            }
287        }
288    }
289
290    fn let_stmt(&mut self, node: AnalyzedLetStmt<'src>) {
291        match node.expr.result_type() {
292            Type::Unit | Type::Never => {
293                self.expression(node.expr);
294                self.scope_mut().insert(node.name, Variable::Unit);
295            }
296            _ => {
297                self.expression(node.expr);
298
299                let offset = -(self.local_let_count as isize);
300                self.insert(Instruction::SetVarImm(Pointer::Rel(offset)));
301
302                self.scope_mut()
303                    .insert(node.name, Variable::Local { offset });
304                self.local_let_count += 1;
305            }
306        }
307    }
308
309    /// Fills in any blank-value `jmp` / `jmpfalse` instructions to point to the specified target.
310    fn fill_blank_jmps(&mut self, jmps: &[usize], target: usize) {
311        for idx in jmps {
312            match &mut self.curr_fn_mut()[*idx] {
313                Instruction::Jmp(o) => *o = target,
314                Instruction::JmpFalse(o) => *o = target,
315                _ => unreachable!("other instructions do not jump"),
316            }
317        }
318    }
319
320    fn loop_stmt(&mut self, node: AnalyzedLoopStmt<'src>) {
321        // save location of the loop head (for continue stmts)
322        let loop_head_pos = self.curr_fn().len();
323        self.loops.push(Loop::default());
324
325        // compile the loop body
326        let block_expr_type = node
327            .block
328            .expr
329            .as_ref()
330            .map_or(Type::Unit, |expr| expr.result_type());
331        self.block(node.block, true);
332        if !matches!(block_expr_type, Type::Unit | Type::Never) {
333            self.insert(Instruction::Drop);
334        }
335
336        // jump back to the top
337        self.insert(Instruction::Jmp(loop_head_pos));
338
339        // correct placeholder `break` / `continue` values
340        let loop_ = self.loops.pop().expect("pushed above");
341        let pos = self.curr_fn().len();
342        self.fill_blank_jmps(&loop_.break_jmp_indices, pos);
343        self.fill_blank_jmps(&loop_.continue_jmp_indices, loop_head_pos);
344    }
345
346    fn while_stmt(&mut self, node: AnalyzedWhileStmt<'src>) {
347        // save location of the loop head (for continue stmts)
348        let loop_head_pos = self.curr_fn().len();
349
350        // compile the while condition
351        self.expression(node.cond);
352
353        // push the loop here (`continue` / `break` can be in cond)
354        self.loops.push(Loop::default());
355
356        // jump to the end if the condition is false
357        let end = self.curr_fn().len();
358        self.curr_loop_mut().break_jmp_indices.push(end);
359        self.insert(Instruction::JmpFalse(usize::MAX));
360
361        // compile the loop body
362        let block_expr_type = node
363            .block
364            .expr
365            .as_ref()
366            .map_or(Type::Unit, |expr| expr.result_type());
367        self.block(node.block, true);
368        if !matches!(block_expr_type, Type::Unit | Type::Never) {
369            self.insert(Instruction::Drop);
370        }
371
372        // jump back to the top
373        self.insert(Instruction::Jmp(loop_head_pos));
374
375        // correct placeholder `break` / `continue` values
376        let loop_ = self.loops.pop().expect("pushed above");
377        let pos = self.curr_fn().len();
378        self.fill_blank_jmps(&loop_.break_jmp_indices, pos);
379        self.fill_blank_jmps(&loop_.continue_jmp_indices, loop_head_pos);
380    }
381
382    fn for_stmt(&mut self, node: AnalyzedForStmt<'src>) {
383        // compile the init expression
384        self.scopes.push(HashMap::new());
385        match node.initializer.result_type() {
386            Type::Unit | Type::Never => {
387                self.expression(node.initializer);
388                self.scope_mut().insert(node.ident, Variable::Unit);
389            }
390            _ => {
391                self.expression(node.initializer);
392                let offset = self.local_let_count as isize;
393                self.insert(Instruction::SetVarImm(Pointer::Rel(offset)));
394                self.scope_mut()
395                    .insert(node.ident, Variable::Local { offset });
396                self.local_let_count += 1;
397            }
398        }
399
400        // save location of the loop head (for repetition)
401        let loop_head_pos = self.curr_fn().len();
402
403        // compile the condition expr
404        self.expression(node.cond);
405
406        self.loops.push(Loop::default());
407
408        // jump to the end of the loop if the condition is false
409        let curr_pos = self.curr_fn().len();
410        self.curr_loop_mut().break_jmp_indices.push(curr_pos);
411        self.insert(Instruction::JmpFalse(usize::MAX));
412
413        let block_expr_type = node
414            .block
415            .expr
416            .as_ref()
417            .map_or(Type::Unit, |expr| expr.result_type());
418        self.block(node.block, true);
419        if !matches!(block_expr_type, Type::Unit | Type::Never) {
420            self.insert(Instruction::Drop);
421        }
422
423        // correct placeholder `continue` values
424        let curr_pos = self.curr_fn().len();
425        let loop_ = self.loops.pop().expect("pushed above");
426        self.fill_blank_jmps(&loop_.continue_jmp_indices, curr_pos);
427
428        // compile the update expression
429        let update_type = node.update.result_type();
430        self.expression(node.update);
431        if !matches!(update_type, Type::Unit | Type::Never) {
432            self.insert(Instruction::Drop);
433        }
434
435        // jump back to the top
436        self.insert(Instruction::Jmp(loop_head_pos));
437
438        // correct placeholder break values
439        let pos = self.curr_fn().len();
440        self.fill_blank_jmps(&loop_.break_jmp_indices, pos);
441
442        self.scopes.pop();
443    }
444
445    fn expression(&mut self, node: AnalyzedExpression<'src>) {
446        match node {
447            AnalyzedExpression::Int(value) => self.insert(Instruction::Push(Value::Int(value))),
448            AnalyzedExpression::Float(value) => self.insert(Instruction::Push(Value::Float(value))),
449            AnalyzedExpression::Bool(value) => self.insert(Instruction::Push(Value::Bool(value))),
450            AnalyzedExpression::Char(value) => self.insert(Instruction::Push(Value::Char(value))),
451            AnalyzedExpression::Ident(node) => self.load_var(node.ident),
452            AnalyzedExpression::Block(node) => self.block(*node, true),
453            AnalyzedExpression::If(node) => self.if_expr(*node),
454            AnalyzedExpression::Prefix(node) => self.prefix_expr(*node),
455            AnalyzedExpression::Infix(node) => self.infix_expr(*node),
456            AnalyzedExpression::Assign(node) => self.assign_expr(*node),
457            AnalyzedExpression::Call(node) => self.call_expr(*node),
458            AnalyzedExpression::Cast(node) => self.cast_expr(*node),
459            AnalyzedExpression::Grouped(node) => self.expression(*node),
460        }
461    }
462
463    fn if_expr(&mut self, node: AnalyzedIfExpr<'src>) {
464        // compile the condition
465        self.expression(node.cond);
466        let after_condition = self.curr_fn().len();
467        self.insert(Instruction::JmpFalse(usize::MAX)); // placeholder
468
469        // compile the `then` branch
470        self.block(node.then_block, true);
471        let after_then_idx = self.curr_fn().len();
472
473        if let Some(else_block) = node.else_block {
474            self.insert(Instruction::Jmp(usize::MAX)); // placeholder
475
476            // if there is `else`, jump to the instruction after the jump after `then`
477            self.curr_fn_mut()[after_condition] = Instruction::JmpFalse(after_then_idx + 1);
478
479            self.block(else_block, true);
480            let after_else = self.curr_fn().len();
481
482            // skip the `else` block when coming from the `then` block
483            self.curr_fn_mut()[after_then_idx] = Instruction::Jmp(after_else);
484        } else {
485            // if there is no `else` branch, jump after the last instruction of the `then` branch
486            self.curr_fn_mut()[after_condition] = Instruction::JmpFalse(after_then_idx);
487        }
488    }
489
490    fn prefix_expr(&mut self, node: AnalyzedPrefixExpr<'src>) {
491        match Instruction::try_from(node.op) {
492            Ok(insruction) => {
493                self.expression(node.expr);
494                self.insert(insruction)
495            }
496            Err(_) => match node.op == PrefixOp::Ref {
497                //ref
498                true => {
499                    if let AnalyzedExpression::Ident(ident) = node.expr {
500                        match self.resolve_var(ident.ident) {
501                            Variable::Local { offset, .. } => {
502                                self.insert(Instruction::RelToAddr(offset))
503                            }
504                            Variable::Global { addr } => {
505                                self.insert(Instruction::Push(Value::Ptr(Pointer::Abs(addr))));
506                            }
507                            Variable::Unit => unreachable!("unit values cannot be referenced"),
508                        }
509                        return;
510                    }
511                    unreachable!("the parser guarantees that only idents can be referenced")
512                }
513                // deref
514                false => {
515                    self.expression(node.expr);
516                    self.insert(Instruction::GetVar)
517                }
518            },
519        }
520    }
521
522    fn infix_expr(&mut self, node: AnalyzedInfixExpr<'src>) {
523        match node.op {
524            InfixOp::Or | InfixOp::And => {
525                self.expression(node.lhs);
526                if node.op == InfixOp::Or {
527                    self.insert(Instruction::Not);
528                }
529                let merge_jmp_idx = self.curr_fn().len();
530                self.insert(Instruction::JmpFalse(usize::MAX));
531                self.expression(node.rhs);
532                let pos = self.curr_fn().len() + 2;
533                self.insert(Instruction::Jmp(pos));
534                self.insert(Instruction::Push(Value::Bool(node.op == InfixOp::Or)));
535                self.curr_fn_mut()[merge_jmp_idx] = Instruction::JmpFalse(self.curr_fn().len() - 1);
536            }
537            op => {
538                self.expression(node.lhs);
539                self.expression(node.rhs);
540                self.insert(Instruction::from(op));
541            }
542        }
543    }
544
545    fn assign_expr(&mut self, node: AnalyzedAssignExpr<'src>) {
546        let assignee = self.resolve_var(node.assignee);
547
548        let ptr = match assignee {
549            Variable::Local { offset } => Pointer::Rel(offset),
550            Variable::Global { addr } => Pointer::Abs(addr),
551            Variable::Unit => unreachable!("cannot assign to unit values"),
552        };
553
554        self.insert(Instruction::Push(Value::Ptr(ptr)));
555
556        let mut ptr_count = node.assignee_ptr_count;
557        while ptr_count > 0 {
558            self.insert(Instruction::GetVar);
559            ptr_count -= 1;
560        }
561
562        match node.op.try_into() {
563            Ok(instruction) => {
564                // insert a clone so that th setter instructions can still use the index
565                self.insert(Instruction::Clone);
566
567                // load the assignee value
568                match assignee {
569                    Variable::Unit => {}
570                    _ => self.insert(Instruction::GetVar),
571                };
572
573                self.expression(node.expr);
574                self.insert(instruction);
575            }
576            Err(()) => self.expression(node.expr),
577        }
578
579        match assignee {
580            Variable::Unit => {}
581            _ => self.insert(Instruction::SetVar),
582        };
583    }
584
585    fn call_expr(&mut self, node: AnalyzedCallExpr<'src>) {
586        for arg in node.args {
587            self.expression(arg);
588        }
589
590        match node.func {
591            "exit" => self.insert(Instruction::Exit),
592            func => {
593                let fn_idx = self.fn_names[func];
594                self.insert(Instruction::Call(fn_idx));
595            }
596        }
597    }
598
599    fn cast_expr(&mut self, node: AnalyzedCastExpr<'src>) {
600        let expr_type = node.expr.result_type();
601        self.expression(node.expr);
602        match (expr_type, node.type_) {
603            (from, to) if from == to => {}
604            (_, to) => self.insert(Instruction::Cast(instruction::Type::from(to))),
605        }
606    }
607}