1use std::{collections::HashMap, mem, vec};
2
3use rush_analyzer::{ast::*, InfixOp, PrefixOp, Type};
4
5use crate::{
6 instruction::{self, Instruction, Program},
7 value::{Pointer, Value},
8};
9
10#[derive(Default)]
11pub struct Compiler<'src> {
12 functions: Vec<Vec<Instruction>>,
14 fn_names: HashMap<&'src str, usize>,
16
17 globals: HashMap<&'src str, usize>,
19
20 scopes: Vec<Scope<'src>>,
22
23 local_let_count: usize,
25
26 setmp_indices: Vec<usize>,
29
30 loops: Vec<Loop>,
32}
33
34type Scope<'src> = HashMap<&'src str, Variable>;
36
37#[derive(Debug, Clone, Copy)]
38enum Variable {
39 Unit,
40 Local { offset: isize },
41 Global { addr: usize },
42}
43
44#[derive(Default)]
45struct Loop {
46 break_jmp_indices: Vec<usize>,
49 continue_jmp_indices: Vec<usize>,
52}
53
54impl<'src> Compiler<'src> {
55 pub(crate) fn new() -> Self {
56 Self {
57 functions: vec![vec![]],
59 ..Default::default()
60 }
61 }
62
63 #[inline]
64 fn insert(&mut self, instruction: Instruction) {
66 self.functions
67 .last_mut()
68 .expect("there is always a function")
69 .push(instruction)
70 }
71
72 #[inline]
73 fn curr_fn(&self) -> &Vec<Instruction> {
75 self.functions.last().expect("there is always a function")
76 }
77
78 #[inline]
79 fn curr_fn_mut(&mut self) -> &mut Vec<Instruction> {
81 self.functions
82 .last_mut()
83 .expect("there is always a function")
84 }
85
86 #[inline]
87 fn scope_mut(&mut self) -> &mut Scope<'src> {
89 self.scopes.last_mut().expect("there is always a scope")
90 }
91
92 #[inline]
93 fn curr_loop_mut(&mut self) -> &mut Loop {
95 self.loops
96 .last_mut()
97 .expect("there is always a loop when called")
98 }
99
100 fn resolve_var(&self, name: &'src str) -> Variable {
102 for scope in self.scopes.iter().rev() {
103 if let Some(i) = scope.get(name) {
104 return *i;
105 };
106 }
107 Variable::Global {
108 addr: self.globals[name],
109 }
110 }
111
112 fn load_var(&mut self, name: &'src str) {
114 let var = self.resolve_var(name);
115 match var {
116 Variable::Unit => {} Variable::Local { offset, .. } => {
118 self.insert(Instruction::Push(Value::Ptr(Pointer::Rel(offset))));
119 self.insert(Instruction::GetVar)
120 }
121 Variable::Global { addr } => {
122 self.insert(Instruction::Push(Value::Ptr(Pointer::Abs(addr))));
123 self.insert(Instruction::GetVar)
124 }
125 }
126 }
127
128 pub(crate) fn compile(mut self, ast: AnalyzedProgram<'src>) -> Program {
129 for (idx, func) in ast.functions.iter().filter(|f| f.used).enumerate() {
131 self.fn_names.insert(func.name, idx + 2);
132 }
133
134 self.insert(Instruction::SetMp(ast.globals.len() as isize));
136
137 for var in ast.globals.into_iter().filter(|g| g.used) {
139 self.declare_global(var);
140 }
141
142 self.insert(Instruction::Call(1));
144
145 self.main_fn(ast.main_fn);
147
148 for func in ast.functions.into_iter().filter(|f| f.used) {
150 self.functions.push(vec![]);
151 self.fn_declaration(func);
152 }
153
154 Program(self.functions)
155 }
156
157 fn declare_global(&mut self, node: AnalyzedLetStmt<'src>) {
158 let addr = self.globals.len();
160 self.globals.insert(node.name, addr);
161 self.expression(node.expr);
163 self.insert(Instruction::SetVarImm(Pointer::Abs(addr)));
165 }
166
167 fn fn_declaration(&mut self, node: AnalyzedFunctionDefinition<'src>) {
168 self.local_let_count = 0;
169 self.scopes.push(Scope::default());
170 mem::take(&mut self.setmp_indices);
171
172 let setmp_idx = self.curr_fn().len();
174 self.insert(Instruction::SetMp(isize::MAX));
175
176 for param in node.params.iter().rev() {
177 let offset = -(self.local_let_count as isize);
178
179 let var = match param.type_ {
180 Type::Unit | Type::Never => Variable::Unit,
181 _ => {
182 self.insert(Instruction::SetVarImm(Pointer::Rel(offset)));
183 self.local_let_count += 1;
184 Variable::Local { offset }
185 }
186 };
187 self.scope_mut().insert(param.name, var);
188 }
189
190 self.block(node.block, false);
191
192 self.curr_fn_mut()[setmp_idx] = Instruction::SetMp(self.local_let_count as isize);
194
195 self.scopes.pop();
196
197 let pos = self.curr_fn().len();
199 self.setmp_indices.push(pos);
200 self.insert(Instruction::SetMp(isize::MIN));
201 self.insert(Instruction::Ret);
202
203 self.correct_setmp_values();
205 }
206
207 fn correct_setmp_values(&mut self) {
208 let offset = -(self.local_let_count as isize);
209 for idx in self.setmp_indices.clone() {
210 match (&mut self.curr_fn_mut()[idx], offset) {
211 (_, 0) => self.curr_fn_mut()[idx] = Instruction::Nop,
212 (Instruction::SetMp(o), _) => *o = offset,
213 other => unreachable!("other instructions do not modify mp: {other:?}"),
214 }
215 }
216 }
217
218 fn main_fn(&mut self, node: AnalyzedBlock<'src>) {
219 self.functions.push(vec![]);
220 self.local_let_count = 0;
221 self.fn_names.insert("main", 1);
222
223 let setmp_idx = self.curr_fn().len();
225 self.insert(Instruction::SetMp(isize::MAX));
226
227 self.block(node, true);
228
229 self.curr_fn_mut()[setmp_idx] = Instruction::SetMp(self.local_let_count as isize);
231
232 self.correct_setmp_values()
233 }
234
235 fn block(&mut self, node: AnalyzedBlock<'src>, new_scope: bool) {
239 if new_scope {
240 self.scopes.push(Scope::default());
241 }
242 for stmt in node.stmts {
243 self.statement(stmt);
244 }
245 if let Some(expr) = node.expr {
246 self.expression(expr);
247 }
248 if new_scope {
249 self.scopes.pop();
250 }
251 }
252
253 fn statement(&mut self, node: AnalyzedStatement<'src>) {
254 match node {
255 AnalyzedStatement::Let(node) => self.let_stmt(node),
256 AnalyzedStatement::Return(expr) => {
257 if let Some(expr) = expr {
258 self.expression(expr);
259 }
260 let pos = self.curr_fn().len();
261 self.setmp_indices.push(pos);
262 self.insert(Instruction::SetMp(isize::MIN));
263 self.insert(Instruction::Ret);
264 }
265 AnalyzedStatement::Loop(node) => self.loop_stmt(node),
266 AnalyzedStatement::While(node) => self.while_stmt(node),
267 AnalyzedStatement::For(node) => self.for_stmt(node),
268 AnalyzedStatement::Break => {
269 let pos = self.curr_fn().len();
271 self.curr_loop_mut().break_jmp_indices.push(pos);
272 self.insert(Instruction::Jmp(usize::MAX));
273 }
274 AnalyzedStatement::Continue => {
275 let pos = self.curr_fn().len();
277 self.curr_loop_mut().continue_jmp_indices.push(pos);
278 self.insert(Instruction::Jmp(usize::MAX));
279 }
280 AnalyzedStatement::Expr(node) => {
281 let expr_type = node.result_type();
282 self.expression(node);
283 if !matches!(expr_type, Type::Unit | Type::Never) {
284 self.insert(Instruction::Drop)
285 }
286 }
287 }
288 }
289
290 fn let_stmt(&mut self, node: AnalyzedLetStmt<'src>) {
291 match node.expr.result_type() {
292 Type::Unit | Type::Never => {
293 self.expression(node.expr);
294 self.scope_mut().insert(node.name, Variable::Unit);
295 }
296 _ => {
297 self.expression(node.expr);
298
299 let offset = -(self.local_let_count as isize);
300 self.insert(Instruction::SetVarImm(Pointer::Rel(offset)));
301
302 self.scope_mut()
303 .insert(node.name, Variable::Local { offset });
304 self.local_let_count += 1;
305 }
306 }
307 }
308
309 fn fill_blank_jmps(&mut self, jmps: &[usize], target: usize) {
311 for idx in jmps {
312 match &mut self.curr_fn_mut()[*idx] {
313 Instruction::Jmp(o) => *o = target,
314 Instruction::JmpFalse(o) => *o = target,
315 _ => unreachable!("other instructions do not jump"),
316 }
317 }
318 }
319
320 fn loop_stmt(&mut self, node: AnalyzedLoopStmt<'src>) {
321 let loop_head_pos = self.curr_fn().len();
323 self.loops.push(Loop::default());
324
325 let block_expr_type = node
327 .block
328 .expr
329 .as_ref()
330 .map_or(Type::Unit, |expr| expr.result_type());
331 self.block(node.block, true);
332 if !matches!(block_expr_type, Type::Unit | Type::Never) {
333 self.insert(Instruction::Drop);
334 }
335
336 self.insert(Instruction::Jmp(loop_head_pos));
338
339 let loop_ = self.loops.pop().expect("pushed above");
341 let pos = self.curr_fn().len();
342 self.fill_blank_jmps(&loop_.break_jmp_indices, pos);
343 self.fill_blank_jmps(&loop_.continue_jmp_indices, loop_head_pos);
344 }
345
346 fn while_stmt(&mut self, node: AnalyzedWhileStmt<'src>) {
347 let loop_head_pos = self.curr_fn().len();
349
350 self.expression(node.cond);
352
353 self.loops.push(Loop::default());
355
356 let end = self.curr_fn().len();
358 self.curr_loop_mut().break_jmp_indices.push(end);
359 self.insert(Instruction::JmpFalse(usize::MAX));
360
361 let block_expr_type = node
363 .block
364 .expr
365 .as_ref()
366 .map_or(Type::Unit, |expr| expr.result_type());
367 self.block(node.block, true);
368 if !matches!(block_expr_type, Type::Unit | Type::Never) {
369 self.insert(Instruction::Drop);
370 }
371
372 self.insert(Instruction::Jmp(loop_head_pos));
374
375 let loop_ = self.loops.pop().expect("pushed above");
377 let pos = self.curr_fn().len();
378 self.fill_blank_jmps(&loop_.break_jmp_indices, pos);
379 self.fill_blank_jmps(&loop_.continue_jmp_indices, loop_head_pos);
380 }
381
382 fn for_stmt(&mut self, node: AnalyzedForStmt<'src>) {
383 self.scopes.push(HashMap::new());
385 match node.initializer.result_type() {
386 Type::Unit | Type::Never => {
387 self.expression(node.initializer);
388 self.scope_mut().insert(node.ident, Variable::Unit);
389 }
390 _ => {
391 self.expression(node.initializer);
392 let offset = self.local_let_count as isize;
393 self.insert(Instruction::SetVarImm(Pointer::Rel(offset)));
394 self.scope_mut()
395 .insert(node.ident, Variable::Local { offset });
396 self.local_let_count += 1;
397 }
398 }
399
400 let loop_head_pos = self.curr_fn().len();
402
403 self.expression(node.cond);
405
406 self.loops.push(Loop::default());
407
408 let curr_pos = self.curr_fn().len();
410 self.curr_loop_mut().break_jmp_indices.push(curr_pos);
411 self.insert(Instruction::JmpFalse(usize::MAX));
412
413 let block_expr_type = node
414 .block
415 .expr
416 .as_ref()
417 .map_or(Type::Unit, |expr| expr.result_type());
418 self.block(node.block, true);
419 if !matches!(block_expr_type, Type::Unit | Type::Never) {
420 self.insert(Instruction::Drop);
421 }
422
423 let curr_pos = self.curr_fn().len();
425 let loop_ = self.loops.pop().expect("pushed above");
426 self.fill_blank_jmps(&loop_.continue_jmp_indices, curr_pos);
427
428 let update_type = node.update.result_type();
430 self.expression(node.update);
431 if !matches!(update_type, Type::Unit | Type::Never) {
432 self.insert(Instruction::Drop);
433 }
434
435 self.insert(Instruction::Jmp(loop_head_pos));
437
438 let pos = self.curr_fn().len();
440 self.fill_blank_jmps(&loop_.break_jmp_indices, pos);
441
442 self.scopes.pop();
443 }
444
445 fn expression(&mut self, node: AnalyzedExpression<'src>) {
446 match node {
447 AnalyzedExpression::Int(value) => self.insert(Instruction::Push(Value::Int(value))),
448 AnalyzedExpression::Float(value) => self.insert(Instruction::Push(Value::Float(value))),
449 AnalyzedExpression::Bool(value) => self.insert(Instruction::Push(Value::Bool(value))),
450 AnalyzedExpression::Char(value) => self.insert(Instruction::Push(Value::Char(value))),
451 AnalyzedExpression::Ident(node) => self.load_var(node.ident),
452 AnalyzedExpression::Block(node) => self.block(*node, true),
453 AnalyzedExpression::If(node) => self.if_expr(*node),
454 AnalyzedExpression::Prefix(node) => self.prefix_expr(*node),
455 AnalyzedExpression::Infix(node) => self.infix_expr(*node),
456 AnalyzedExpression::Assign(node) => self.assign_expr(*node),
457 AnalyzedExpression::Call(node) => self.call_expr(*node),
458 AnalyzedExpression::Cast(node) => self.cast_expr(*node),
459 AnalyzedExpression::Grouped(node) => self.expression(*node),
460 }
461 }
462
463 fn if_expr(&mut self, node: AnalyzedIfExpr<'src>) {
464 self.expression(node.cond);
466 let after_condition = self.curr_fn().len();
467 self.insert(Instruction::JmpFalse(usize::MAX)); self.block(node.then_block, true);
471 let after_then_idx = self.curr_fn().len();
472
473 if let Some(else_block) = node.else_block {
474 self.insert(Instruction::Jmp(usize::MAX)); self.curr_fn_mut()[after_condition] = Instruction::JmpFalse(after_then_idx + 1);
478
479 self.block(else_block, true);
480 let after_else = self.curr_fn().len();
481
482 self.curr_fn_mut()[after_then_idx] = Instruction::Jmp(after_else);
484 } else {
485 self.curr_fn_mut()[after_condition] = Instruction::JmpFalse(after_then_idx);
487 }
488 }
489
490 fn prefix_expr(&mut self, node: AnalyzedPrefixExpr<'src>) {
491 match Instruction::try_from(node.op) {
492 Ok(insruction) => {
493 self.expression(node.expr);
494 self.insert(insruction)
495 }
496 Err(_) => match node.op == PrefixOp::Ref {
497 true => {
499 if let AnalyzedExpression::Ident(ident) = node.expr {
500 match self.resolve_var(ident.ident) {
501 Variable::Local { offset, .. } => {
502 self.insert(Instruction::RelToAddr(offset))
503 }
504 Variable::Global { addr } => {
505 self.insert(Instruction::Push(Value::Ptr(Pointer::Abs(addr))));
506 }
507 Variable::Unit => unreachable!("unit values cannot be referenced"),
508 }
509 return;
510 }
511 unreachable!("the parser guarantees that only idents can be referenced")
512 }
513 false => {
515 self.expression(node.expr);
516 self.insert(Instruction::GetVar)
517 }
518 },
519 }
520 }
521
522 fn infix_expr(&mut self, node: AnalyzedInfixExpr<'src>) {
523 match node.op {
524 InfixOp::Or | InfixOp::And => {
525 self.expression(node.lhs);
526 if node.op == InfixOp::Or {
527 self.insert(Instruction::Not);
528 }
529 let merge_jmp_idx = self.curr_fn().len();
530 self.insert(Instruction::JmpFalse(usize::MAX));
531 self.expression(node.rhs);
532 let pos = self.curr_fn().len() + 2;
533 self.insert(Instruction::Jmp(pos));
534 self.insert(Instruction::Push(Value::Bool(node.op == InfixOp::Or)));
535 self.curr_fn_mut()[merge_jmp_idx] = Instruction::JmpFalse(self.curr_fn().len() - 1);
536 }
537 op => {
538 self.expression(node.lhs);
539 self.expression(node.rhs);
540 self.insert(Instruction::from(op));
541 }
542 }
543 }
544
545 fn assign_expr(&mut self, node: AnalyzedAssignExpr<'src>) {
546 let assignee = self.resolve_var(node.assignee);
547
548 let ptr = match assignee {
549 Variable::Local { offset } => Pointer::Rel(offset),
550 Variable::Global { addr } => Pointer::Abs(addr),
551 Variable::Unit => unreachable!("cannot assign to unit values"),
552 };
553
554 self.insert(Instruction::Push(Value::Ptr(ptr)));
555
556 let mut ptr_count = node.assignee_ptr_count;
557 while ptr_count > 0 {
558 self.insert(Instruction::GetVar);
559 ptr_count -= 1;
560 }
561
562 match node.op.try_into() {
563 Ok(instruction) => {
564 self.insert(Instruction::Clone);
566
567 match assignee {
569 Variable::Unit => {}
570 _ => self.insert(Instruction::GetVar),
571 };
572
573 self.expression(node.expr);
574 self.insert(instruction);
575 }
576 Err(()) => self.expression(node.expr),
577 }
578
579 match assignee {
580 Variable::Unit => {}
581 _ => self.insert(Instruction::SetVar),
582 };
583 }
584
585 fn call_expr(&mut self, node: AnalyzedCallExpr<'src>) {
586 for arg in node.args {
587 self.expression(arg);
588 }
589
590 match node.func {
591 "exit" => self.insert(Instruction::Exit),
592 func => {
593 let fn_idx = self.fn_names[func];
594 self.insert(Instruction::Call(fn_idx));
595 }
596 }
597 }
598
599 fn cast_expr(&mut self, node: AnalyzedCastExpr<'src>) {
600 let expr_type = node.expr.result_type();
601 self.expression(node.expr);
602 match (expr_type, node.type_) {
603 (from, to) if from == to => {}
604 (_, to) => self.insert(Instruction::Cast(instruction::Type::from(to))),
605 }
606 }
607}