1use std::{cell::RefCell, collections::HashMap};
6
7use rustc_hash::FxHashMap;
8
9use crate::{
10 asm::AsmArg,
11 block::Block,
12 call_graph, compute_post_order,
13 context::Context,
14 error::IrError,
15 function::Function,
16 instruction::{FuelVmInstruction, InstOp},
17 irtype::Type,
18 metadata::{combine, MetadataIndex},
19 value::{Value, ValueContent, ValueDatum},
20 variable::LocalVar,
21 AnalysisResults, BlockArgument, Instruction, Module, Pass, PassMutability, ScopedPass,
22};
23
24pub const FN_INLINE_NAME: &str = "inline";
25
26pub fn create_fn_inline_pass() -> Pass {
27 Pass {
28 name: FN_INLINE_NAME,
29 descr: "Function inlining",
30 deps: vec![],
31 runner: ScopedPass::ModulePass(PassMutability::Transform(fn_inline)),
32 }
33}
34
35#[derive(Debug)]
38pub enum Inline {
39 Always,
40 Never,
41}
42
43pub fn metadata_to_inline(context: &Context, md_idx: Option<MetadataIndex>) -> Option<Inline> {
44 fn for_each_md_idx<T, F: FnMut(MetadataIndex) -> Option<T>>(
45 context: &Context,
46 md_idx: Option<MetadataIndex>,
47 mut f: F,
48 ) -> Option<T> {
49 md_idx.and_then(|md_idx| {
51 if let Some(md_idcs) = md_idx.get_content(context).unwrap_list() {
52 md_idcs.iter().find_map(|md_idx| f(*md_idx))
53 } else {
54 f(md_idx)
55 }
56 })
57 }
58 for_each_md_idx(context, md_idx, |md_idx| {
59 md_idx
61 .get_content(context)
62 .unwrap_struct("inline", 1)
63 .and_then(|fields| fields[0].unwrap_string())
64 .and_then(|inline_str| {
65 let inline = match inline_str {
66 "always" => Some(Inline::Always),
67 "never" => Some(Inline::Never),
68 _otherwise => None,
69 }?;
70 Some(inline)
71 })
72 })
73}
74
75pub fn fn_inline(
76 context: &mut Context,
77 _: &AnalysisResults,
78 module: Module,
79) -> Result<bool, IrError> {
80 let call_counts: HashMap<Function, u64> =
82 module
83 .function_iter(context)
84 .fold(HashMap::new(), |mut counts, func| {
85 for (_block, ins) in func.instruction_iter(context) {
86 if let Some(Instruction {
87 op: InstOp::Call(callee, _args),
88 ..
89 }) = ins.get_instruction(context)
90 {
91 counts
92 .entry(*callee)
93 .and_modify(|count| *count += 1)
94 .or_insert(1);
95 }
96 }
97 counts
98 });
99
100 let inline_heuristic = |ctx: &Context, func: &Function, _call_site: &Value| {
101 if func.is_original_entry(ctx) {
106 return false;
107 }
108
109 let attributed_inline = metadata_to_inline(ctx, func.get_metadata(ctx));
110 match attributed_inline {
111 Some(Inline::Always) => {
112 }
115 Some(Inline::Never) => {
116 return false;
117 }
118 None => {}
119 }
120
121 if call_counts.get(func).copied().unwrap_or(0) == 1 {
123 return true;
124 }
125
126 const MAX_INLINE_INSTRS_COUNT: usize = 12;
128 if func.num_instructions_incl_asm_instructions(ctx) <= MAX_INLINE_INSTRS_COUNT {
129 return true;
130 }
131
132 false
133 };
134
135 let cg =
136 call_graph::build_call_graph(context, &module.function_iter(context).collect::<Vec<_>>());
137 let functions = call_graph::callee_first_order(&cg);
138 let mut modified = false;
139
140 for function in functions {
141 modified |= inline_some_function_calls(context, &function, inline_heuristic)?;
142 }
143 Ok(modified)
144}
145
146pub fn inline_all_function_calls(
151 context: &mut Context,
152 function: &Function,
153) -> Result<bool, IrError> {
154 inline_some_function_calls(context, function, |_, _, _| true)
155}
156
157pub fn inline_some_function_calls<F: Fn(&Context, &Function, &Value) -> bool>(
166 context: &mut Context,
167 function: &Function,
168 predicate: F,
169) -> Result<bool, IrError> {
170 let (call_sites, call_data): (Vec<_>, FxHashMap<_, _>) = function
174 .instruction_iter(context)
175 .filter_map(|(block, call_val)| match context.values[call_val.0].value {
176 ValueDatum::Instruction(Instruction {
177 op: InstOp::Call(inlined_function, _),
178 ..
179 }) => predicate(context, &inlined_function, &call_val).then_some((
180 call_val,
181 (call_val, RefCell::new((block, inlined_function))),
182 )),
183 _ => None,
184 })
185 .unzip();
186
187 for call_site in &call_sites {
188 let call_site_in = call_data.get(call_site).unwrap();
189 let (block, inlined_function) = *call_site_in.borrow();
190
191 if function == &inlined_function {
192 continue;
194 }
195
196 inline_function_call(
197 context,
198 *function,
199 block,
200 *call_site,
201 inlined_function,
202 &call_data,
203 )?;
204 }
205
206 Ok(!call_data.is_empty())
207}
208
209pub fn is_small_fn(
216 max_blocks: Option<usize>,
217 max_instrs: Option<usize>,
218 max_stack_size: Option<usize>,
219) -> impl Fn(&Context, &Function, &Value) -> bool {
220 fn count_type_elements(context: &Context, ty: &Type) -> usize {
221 if ty.is_array(context) {
223 count_type_elements(context, &ty.get_array_elem_type(context).unwrap())
224 * ty.get_array_len(context).unwrap() as usize
225 } else if ty.is_union(context) {
226 ty.get_field_types(context)
227 .iter()
228 .map(|ty| count_type_elements(context, ty))
229 .max()
230 .unwrap_or(1)
231 } else if ty.is_struct(context) {
232 ty.get_field_types(context)
233 .iter()
234 .map(|ty| count_type_elements(context, ty))
235 .sum()
236 } else {
237 1
238 }
239 }
240
241 move |context: &Context, function: &Function, _call_site: &Value| -> bool {
242 max_blocks.is_none_or(|max_block_count| function.num_blocks(context) <= max_block_count)
243 && max_instrs.is_none_or(|max_instrs_count| {
244 function.num_instructions_incl_asm_instructions(context) <= max_instrs_count
245 })
246 && max_stack_size.is_none_or(|max_stack_size_count| {
247 function
248 .locals_iter(context)
249 .map(|(_name, ptr)| count_type_elements(context, &ptr.get_inner_type(context)))
250 .sum::<usize>()
251 <= max_stack_size_count
252 })
253 }
254}
255
256pub fn inline_function_call(
261 context: &mut Context,
262 function: Function,
263 block: Block,
264 call_site: Value,
265 inlined_function: Function,
266 call_data: &FxHashMap<Value, RefCell<(Block, Function)>>,
267) -> Result<(), IrError> {
268 let call_site_idx = block
270 .instruction_iter(context)
271 .position(|v| v == call_site)
272 .unwrap();
273 let (pre_block, post_block) = block.split_at(context, call_site_idx + 1);
274 if post_block != block {
275 for inst in post_block.instruction_iter(context).filter(|inst| {
277 matches!(
278 context.values[inst.0].value,
279 ValueDatum::Instruction(Instruction {
280 op: InstOp::Call(..),
281 ..
282 })
283 )
284 }) {
285 if let Some(call_info) = call_data.get(&inst) {
286 call_info.borrow_mut().0 = post_block;
287 }
288 }
289 }
290
291 pre_block.remove_last_instruction(context);
293
294 if post_block.new_arg(context, call_site.get_type(context).unwrap()) != 0 {
297 panic!("Expected newly created post_block to not have block args")
298 }
299 function.replace_value(
300 context,
301 call_site,
302 post_block.get_arg(context, 0).unwrap(),
303 None,
304 );
305
306 let ptr_map = function.merge_locals_from(context, inlined_function);
309 let mut value_map = HashMap::new();
310
311 if let ValueDatum::Instruction(Instruction {
313 op: InstOp::Call(_, passed_vals),
314 ..
315 }) = &context.values[call_site.0].value
316 {
317 for (arg_val, passed_val) in context.functions[inlined_function.0]
318 .arguments
319 .iter()
320 .zip(passed_vals.iter())
321 {
322 value_map.insert(arg_val.1, *passed_val);
323 }
324 }
325
326 let metadata = context.values[call_site.0].metadata;
329
330 context.values.remove(call_site.0);
332
333 let inlined_fn_name = inlined_function.get_name(context).to_owned();
341 let mut block_map = HashMap::new();
342 let mut block_iter = context.functions[inlined_function.0]
343 .blocks
344 .clone()
345 .into_iter();
346 block_map.insert(block_iter.next().unwrap(), pre_block);
347 block_map = block_iter.fold(block_map, |mut block_map, inlined_block| {
348 let inlined_block_label = inlined_block.get_label(context);
349 let new_block = function
350 .create_block_before(
351 context,
352 &post_block,
353 Some(format!("{inlined_fn_name}_{inlined_block_label}")),
354 )
355 .unwrap();
356 block_map.insert(inlined_block, new_block);
357 let inlined_args: Vec<_> = inlined_block.arg_iter(context).copied().collect();
359 for inlined_arg in inlined_args {
360 if let ValueDatum::Argument(BlockArgument {
361 block: _,
362 idx: _,
363 ty,
364 is_immutable: _,
365 }) = &context.values[inlined_arg.0].value
366 {
367 let index = new_block.new_arg(context, *ty);
368 value_map.insert(inlined_arg, new_block.get_arg(context, index).unwrap());
369 } else {
370 unreachable!("Expected a block argument")
371 }
372 }
373 block_map
374 });
375
376 let inlined_block_iter = compute_post_order(context, &inlined_function)
378 .po_to_block
379 .into_iter()
380 .rev();
381 for ref block in inlined_block_iter {
386 for ins in block.instruction_iter(context) {
387 inline_instruction(
388 context,
389 block_map.get(block).unwrap(),
390 &post_block,
391 &ins,
392 &block_map,
393 &mut value_map,
394 &ptr_map,
395 metadata,
396 );
397 }
398 }
399
400 Ok(())
401}
402
403#[allow(clippy::too_many_arguments)]
404fn inline_instruction(
405 context: &mut Context,
406 new_block: &Block,
407 post_block: &Block,
408 instruction: &Value,
409 block_map: &HashMap<Block, Block>,
410 value_map: &mut HashMap<Value, Value>,
411 local_map: &HashMap<LocalVar, LocalVar>,
412 fn_metadata: Option<MetadataIndex>,
413) {
414 let map_block = |old_block| *block_map.get(&old_block).unwrap();
417
418 let map_value = |old_val: Value| value_map.get(&old_val).copied().unwrap_or(old_val);
421 let map_local = |old_local| local_map.get(&old_local).copied().unwrap();
422
423 if let ValueContent {
431 value: ValueDatum::Instruction(old_ins),
432 metadata: val_metadata,
433 } = context.values[instruction.0].clone()
434 {
435 let metadata = combine(context, &fn_metadata, &val_metadata);
438
439 let new_ins = match old_ins.op {
440 InstOp::AsmBlock(asm, args) => {
441 let new_args = args
442 .iter()
443 .map(|AsmArg { name, initializer }| AsmArg {
444 name: name.clone(),
445 initializer: initializer.map(map_value),
446 })
447 .collect();
448
449 new_block.append(context).asm_block_from_asm(asm, new_args)
451 }
452 InstOp::BitCast(value, ty) => new_block.append(context).bitcast(map_value(value), ty),
453 InstOp::UnaryOp { op, arg } => new_block.append(context).unary_op(op, map_value(arg)),
454 InstOp::BinaryOp { op, arg1, arg2 } => {
455 new_block
456 .append(context)
457 .binary_op(op, map_value(arg1), map_value(arg2))
458 }
459 InstOp::Branch(b) => new_block.append(context).branch(
462 map_block(b.block),
463 b.args.iter().map(|v| map_value(*v)).collect(),
464 ),
465 InstOp::Call(f, args) => new_block.append(context).call(
466 f,
467 args.iter()
468 .map(|old_val: &Value| map_value(*old_val))
469 .collect::<Vec<Value>>()
470 .as_slice(),
471 ),
472 InstOp::CastPtr(val, ty) => new_block.append(context).cast_ptr(map_value(val), ty),
473 InstOp::Cmp(pred, lhs_value, rhs_value) => {
474 new_block
475 .append(context)
476 .cmp(pred, map_value(lhs_value), map_value(rhs_value))
477 }
478 InstOp::ConditionalBranch {
479 cond_value,
480 true_block,
481 false_block,
482 } => new_block.append(context).conditional_branch(
483 map_value(cond_value),
484 map_block(true_block.block),
485 map_block(false_block.block),
486 true_block.args.iter().map(|v| map_value(*v)).collect(),
487 false_block.args.iter().map(|v| map_value(*v)).collect(),
488 ),
489 InstOp::ContractCall {
490 return_type,
491 name,
492 params,
493 coins,
494 asset_id,
495 gas,
496 } => new_block.append(context).contract_call(
497 return_type,
498 name,
499 map_value(params),
500 map_value(coins),
501 map_value(asset_id),
502 map_value(gas),
503 ),
504 InstOp::FuelVm(fuel_vm_instr) => match fuel_vm_instr {
505 FuelVmInstruction::Gtf { index, tx_field_id } => {
506 new_block.append(context).gtf(map_value(index), tx_field_id)
507 }
508 FuelVmInstruction::Log {
509 log_val,
510 log_ty,
511 log_id,
512 } => new_block
513 .append(context)
514 .log(map_value(log_val), log_ty, map_value(log_id)),
515 FuelVmInstruction::ReadRegister(reg) => {
516 new_block.append(context).read_register(reg)
517 }
518 FuelVmInstruction::Revert(val) => new_block.append(context).revert(map_value(val)),
519 FuelVmInstruction::JmpMem => new_block.append(context).jmp_mem(),
520 FuelVmInstruction::Smo {
521 recipient,
522 message,
523 message_size,
524 coins,
525 } => new_block.append(context).smo(
526 map_value(recipient),
527 map_value(message),
528 map_value(message_size),
529 map_value(coins),
530 ),
531 FuelVmInstruction::StateClear {
532 key,
533 number_of_slots,
534 } => new_block
535 .append(context)
536 .state_clear(map_value(key), map_value(number_of_slots)),
537 FuelVmInstruction::StateLoadQuadWord {
538 load_val,
539 key,
540 number_of_slots,
541 } => new_block.append(context).state_load_quad_word(
542 map_value(load_val),
543 map_value(key),
544 map_value(number_of_slots),
545 ),
546 FuelVmInstruction::StateLoadWord(key) => {
547 new_block.append(context).state_load_word(map_value(key))
548 }
549 FuelVmInstruction::StateStoreQuadWord {
550 stored_val,
551 key,
552 number_of_slots,
553 } => new_block.append(context).state_store_quad_word(
554 map_value(stored_val),
555 map_value(key),
556 map_value(number_of_slots),
557 ),
558 FuelVmInstruction::StateStoreWord { stored_val, key } => new_block
559 .append(context)
560 .state_store_word(map_value(stored_val), map_value(key)),
561 FuelVmInstruction::WideUnaryOp { op, arg, result } => new_block
562 .append(context)
563 .wide_unary_op(op, map_value(arg), map_value(result)),
564 FuelVmInstruction::WideBinaryOp {
565 op,
566 arg1,
567 arg2,
568 result,
569 } => new_block.append(context).wide_binary_op(
570 op,
571 map_value(arg1),
572 map_value(arg2),
573 map_value(result),
574 ),
575 FuelVmInstruction::WideModularOp {
576 op,
577 result,
578 arg1,
579 arg2,
580 arg3,
581 } => new_block.append(context).wide_modular_op(
582 op,
583 map_value(result),
584 map_value(arg1),
585 map_value(arg2),
586 map_value(arg3),
587 ),
588 FuelVmInstruction::WideCmpOp { op, arg1, arg2 } => new_block
589 .append(context)
590 .wide_cmp_op(op, map_value(arg1), map_value(arg2)),
591 FuelVmInstruction::Retd { ptr, len } => new_block
592 .append(context)
593 .retd(map_value(ptr), map_value(len)),
594 },
595 InstOp::GetElemPtr {
596 base,
597 elem_ptr_ty,
598 indices,
599 } => {
600 let elem_ty = elem_ptr_ty.get_pointee_type(context).unwrap();
601 new_block.append(context).get_elem_ptr(
602 map_value(base),
603 elem_ty,
604 indices.iter().map(|idx| map_value(*idx)).collect(),
605 )
606 }
607 InstOp::GetLocal(local_var) => {
608 new_block.append(context).get_local(map_local(local_var))
609 }
610 InstOp::GetGlobal(global_var) => new_block.append(context).get_global(global_var),
611 InstOp::GetStorageKey(storage_key) => {
612 new_block.append(context).get_storage_key(storage_key)
613 }
614 InstOp::GetConfig(module, name) => new_block.append(context).get_config(module, name),
615 InstOp::IntToPtr(value, ty) => {
616 new_block.append(context).int_to_ptr(map_value(value), ty)
617 }
618 InstOp::Load(src_val) => new_block.append(context).load(map_value(src_val)),
619 InstOp::MemCopyBytes {
620 dst_val_ptr,
621 src_val_ptr,
622 byte_len,
623 } => new_block.append(context).mem_copy_bytes(
624 map_value(dst_val_ptr),
625 map_value(src_val_ptr),
626 byte_len,
627 ),
628 InstOp::MemCopyVal {
629 dst_val_ptr,
630 src_val_ptr,
631 } => new_block
632 .append(context)
633 .mem_copy_val(map_value(dst_val_ptr), map_value(src_val_ptr)),
634 InstOp::MemClearVal { dst_val_ptr } => new_block
635 .append(context)
636 .mem_clear_val(map_value(dst_val_ptr)),
637 InstOp::Nop => new_block.append(context).nop(),
638 InstOp::PtrToInt(value, ty) => {
639 new_block.append(context).ptr_to_int(map_value(value), ty)
640 }
641 InstOp::Ret(val, _) => new_block
643 .append(context)
644 .branch(*post_block, vec![map_value(val)]),
645 InstOp::Store {
646 dst_val_ptr,
647 stored_val,
648 } => new_block
649 .append(context)
650 .store(map_value(dst_val_ptr), map_value(stored_val)),
651 }
652 .add_metadatum(context, metadata);
653
654 value_map.insert(*instruction, new_ins);
655 }
656}