1use std::collections::{BTreeMap, HashMap};
10use std::fmt::Write;
11
12use rustc_hash::{FxHashMap, FxHashSet};
13
14use crate::{
15 block::{Block, BlockIterator, Label},
16 context::Context,
17 error::IrError,
18 irtype::Type,
19 metadata::MetadataIndex,
20 module::Module,
21 value::{Value, ValueDatum},
22 variable::{LocalVar, LocalVarContent},
23 BlockArgument, BranchToWithArgs,
24};
25use crate::{Constant, InstOp};
26
27#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
30pub struct Function(pub slotmap::DefaultKey);
31
32#[doc(hidden)]
33pub struct FunctionContent {
34 pub name: String,
35 pub arguments: Vec<(String, Value)>,
36 pub return_type: Type,
37 pub blocks: Vec<Block>,
38 pub module: Module,
39 pub is_public: bool,
40 pub is_entry: bool,
41 pub is_original_entry: bool,
44 pub is_fallback: bool,
45 pub selector: Option<[u8; 4]>,
46 pub metadata: Option<MetadataIndex>,
47
48 pub local_storage: BTreeMap<String, LocalVar>, next_label_idx: u64,
51}
52
53impl Function {
54 #[allow(clippy::too_many_arguments)]
62 pub fn new(
63 context: &mut Context,
64 module: Module,
65 name: String,
66 args: Vec<(String, Type, Option<MetadataIndex>)>,
67 return_type: Type,
68 selector: Option<[u8; 4]>,
69 is_public: bool,
70 is_entry: bool,
71 is_original_entry: bool,
72 is_fallback: bool,
73 metadata: Option<MetadataIndex>,
74 ) -> Function {
75 let content = FunctionContent {
76 name,
77 arguments: Vec::new(),
80 return_type,
81 blocks: Vec::new(),
82 module,
83 is_public,
84 is_entry,
85 is_original_entry,
86 is_fallback,
87 selector,
88 metadata,
89 local_storage: BTreeMap::new(),
90 next_label_idx: 0,
91 };
92 let func = Function(context.functions.insert(content));
93
94 context.modules[module.0].functions.push(func);
95
96 let entry_block = Block::new(context, func, Some("entry".to_owned()));
97 context
98 .functions
99 .get_mut(func.0)
100 .unwrap()
101 .blocks
102 .push(entry_block);
103
104 let arguments: Vec<_> = args
106 .into_iter()
107 .enumerate()
108 .map(|(idx, (name, ty, arg_metadata))| {
109 (
110 name,
111 Value::new_argument(
112 context,
113 BlockArgument {
114 block: entry_block,
115 idx,
116 ty,
117 is_immutable: false,
118 },
119 )
120 .add_metadatum(context, arg_metadata),
121 )
122 })
123 .collect();
124 context
125 .functions
126 .get_mut(func.0)
127 .unwrap()
128 .arguments
129 .clone_from(&arguments);
130 let (_, arg_vals): (Vec<_>, Vec<_>) = arguments.iter().cloned().unzip();
131 context.blocks.get_mut(entry_block.0).unwrap().args = arg_vals;
132
133 func
134 }
135
136 pub fn create_block(&self, context: &mut Context, label: Option<Label>) -> Block {
138 let block = Block::new(context, *self, label);
139 let func = context.functions.get_mut(self.0).unwrap();
140 func.blocks.push(block);
141 block
142 }
143
144 pub fn create_block_before(
148 &self,
149 context: &mut Context,
150 other: &Block,
151 label: Option<Label>,
152 ) -> Result<Block, IrError> {
153 let block_idx = context.functions[self.0]
154 .blocks
155 .iter()
156 .position(|block| block == other)
157 .ok_or_else(|| {
158 let label = &context.blocks[other.0].label;
159 IrError::MissingBlock(label.clone())
160 })?;
161
162 let new_block = Block::new(context, *self, label);
163 context.functions[self.0]
164 .blocks
165 .insert(block_idx, new_block);
166 Ok(new_block)
167 }
168
169 pub fn create_block_after(
173 &self,
174 context: &mut Context,
175 other: &Block,
176 label: Option<Label>,
177 ) -> Result<Block, IrError> {
178 let new_block = Block::new(context, *self, label);
181 let func = context.functions.get_mut(self.0).unwrap();
182 func.blocks
183 .iter()
184 .position(|block| block == other)
185 .map(|idx| {
186 func.blocks.insert(idx + 1, new_block);
187 new_block
188 })
189 .ok_or_else(|| {
190 let label = &context.blocks[other.0].label;
191 IrError::MissingBlock(label.clone())
192 })
193 }
194
195 pub fn remove_block(&self, context: &mut Context, block: &Block) -> Result<(), IrError> {
200 let label = block.get_label(context);
201 let func = context.functions.get_mut(self.0).unwrap();
202 let block_idx = func
203 .blocks
204 .iter()
205 .position(|b| b == block)
206 .ok_or(IrError::RemoveMissingBlock(label))?;
207 func.blocks.remove(block_idx);
208 Ok(())
209 }
210
211 pub fn get_unique_label(&self, context: &mut Context, hint: Option<String>) -> String {
219 match hint {
220 Some(hint) => {
221 if context.functions[self.0]
222 .blocks
223 .iter()
224 .any(|block| context.blocks[block.0].label == hint)
225 {
226 let idx = self.get_next_label_idx(context);
227 self.get_unique_label(context, Some(format!("{hint}{idx}")))
228 } else {
229 hint
230 }
231 }
232 None => {
233 let idx = self.get_next_label_idx(context);
234 self.get_unique_label(context, Some(format!("block{idx}")))
235 }
236 }
237 }
238
239 fn get_next_label_idx(&self, context: &mut Context) -> u64 {
240 let func = context.functions.get_mut(self.0).unwrap();
241 let idx = func.next_label_idx;
242 func.next_label_idx += 1;
243 idx
244 }
245
246 pub fn num_blocks(&self, context: &Context) -> usize {
248 context.functions[self.0].blocks.len()
249 }
250
251 pub fn num_instructions(&self, context: &Context) -> usize {
261 self.block_iter(context)
262 .map(|block| block.num_instructions(context))
263 .sum()
264 }
265
266 pub fn num_instructions_incl_asm_instructions(&self, context: &Context) -> usize {
278 self.instruction_iter(context).fold(0, |num, (_, value)| {
279 match &value
280 .get_instruction(context)
281 .expect("We are iterating through the instructions.")
282 .op
283 {
284 InstOp::AsmBlock(asm, _) => num + asm.body.len(),
285 _ => num + 1,
286 }
287 })
288 }
289
290 pub fn get_name<'a>(&self, context: &'a Context) -> &'a str {
292 &context.functions[self.0].name
293 }
294
295 pub fn get_module(&self, context: &Context) -> Module {
297 context.functions[self.0].module
298 }
299
300 pub fn get_entry_block(&self, context: &Context) -> Block {
302 context.functions[self.0].blocks[0]
303 }
304
305 pub fn get_metadata(&self, context: &Context) -> Option<MetadataIndex> {
307 context.functions[self.0].metadata
308 }
309
310 pub fn has_selector(&self, context: &Context) -> bool {
312 context.functions[self.0].selector.is_some()
313 }
314
315 pub fn get_selector(&self, context: &Context) -> Option<[u8; 4]> {
317 context.functions[self.0].selector
318 }
319
320 pub fn is_entry(&self, context: &Context) -> bool {
323 context.functions[self.0].is_entry
324 }
325
326 pub fn is_original_entry(&self, context: &Context) -> bool {
329 context.functions[self.0].is_original_entry
330 }
331
332 pub fn is_fallback(&self, context: &Context) -> bool {
334 context.functions[self.0].is_fallback
335 }
336
337 pub fn get_return_type(&self, context: &Context) -> Type {
339 context.functions[self.0].return_type
340 }
341
342 pub fn set_return_type(&self, context: &mut Context, new_ret_type: Type) {
344 context.functions.get_mut(self.0).unwrap().return_type = new_ret_type
345 }
346
347 pub fn num_args(&self, context: &Context) -> usize {
349 context.functions[self.0].arguments.len()
350 }
351
352 pub fn get_arg(&self, context: &Context, name: &str) -> Option<Value> {
354 context.functions[self.0]
355 .arguments
356 .iter()
357 .find_map(|(arg_name, val)| (arg_name == name).then_some(val))
358 .copied()
359 }
360
361 pub fn add_arg<S: Into<String>>(&self, context: &mut Context, name: S, arg: Value) {
366 match context.values[arg.0].value {
367 ValueDatum::Argument(BlockArgument { idx, .. })
368 if idx == context.functions[self.0].arguments.len() =>
369 {
370 context.functions[self.0].arguments.push((name.into(), arg));
371 }
372 _ => panic!("Inconsistent function argument being added"),
373 }
374 }
375
376 pub fn lookup_arg_name<'a>(&self, context: &'a Context, value: &Value) -> Option<&'a String> {
378 context.functions[self.0]
379 .arguments
380 .iter()
381 .find_map(|(name, arg_val)| (arg_val == value).then_some(name))
382 }
383
384 pub fn args_iter<'a>(&self, context: &'a Context) -> impl Iterator<Item = &'a (String, Value)> {
386 context.functions[self.0].arguments.iter()
387 }
388
389 pub fn is_arg_immutable(&self, context: &Context, i: usize) -> bool {
391 if let Some((_, val)) = context.functions[self.0].arguments.get(i) {
392 if let ValueDatum::Argument(arg) = &context.values[val.0].value {
393 return arg.is_immutable;
394 }
395 }
396 false
397 }
398
399 pub fn get_local_var(&self, context: &Context, name: &str) -> Option<LocalVar> {
401 context.functions[self.0].local_storage.get(name).copied()
402 }
403
404 pub fn lookup_local_name<'a>(
406 &self,
407 context: &'a Context,
408 var: &LocalVar,
409 ) -> Option<&'a String> {
410 context.functions[self.0]
411 .local_storage
412 .iter()
413 .find_map(|(name, local_var)| if local_var == var { Some(name) } else { None })
414 }
415
416 pub fn new_local_var(
420 &self,
421 context: &mut Context,
422 name: String,
423 local_type: Type,
424 initializer: Option<Constant>,
425 mutable: bool,
426 ) -> Result<LocalVar, IrError> {
427 let var = LocalVar::new(context, local_type, initializer, mutable);
428 let func = context.functions.get_mut(self.0).unwrap();
429 func.local_storage
430 .insert(name.clone(), var)
431 .map(|_| Err(IrError::FunctionLocalClobbered(func.name.clone(), name)))
432 .unwrap_or(Ok(var))
433 }
434
435 pub fn new_unique_local_var(
439 &self,
440 context: &mut Context,
441 name: String,
442 local_type: Type,
443 initializer: Option<Constant>,
444 mutable: bool,
445 ) -> LocalVar {
446 let func = &context.functions[self.0];
447 let new_name = if func.local_storage.contains_key(&name) {
448 (0..)
451 .find_map(|n| {
452 let candidate = format!("{name}{n}");
453 if func.local_storage.contains_key(&candidate) {
454 None
455 } else {
456 Some(candidate)
457 }
458 })
459 .unwrap()
460 } else {
461 name
462 };
463 self.new_local_var(context, new_name, local_type, initializer, mutable)
464 .unwrap()
465 }
466
467 pub fn locals_iter<'a>(
469 &self,
470 context: &'a Context,
471 ) -> impl Iterator<Item = (&'a String, &'a LocalVar)> {
472 context.functions[self.0].local_storage.iter()
473 }
474
475 pub fn remove_locals(&self, context: &mut Context, removals: &Vec<String>) {
477 for remove in removals {
478 if let Some(local) = context.functions[self.0].local_storage.remove(remove) {
479 context.local_vars.remove(local.0);
480 }
481 }
482 }
483
484 pub fn merge_locals_from(
491 &self,
492 context: &mut Context,
493 other: Function,
494 ) -> HashMap<LocalVar, LocalVar> {
495 let mut var_map = HashMap::new();
496 let old_vars: Vec<(String, LocalVar, LocalVarContent)> = context.functions[other.0]
497 .local_storage
498 .iter()
499 .map(|(name, var)| (name.clone(), *var, context.local_vars[var.0].clone()))
500 .collect();
501 for (name, old_var, old_var_content) in old_vars {
502 let old_ty = old_var_content
503 .ptr_ty
504 .get_pointee_type(context)
505 .expect("LocalVar types are always pointers.");
506 let new_var = self.new_unique_local_var(
507 context,
508 name.clone(),
509 old_ty,
510 old_var_content.initializer,
511 old_var_content.mutable,
512 );
513 var_map.insert(old_var, new_var);
514 }
515 var_map
516 }
517
518 pub fn block_iter(&self, context: &Context) -> BlockIterator {
520 BlockIterator::new(context, self)
521 }
522
523 pub fn instruction_iter<'a>(
528 &self,
529 context: &'a Context,
530 ) -> impl Iterator<Item = (Block, Value)> + 'a {
531 context.functions[self.0]
532 .blocks
533 .iter()
534 .flat_map(move |block| {
535 block
536 .instruction_iter(context)
537 .map(move |ins_val| (*block, ins_val))
538 })
539 }
540
541 pub fn replace_values(
549 &self,
550 context: &mut Context,
551 replace_map: &FxHashMap<Value, Value>,
552 starting_block: Option<Block>,
553 ) {
554 let mut block_iter = self.block_iter(context).peekable();
555
556 if let Some(ref starting_block) = starting_block {
557 while block_iter
559 .next_if(|block| block != starting_block)
560 .is_some()
561 {}
562 }
563
564 for block in block_iter {
565 block.replace_values(context, replace_map);
566 }
567 }
568
569 pub fn replace_value(
570 &self,
571 context: &mut Context,
572 old_val: Value,
573 new_val: Value,
574 starting_block: Option<Block>,
575 ) {
576 let mut map = FxHashMap::<Value, Value>::default();
577 map.insert(old_val, new_val);
578 self.replace_values(context, &map, starting_block);
579 }
580
581 pub fn dot_cfg(&self, context: &Context) -> String {
583 let mut worklist = Vec::<Block>::new();
584 let mut visited = FxHashSet::<Block>::default();
585 let entry = self.get_entry_block(context);
586 let mut res = format!("digraph {} {{\n", self.get_name(context));
587
588 worklist.push(entry);
589 while let Some(n) = worklist.pop() {
590 visited.insert(n);
591 for BranchToWithArgs { block: n_succ, .. } in n.successors(context) {
592 let _ = writeln!(
593 res,
594 "\t{} -> {}\n",
595 n.get_label(context),
596 n_succ.get_label(context)
597 );
598 if !visited.contains(&n_succ) {
599 worklist.push(n_succ);
600 }
601 }
602 }
603
604 res += "}\n";
605 res
606 }
607}
608
609pub struct FunctionIterator {
611 functions: Vec<slotmap::DefaultKey>,
612 next: usize,
613}
614
615impl FunctionIterator {
616 pub fn new(context: &Context, module: &Module) -> FunctionIterator {
618 FunctionIterator {
621 functions: context.modules[module.0]
622 .functions
623 .iter()
624 .map(|func| func.0)
625 .collect(),
626 next: 0,
627 }
628 }
629}
630
631impl Iterator for FunctionIterator {
632 type Item = Function;
633
634 fn next(&mut self) -> Option<Function> {
635 if self.next < self.functions.len() {
636 let idx = self.next;
637 self.next += 1;
638 Some(Function(self.functions[idx]))
639 } else {
640 None
641 }
642 }
643}