1use std::collections::{BTreeMap, HashMap};
10use std::fmt::Write;
11
12use rustc_hash::{FxHashMap, FxHashSet};
13
14use crate::{
15 block::{Block, BlockIterator, Label},
16 context::Context,
17 error::IrError,
18 irtype::Type,
19 metadata::MetadataIndex,
20 module::Module,
21 value::{Value, ValueDatum},
22 variable::{LocalVar, LocalVarContent},
23 BlockArgument, BranchToWithArgs,
24};
25use crate::{Constant, InstOp};
26
27#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
30pub struct Function(pub slotmap::DefaultKey);
31
32#[doc(hidden)]
33pub struct FunctionContent {
34 pub name: String,
35 pub abi_errors_display: String,
44 pub arguments: Vec<(String, Value)>,
45 pub return_type: Type,
46 pub blocks: Vec<Block>,
47 pub module: Module,
48 pub is_public: bool,
49 pub is_entry: bool,
50 pub is_original_entry: bool,
53 pub is_fallback: bool,
54 pub selector: Option<[u8; 4]>,
55 pub metadata: Option<MetadataIndex>,
56
57 pub local_storage: BTreeMap<String, LocalVar>, next_label_idx: u64,
60}
61
62impl Function {
63 #[allow(clippy::too_many_arguments)]
71 pub fn new(
72 context: &mut Context,
73 module: Module,
74 name: String,
75 abi_errors_display: String,
76 args: Vec<(String, Type, Option<MetadataIndex>)>,
77 return_type: Type,
78 selector: Option<[u8; 4]>,
79 is_public: bool,
80 is_entry: bool,
81 is_original_entry: bool,
82 is_fallback: bool,
83 metadata: Option<MetadataIndex>,
84 ) -> Function {
85 let content = FunctionContent {
86 name,
87 abi_errors_display,
88 arguments: Vec::new(),
91 return_type,
92 blocks: Vec::new(),
93 module,
94 is_public,
95 is_entry,
96 is_original_entry,
97 is_fallback,
98 selector,
99 metadata,
100 local_storage: BTreeMap::new(),
101 next_label_idx: 0,
102 };
103 let func = Function(context.functions.insert(content));
104
105 context.modules[module.0].functions.push(func);
106
107 let entry_block = Block::new(context, func, Some("entry".to_owned()));
108 context
109 .functions
110 .get_mut(func.0)
111 .unwrap()
112 .blocks
113 .push(entry_block);
114
115 let arguments: Vec<_> = args
117 .into_iter()
118 .enumerate()
119 .map(|(idx, (name, ty, arg_metadata))| {
120 (
121 name,
122 Value::new_argument(
123 context,
124 BlockArgument {
125 block: entry_block,
126 idx,
127 ty,
128 is_immutable: false,
129 },
130 )
131 .add_metadatum(context, arg_metadata),
132 )
133 })
134 .collect();
135 context
136 .functions
137 .get_mut(func.0)
138 .unwrap()
139 .arguments
140 .clone_from(&arguments);
141 let (_, arg_vals): (Vec<_>, Vec<_>) = arguments.iter().cloned().unzip();
142 context.blocks.get_mut(entry_block.0).unwrap().args = arg_vals;
143
144 func
145 }
146
147 pub fn create_block(&self, context: &mut Context, label: Option<Label>) -> Block {
149 let block = Block::new(context, *self, label);
150 let func = context.functions.get_mut(self.0).unwrap();
151 func.blocks.push(block);
152 block
153 }
154
155 pub fn create_block_before(
159 &self,
160 context: &mut Context,
161 other: &Block,
162 label: Option<Label>,
163 ) -> Result<Block, IrError> {
164 let block_idx = context.functions[self.0]
165 .blocks
166 .iter()
167 .position(|block| block == other)
168 .ok_or_else(|| {
169 let label = &context.blocks[other.0].label;
170 IrError::MissingBlock(label.clone())
171 })?;
172
173 let new_block = Block::new(context, *self, label);
174 context.functions[self.0]
175 .blocks
176 .insert(block_idx, new_block);
177 Ok(new_block)
178 }
179
180 pub fn create_block_after(
184 &self,
185 context: &mut Context,
186 other: &Block,
187 label: Option<Label>,
188 ) -> Result<Block, IrError> {
189 let new_block = Block::new(context, *self, label);
192 let func = context.functions.get_mut(self.0).unwrap();
193 func.blocks
194 .iter()
195 .position(|block| block == other)
196 .map(|idx| {
197 func.blocks.insert(idx + 1, new_block);
198 new_block
199 })
200 .ok_or_else(|| {
201 let label = &context.blocks[other.0].label;
202 IrError::MissingBlock(label.clone())
203 })
204 }
205
206 pub fn remove_block(&self, context: &mut Context, block: &Block) -> Result<(), IrError> {
211 let label = block.get_label(context);
212 let func = context.functions.get_mut(self.0).unwrap();
213 let block_idx = func
214 .blocks
215 .iter()
216 .position(|b| b == block)
217 .ok_or(IrError::RemoveMissingBlock(label))?;
218 func.blocks.remove(block_idx);
219 Ok(())
220 }
221
222 pub fn remove_instructions<T: Fn(Value) -> bool>(&self, context: &mut Context, pred: T) {
224 for block in context.functions[self.0].blocks.clone() {
225 block.remove_instructions(context, &pred);
226 }
227 }
228
229 pub fn get_unique_label(&self, context: &mut Context, hint: Option<String>) -> String {
237 match hint {
238 Some(hint) => {
239 if context.functions[self.0]
240 .blocks
241 .iter()
242 .any(|block| context.blocks[block.0].label == hint)
243 {
244 let idx = self.get_next_label_idx(context);
245 self.get_unique_label(context, Some(format!("{hint}{idx}")))
246 } else {
247 hint
248 }
249 }
250 None => {
251 let idx = self.get_next_label_idx(context);
252 self.get_unique_label(context, Some(format!("block{idx}")))
253 }
254 }
255 }
256
257 fn get_next_label_idx(&self, context: &mut Context) -> u64 {
258 let func = context.functions.get_mut(self.0).unwrap();
259 let idx = func.next_label_idx;
260 func.next_label_idx += 1;
261 idx
262 }
263
264 pub fn num_blocks(&self, context: &Context) -> usize {
266 context.functions[self.0].blocks.len()
267 }
268
269 pub fn num_instructions(&self, context: &Context) -> usize {
279 self.block_iter(context)
280 .map(|block| block.num_instructions(context))
281 .sum()
282 }
283
284 pub fn num_instructions_incl_asm_instructions(&self, context: &Context) -> usize {
296 self.instruction_iter(context).fold(0, |num, (_, value)| {
297 match &value
298 .get_instruction(context)
299 .expect("We are iterating through the instructions.")
300 .op
301 {
302 InstOp::AsmBlock(asm, _) => num + asm.body.len(),
303 _ => num + 1,
304 }
305 })
306 }
307
308 pub fn get_name<'a>(&self, context: &'a Context) -> &'a str {
310 &context.functions[self.0].name
311 }
312
313 pub fn get_abi_errors_display(&self, context: &Context) -> String {
316 context.functions[self.0].abi_errors_display.clone()
317 }
318
319 pub fn get_module(&self, context: &Context) -> Module {
321 context.functions[self.0].module
322 }
323
324 pub fn get_entry_block(&self, context: &Context) -> Block {
326 context.functions[self.0].blocks[0]
327 }
328
329 pub fn get_metadata(&self, context: &Context) -> Option<MetadataIndex> {
331 context.functions[self.0].metadata
332 }
333
334 pub fn has_selector(&self, context: &Context) -> bool {
336 context.functions[self.0].selector.is_some()
337 }
338
339 pub fn get_selector(&self, context: &Context) -> Option<[u8; 4]> {
341 context.functions[self.0].selector
342 }
343
344 pub fn is_entry(&self, context: &Context) -> bool {
347 context.functions[self.0].is_entry
348 }
349
350 pub fn is_original_entry(&self, context: &Context) -> bool {
353 context.functions[self.0].is_original_entry
354 }
355
356 pub fn is_fallback(&self, context: &Context) -> bool {
358 context.functions[self.0].is_fallback
359 }
360
361 pub fn get_return_type(&self, context: &Context) -> Type {
363 context.functions[self.0].return_type
364 }
365
366 pub fn set_return_type(&self, context: &mut Context, new_ret_type: Type) {
368 context.functions.get_mut(self.0).unwrap().return_type = new_ret_type
369 }
370
371 pub fn num_args(&self, context: &Context) -> usize {
373 context.functions[self.0].arguments.len()
374 }
375
376 pub fn get_arg(&self, context: &Context, name: &str) -> Option<Value> {
378 context.functions[self.0]
379 .arguments
380 .iter()
381 .find_map(|(arg_name, val)| (arg_name == name).then_some(val))
382 .copied()
383 }
384
385 pub fn add_arg<S: Into<String>>(&self, context: &mut Context, name: S, arg: Value) {
390 match context.values[arg.0].value {
391 ValueDatum::Argument(BlockArgument { idx, .. })
392 if idx == context.functions[self.0].arguments.len() =>
393 {
394 context.functions[self.0].arguments.push((name.into(), arg));
395 }
396 _ => panic!("Inconsistent function argument being added"),
397 }
398 }
399
400 pub fn lookup_arg_name<'a>(&self, context: &'a Context, value: &Value) -> Option<&'a String> {
402 context.functions[self.0]
403 .arguments
404 .iter()
405 .find_map(|(name, arg_val)| (arg_val == value).then_some(name))
406 }
407
408 pub fn args_iter<'a>(&self, context: &'a Context) -> impl Iterator<Item = &'a (String, Value)> {
410 context.functions[self.0].arguments.iter()
411 }
412
413 pub fn is_arg_immutable(&self, context: &Context, i: usize) -> bool {
415 if let Some((_, val)) = context.functions[self.0].arguments.get(i) {
416 if let ValueDatum::Argument(arg) = &context.values[val.0].value {
417 return arg.is_immutable;
418 }
419 }
420 false
421 }
422
423 pub fn get_local_var(&self, context: &Context, name: &str) -> Option<LocalVar> {
425 context.functions[self.0].local_storage.get(name).copied()
426 }
427
428 pub fn lookup_local_name<'a>(
430 &self,
431 context: &'a Context,
432 var: &LocalVar,
433 ) -> Option<&'a String> {
434 context.functions[self.0]
435 .local_storage
436 .iter()
437 .find_map(|(name, local_var)| if local_var == var { Some(name) } else { None })
438 }
439
440 pub fn new_local_var(
444 &self,
445 context: &mut Context,
446 name: String,
447 local_type: Type,
448 initializer: Option<Constant>,
449 mutable: bool,
450 ) -> Result<LocalVar, IrError> {
451 let var = LocalVar::new(context, local_type, initializer, mutable);
452 let func = context.functions.get_mut(self.0).unwrap();
453 func.local_storage
454 .insert(name.clone(), var)
455 .map(|_| Err(IrError::FunctionLocalClobbered(func.name.clone(), name)))
456 .unwrap_or(Ok(var))
457 }
458
459 pub fn new_unique_local_var(
463 &self,
464 context: &mut Context,
465 name: String,
466 local_type: Type,
467 initializer: Option<Constant>,
468 mutable: bool,
469 ) -> LocalVar {
470 let func = &context.functions[self.0];
471 let new_name = if func.local_storage.contains_key(&name) {
472 (0..)
475 .find_map(|n| {
476 let candidate = format!("{name}{n}");
477 if func.local_storage.contains_key(&candidate) {
478 None
479 } else {
480 Some(candidate)
481 }
482 })
483 .unwrap()
484 } else {
485 name
486 };
487 self.new_local_var(context, new_name, local_type, initializer, mutable)
488 .unwrap()
489 }
490
491 pub fn locals_iter<'a>(
493 &self,
494 context: &'a Context,
495 ) -> impl Iterator<Item = (&'a String, &'a LocalVar)> {
496 context.functions[self.0].local_storage.iter()
497 }
498
499 pub fn remove_locals(&self, context: &mut Context, removals: &Vec<String>) {
501 for remove in removals {
502 if let Some(local) = context.functions[self.0].local_storage.remove(remove) {
503 context.local_vars.remove(local.0);
504 }
505 }
506 }
507
508 pub fn merge_locals_from(
515 &self,
516 context: &mut Context,
517 other: Function,
518 ) -> HashMap<LocalVar, LocalVar> {
519 let mut var_map = HashMap::new();
520 let old_vars: Vec<(String, LocalVar, LocalVarContent)> = context.functions[other.0]
521 .local_storage
522 .iter()
523 .map(|(name, var)| (name.clone(), *var, context.local_vars[var.0].clone()))
524 .collect();
525 for (name, old_var, old_var_content) in old_vars {
526 let old_ty = old_var_content
527 .ptr_ty
528 .get_pointee_type(context)
529 .expect("LocalVar types are always pointers.");
530 let new_var = self.new_unique_local_var(
531 context,
532 name.clone(),
533 old_ty,
534 old_var_content.initializer,
535 old_var_content.mutable,
536 );
537 var_map.insert(old_var, new_var);
538 }
539 var_map
540 }
541
542 pub fn block_iter(&self, context: &Context) -> BlockIterator {
544 BlockIterator::new(context, self)
545 }
546
547 pub fn instruction_iter<'a>(
552 &self,
553 context: &'a Context,
554 ) -> impl Iterator<Item = (Block, Value)> + 'a {
555 context.functions[self.0]
556 .blocks
557 .iter()
558 .flat_map(move |block| {
559 block
560 .instruction_iter(context)
561 .map(move |ins_val| (*block, ins_val))
562 })
563 }
564
565 pub fn replace_values(
573 &self,
574 context: &mut Context,
575 replace_map: &FxHashMap<Value, Value>,
576 starting_block: Option<Block>,
577 ) {
578 let mut block_iter = self.block_iter(context).peekable();
579
580 if let Some(ref starting_block) = starting_block {
581 while block_iter
583 .next_if(|block| block != starting_block)
584 .is_some()
585 {}
586 }
587
588 for block in block_iter {
589 block.replace_values(context, replace_map);
590 }
591 }
592
593 pub fn replace_value(
594 &self,
595 context: &mut Context,
596 old_val: Value,
597 new_val: Value,
598 starting_block: Option<Block>,
599 ) {
600 let mut map = FxHashMap::<Value, Value>::default();
601 map.insert(old_val, new_val);
602 self.replace_values(context, &map, starting_block);
603 }
604
605 pub fn dot_cfg(&self, context: &Context) -> String {
607 let mut worklist = Vec::<Block>::new();
608 let mut visited = FxHashSet::<Block>::default();
609 let entry = self.get_entry_block(context);
610 let mut res = format!("digraph {} {{\n", self.get_name(context));
611
612 worklist.push(entry);
613 while let Some(n) = worklist.pop() {
614 visited.insert(n);
615 for BranchToWithArgs { block: n_succ, .. } in n.successors(context) {
616 let _ = writeln!(
617 res,
618 "\t{} -> {}\n",
619 n.get_label(context),
620 n_succ.get_label(context)
621 );
622 if !visited.contains(&n_succ) {
623 worklist.push(n_succ);
624 }
625 }
626 }
627
628 res += "}\n";
629 res
630 }
631}
632
633pub struct FunctionIterator {
635 functions: Vec<slotmap::DefaultKey>,
636 next: usize,
637}
638
639impl FunctionIterator {
640 pub fn new(context: &Context, module: &Module) -> FunctionIterator {
642 FunctionIterator {
645 functions: context.modules[module.0]
646 .functions
647 .iter()
648 .map(|func| func.0)
649 .collect(),
650 next: 0,
651 }
652 }
653}
654
655impl Iterator for FunctionIterator {
656 type Item = Function;
657
658 fn next(&mut self) -> Option<Function> {
659 if self.next < self.functions.len() {
660 let idx = self.next;
661 self.next += 1;
662 Some(Function(self.functions[idx]))
663 } else {
664 None
665 }
666 }
667}