sway_ir/optimize/
inline.rs

1//! Function inlining.
2//!
3//! Function inlining is pretty hairy so these passes must be maintained with care.
4
5use std::{cell::RefCell, collections::HashMap};
6
7use rustc_hash::FxHashMap;
8
9use crate::{
10    asm::AsmArg,
11    block::Block,
12    call_graph, compute_post_order,
13    context::Context,
14    error::IrError,
15    function::Function,
16    instruction::{FuelVmInstruction, InstOp},
17    irtype::Type,
18    metadata::{combine, MetadataIndex},
19    value::{Value, ValueContent, ValueDatum},
20    variable::LocalVar,
21    AnalysisResults, BlockArgument, Instruction, Module, Pass, PassMutability, ScopedPass,
22};
23
24pub const FN_INLINE_NAME: &str = "inline";
25
26pub fn create_fn_inline_pass() -> Pass {
27    Pass {
28        name: FN_INLINE_NAME,
29        descr: "Function inlining",
30        deps: vec![],
31        runner: ScopedPass::ModulePass(PassMutability::Transform(fn_inline)),
32    }
33}
34
35/// This is a copy of sway_core::inline::Inline.
36/// TODO: Reuse: Depend on sway_core? Move it to sway_types?
37#[derive(Debug)]
38pub enum Inline {
39    Always,
40    Never,
41}
42
43pub fn metadata_to_inline(context: &Context, md_idx: Option<MetadataIndex>) -> Option<Inline> {
44    fn for_each_md_idx<T, F: FnMut(MetadataIndex) -> Option<T>>(
45        context: &Context,
46        md_idx: Option<MetadataIndex>,
47        mut f: F,
48    ) -> Option<T> {
49        // If md_idx is not None and is a list then try them all.
50        md_idx.and_then(|md_idx| {
51            if let Some(md_idcs) = md_idx.get_content(context).unwrap_list() {
52                md_idcs.iter().find_map(|md_idx| f(*md_idx))
53            } else {
54                f(md_idx)
55            }
56        })
57    }
58    for_each_md_idx(context, md_idx, |md_idx| {
59        // Create a new inline and save it in the cache.
60        md_idx
61            .get_content(context)
62            .unwrap_struct("inline", 1)
63            .and_then(|fields| fields[0].unwrap_string())
64            .and_then(|inline_str| {
65                let inline = match inline_str {
66                    "always" => Some(Inline::Always),
67                    "never" => Some(Inline::Never),
68                    _otherwise => None,
69                }?;
70                Some(inline)
71            })
72    })
73}
74
75pub fn fn_inline(
76    context: &mut Context,
77    _: &AnalysisResults,
78    module: Module,
79) -> Result<bool, IrError> {
80    // Inspect ALL calls and count how often each function is called.
81    let call_counts: HashMap<Function, u64> =
82        module
83            .function_iter(context)
84            .fold(HashMap::new(), |mut counts, func| {
85                for (_block, ins) in func.instruction_iter(context) {
86                    if let Some(Instruction {
87                        op: InstOp::Call(callee, _args),
88                        ..
89                    }) = ins.get_instruction(context)
90                    {
91                        counts
92                            .entry(*callee)
93                            .and_modify(|count| *count += 1)
94                            .or_insert(1);
95                    }
96                }
97                counts
98            });
99
100    let inline_heuristic = |ctx: &Context, func: &Function, _call_site: &Value| {
101        // The encoding code in the `__entry` functions contains pointer patterns that mark
102        // escape analysis and referred symbols as incomplete. This effectively forbids optimizations
103        // like SROA nad DCE. If we inline original entries, like e.g., `main`, the code in them will
104        // also not be optimized. Therefore, we forbid inlining of original entries into `__entry`.
105        if func.is_original_entry(ctx) {
106            return false;
107        }
108
109        let attributed_inline = metadata_to_inline(ctx, func.get_metadata(ctx));
110        match attributed_inline {
111            Some(Inline::Always) => {
112                // TODO: check if inlining of function is possible
113                // return true;
114            }
115            Some(Inline::Never) => {
116                return false;
117            }
118            None => {}
119        }
120
121        // If the function is called only once then definitely inline it.
122        if call_counts.get(func).copied().unwrap_or(0) == 1 {
123            return true;
124        }
125
126        // If the function is (still) small then also inline it.
127        const MAX_INLINE_INSTRS_COUNT: usize = 12;
128        if func.num_instructions_incl_asm_instructions(ctx) <= MAX_INLINE_INSTRS_COUNT {
129            return true;
130        }
131
132        false
133    };
134
135    let cg =
136        call_graph::build_call_graph(context, &module.function_iter(context).collect::<Vec<_>>());
137    let functions = call_graph::callee_first_order(&cg);
138    let mut modified = false;
139
140    for function in functions {
141        modified |= inline_some_function_calls(context, &function, inline_heuristic)?;
142    }
143    Ok(modified)
144}
145
146/// Inline all calls made from a specific function, effectively removing all `Call` instructions.
147///
148/// e.g., If this is applied to main() then all calls in the program are removed.  This is
149/// obviously dangerous for recursive functions, in which case this pass would inline forever.
150pub fn inline_all_function_calls(
151    context: &mut Context,
152    function: &Function,
153) -> Result<bool, IrError> {
154    inline_some_function_calls(context, function, |_, _, _| true)
155}
156
157/// Inline function calls based on a provided heuristic predicate.
158///
159/// There are many things to consider when deciding to inline a function.  For example:
160/// - The size of the function, especially if smaller than the call overhead size.
161/// - The stack frame size of the function.
162/// - The number of calls made to the function or if the function is called inside a loop.
163/// - A particular call has constant arguments implying further constant folding.
164/// - An attribute request, e.g., #[always_inline], #[never_inline].
165pub fn inline_some_function_calls<F: Fn(&Context, &Function, &Value) -> bool>(
166    context: &mut Context,
167    function: &Function,
168    predicate: F,
169) -> Result<bool, IrError> {
170    // Find call sites which passes the predicate.
171    // We use a RefCell so that the inliner can modify the value
172    // when it moves other instructions (which could be in call_date) after an inline.
173    let (call_sites, call_data): (Vec<_>, FxHashMap<_, _>) = function
174        .instruction_iter(context)
175        .filter_map(|(block, call_val)| match context.values[call_val.0].value {
176            ValueDatum::Instruction(Instruction {
177                op: InstOp::Call(inlined_function, _),
178                ..
179            }) => predicate(context, &inlined_function, &call_val).then_some((
180                call_val,
181                (call_val, RefCell::new((block, inlined_function))),
182            )),
183            _ => None,
184        })
185        .unzip();
186
187    for call_site in &call_sites {
188        let call_site_in = call_data.get(call_site).unwrap();
189        let (block, inlined_function) = *call_site_in.borrow();
190
191        if function == &inlined_function {
192            // We can't inline a function into itself.
193            continue;
194        }
195
196        inline_function_call(
197            context,
198            *function,
199            block,
200            *call_site,
201            inlined_function,
202            &call_data,
203        )?;
204    }
205
206    Ok(!call_data.is_empty())
207}
208
209/// A utility to get a predicate which can be passed to inline_some_function_calls() based on
210/// certain sizes of the function.  If a constraint is None then any size is assumed to be
211/// acceptable.
212///
213/// The max_stack_size is a bit tricky, as the IR doesn't really know (or care) about the size of
214/// types.  See the source code for how it works.
215pub fn is_small_fn(
216    max_blocks: Option<usize>,
217    max_instrs: Option<usize>,
218    max_stack_size: Option<usize>,
219) -> impl Fn(&Context, &Function, &Value) -> bool {
220    fn count_type_elements(context: &Context, ty: &Type) -> usize {
221        // This is meant to just be a heuristic rather than be super accurate.
222        if ty.is_array(context) {
223            count_type_elements(context, &ty.get_array_elem_type(context).unwrap())
224                * ty.get_array_len(context).unwrap() as usize
225        } else if ty.is_union(context) {
226            ty.get_field_types(context)
227                .iter()
228                .map(|ty| count_type_elements(context, ty))
229                .max()
230                .unwrap_or(1)
231        } else if ty.is_struct(context) {
232            ty.get_field_types(context)
233                .iter()
234                .map(|ty| count_type_elements(context, ty))
235                .sum()
236        } else {
237            1
238        }
239    }
240
241    move |context: &Context, function: &Function, _call_site: &Value| -> bool {
242        max_blocks.is_none_or(|max_block_count| function.num_blocks(context) <= max_block_count)
243            && max_instrs.is_none_or(|max_instrs_count| {
244                function.num_instructions_incl_asm_instructions(context) <= max_instrs_count
245            })
246            && max_stack_size.is_none_or(|max_stack_size_count| {
247                function
248                    .locals_iter(context)
249                    .map(|(_name, ptr)| count_type_elements(context, &ptr.get_inner_type(context)))
250                    .sum::<usize>()
251                    <= max_stack_size_count
252            })
253    }
254}
255
256/// Inline a function to a specific call site within another function.
257///
258/// The destination function, block and call site must be specified along with the function to
259/// inline.
260pub fn inline_function_call(
261    context: &mut Context,
262    function: Function,
263    block: Block,
264    call_site: Value,
265    inlined_function: Function,
266    call_data: &FxHashMap<Value, RefCell<(Block, Function)>>,
267) -> Result<(), IrError> {
268    // Split the block at right after the call site.
269    let call_site_idx = block
270        .instruction_iter(context)
271        .position(|v| v == call_site)
272        .unwrap();
273    let (pre_block, post_block) = block.split_at(context, call_site_idx + 1);
274    if post_block != block {
275        // We need to update call_data for every call_site that was in block.
276        for inst in post_block.instruction_iter(context).filter(|inst| {
277            matches!(
278                context.values[inst.0].value,
279                ValueDatum::Instruction(Instruction {
280                    op: InstOp::Call(..),
281                    ..
282                })
283            )
284        }) {
285            if let Some(call_info) = call_data.get(&inst) {
286                call_info.borrow_mut().0 = post_block;
287            }
288        }
289    }
290
291    // Remove the call from the pre_block instructions.  It's still in the context.values[] though.
292    pre_block.remove_last_instruction(context);
293
294    // Returned values, if any, go to `post_block`, so a block arg there.
295    // We don't expect `post_block` to already have any block args.
296    if post_block.new_arg(context, call_site.get_type(context).unwrap()) != 0 {
297        panic!("Expected newly created post_block to not have block args")
298    }
299    function.replace_value(
300        context,
301        call_site,
302        post_block.get_arg(context, 0).unwrap(),
303        None,
304    );
305
306    // Take the locals from the inlined function and add them to this function.  `value_map` is a
307    // map from the original local ptrs to the new ptrs.
308    let ptr_map = function.merge_locals_from(context, inlined_function);
309    let mut value_map = HashMap::new();
310
311    // Add the mapping from argument values in the inlined function to the args passed to the call.
312    if let ValueDatum::Instruction(Instruction {
313        op: InstOp::Call(_, passed_vals),
314        ..
315    }) = &context.values[call_site.0].value
316    {
317        for (arg_val, passed_val) in context.functions[inlined_function.0]
318            .arguments
319            .iter()
320            .zip(passed_vals.iter())
321        {
322            value_map.insert(arg_val.1, *passed_val);
323        }
324    }
325
326    // Get the metadata attached to the function call which may need to be propagated to the
327    // inlined instructions.
328    let metadata = context.values[call_site.0].metadata;
329
330    // Now remove the call altogether.
331    context.values.remove(call_site.0);
332
333    // Insert empty blocks from the inlined function between our split blocks, and create a mapping
334    // from old blocks to new.  We need this when inlining branch instructions, so they branch to
335    // the new blocks.
336    //
337    // We map the entry block in the inlined function (which we know must exist) to our `pre_block`
338    // from the split above.  We'll start appending inlined instructions to that block rather than
339    // a new one (with a redundant branch to it from the `pre_block`).
340    let inlined_fn_name = inlined_function.get_name(context).to_owned();
341    let mut block_map = HashMap::new();
342    let mut block_iter = context.functions[inlined_function.0]
343        .blocks
344        .clone()
345        .into_iter();
346    block_map.insert(block_iter.next().unwrap(), pre_block);
347    block_map = block_iter.fold(block_map, |mut block_map, inlined_block| {
348        let inlined_block_label = inlined_block.get_label(context);
349        let new_block = function
350            .create_block_before(
351                context,
352                &post_block,
353                Some(format!("{inlined_fn_name}_{inlined_block_label}")),
354            )
355            .unwrap();
356        block_map.insert(inlined_block, new_block);
357        // We collect so that context can be mutably borrowed later.
358        let inlined_args: Vec<_> = inlined_block.arg_iter(context).copied().collect();
359        for inlined_arg in inlined_args {
360            if let ValueDatum::Argument(BlockArgument {
361                block: _,
362                idx: _,
363                ty,
364                is_immutable: _,
365            }) = &context.values[inlined_arg.0].value
366            {
367                let index = new_block.new_arg(context, *ty);
368                value_map.insert(inlined_arg, new_block.get_arg(context, index).unwrap());
369            } else {
370                unreachable!("Expected a block argument")
371            }
372        }
373        block_map
374    });
375
376    // Use a reverse-post-order traversal to ensure that definitions are seen before uses.
377    let inlined_block_iter = compute_post_order(context, &inlined_function)
378        .po_to_block
379        .into_iter()
380        .rev();
381    // We now have a mapping from old blocks to new (currently empty) blocks, and a mapping from
382    // old values (locals and args at this stage) to new values.  We can copy instructions over,
383    // translating their blocks and values to refer to the new ones.  The value map is still live
384    // as we add new instructions which replace the old ones to it too.
385    for ref block in inlined_block_iter {
386        for ins in block.instruction_iter(context) {
387            inline_instruction(
388                context,
389                block_map.get(block).unwrap(),
390                &post_block,
391                &ins,
392                &block_map,
393                &mut value_map,
394                &ptr_map,
395                metadata,
396            );
397        }
398    }
399
400    Ok(())
401}
402
403#[allow(clippy::too_many_arguments)]
404fn inline_instruction(
405    context: &mut Context,
406    new_block: &Block,
407    post_block: &Block,
408    instruction: &Value,
409    block_map: &HashMap<Block, Block>,
410    value_map: &mut HashMap<Value, Value>,
411    local_map: &HashMap<LocalVar, LocalVar>,
412    fn_metadata: Option<MetadataIndex>,
413) {
414    // Util to translate old blocks to new.  If an old block isn't in the map then we panic, since
415    // it should be guaranteed to be there...that's a bug otherwise.
416    let map_block = |old_block| *block_map.get(&old_block).unwrap();
417
418    // Util to translate old values to new.  If an old value isn't in the map then it (should be)
419    // a const, which we can just keep using.
420    let map_value = |old_val: Value| value_map.get(&old_val).copied().unwrap_or(old_val);
421    let map_local = |old_local| local_map.get(&old_local).copied().unwrap();
422
423    // The instruction needs to be cloned into the new block, with each value and/or block
424    // translated using the above maps.  Most of these are relatively cheap as Instructions
425    // generally are lightweight, except maybe ASM blocks, but we're able to re-use the block
426    // content since it's a black box and not concerned with Values, Blocks or Pointers.
427    //
428    // We need to clone the instruction here, which is unfortunate.  Maybe in the future we
429    // restructure instructions somehow, so we don't need a persistent `&Context` to access them.
430    if let ValueContent {
431        value: ValueDatum::Instruction(old_ins),
432        metadata: val_metadata,
433    } = context.values[instruction.0].clone()
434    {
435        // Combine the function metadata with this instruction metadata so we don't lose the
436        // function metadata after inlining.
437        let metadata = combine(context, &fn_metadata, &val_metadata);
438
439        let new_ins = match old_ins.op {
440            InstOp::AsmBlock(asm, args) => {
441                let new_args = args
442                    .iter()
443                    .map(|AsmArg { name, initializer }| AsmArg {
444                        name: name.clone(),
445                        initializer: initializer.map(map_value),
446                    })
447                    .collect();
448
449                // We can re-use the old asm block with the updated args.
450                new_block.append(context).asm_block_from_asm(asm, new_args)
451            }
452            InstOp::BitCast(value, ty) => new_block.append(context).bitcast(map_value(value), ty),
453            InstOp::UnaryOp { op, arg } => new_block.append(context).unary_op(op, map_value(arg)),
454            InstOp::BinaryOp { op, arg1, arg2 } => {
455                new_block
456                    .append(context)
457                    .binary_op(op, map_value(arg1), map_value(arg2))
458            }
459            // For `br` and `cbr` below we don't need to worry about the phi values, they're
460            // adjusted later in `inline_function_call()`.
461            InstOp::Branch(b) => new_block.append(context).branch(
462                map_block(b.block),
463                b.args.iter().map(|v| map_value(*v)).collect(),
464            ),
465            InstOp::Call(f, args) => new_block.append(context).call(
466                f,
467                args.iter()
468                    .map(|old_val: &Value| map_value(*old_val))
469                    .collect::<Vec<Value>>()
470                    .as_slice(),
471            ),
472            InstOp::CastPtr(val, ty) => new_block.append(context).cast_ptr(map_value(val), ty),
473            InstOp::Cmp(pred, lhs_value, rhs_value) => {
474                new_block
475                    .append(context)
476                    .cmp(pred, map_value(lhs_value), map_value(rhs_value))
477            }
478            InstOp::ConditionalBranch {
479                cond_value,
480                true_block,
481                false_block,
482            } => new_block.append(context).conditional_branch(
483                map_value(cond_value),
484                map_block(true_block.block),
485                map_block(false_block.block),
486                true_block.args.iter().map(|v| map_value(*v)).collect(),
487                false_block.args.iter().map(|v| map_value(*v)).collect(),
488            ),
489            InstOp::ContractCall {
490                return_type,
491                name,
492                params,
493                coins,
494                asset_id,
495                gas,
496            } => new_block.append(context).contract_call(
497                return_type,
498                name,
499                map_value(params),
500                map_value(coins),
501                map_value(asset_id),
502                map_value(gas),
503            ),
504            InstOp::FuelVm(fuel_vm_instr) => match fuel_vm_instr {
505                FuelVmInstruction::Gtf { index, tx_field_id } => {
506                    new_block.append(context).gtf(map_value(index), tx_field_id)
507                }
508                FuelVmInstruction::Log {
509                    log_val,
510                    log_ty,
511                    log_id,
512                } => new_block
513                    .append(context)
514                    .log(map_value(log_val), log_ty, map_value(log_id)),
515                FuelVmInstruction::ReadRegister(reg) => {
516                    new_block.append(context).read_register(reg)
517                }
518                FuelVmInstruction::Revert(val) => new_block.append(context).revert(map_value(val)),
519                FuelVmInstruction::JmpMem => new_block.append(context).jmp_mem(),
520                FuelVmInstruction::Smo {
521                    recipient,
522                    message,
523                    message_size,
524                    coins,
525                } => new_block.append(context).smo(
526                    map_value(recipient),
527                    map_value(message),
528                    map_value(message_size),
529                    map_value(coins),
530                ),
531                FuelVmInstruction::StateClear {
532                    key,
533                    number_of_slots,
534                } => new_block
535                    .append(context)
536                    .state_clear(map_value(key), map_value(number_of_slots)),
537                FuelVmInstruction::StateLoadQuadWord {
538                    load_val,
539                    key,
540                    number_of_slots,
541                } => new_block.append(context).state_load_quad_word(
542                    map_value(load_val),
543                    map_value(key),
544                    map_value(number_of_slots),
545                ),
546                FuelVmInstruction::StateLoadWord(key) => {
547                    new_block.append(context).state_load_word(map_value(key))
548                }
549                FuelVmInstruction::StateStoreQuadWord {
550                    stored_val,
551                    key,
552                    number_of_slots,
553                } => new_block.append(context).state_store_quad_word(
554                    map_value(stored_val),
555                    map_value(key),
556                    map_value(number_of_slots),
557                ),
558                FuelVmInstruction::StateStoreWord { stored_val, key } => new_block
559                    .append(context)
560                    .state_store_word(map_value(stored_val), map_value(key)),
561                FuelVmInstruction::WideUnaryOp { op, arg, result } => new_block
562                    .append(context)
563                    .wide_unary_op(op, map_value(arg), map_value(result)),
564                FuelVmInstruction::WideBinaryOp {
565                    op,
566                    arg1,
567                    arg2,
568                    result,
569                } => new_block.append(context).wide_binary_op(
570                    op,
571                    map_value(arg1),
572                    map_value(arg2),
573                    map_value(result),
574                ),
575                FuelVmInstruction::WideModularOp {
576                    op,
577                    result,
578                    arg1,
579                    arg2,
580                    arg3,
581                } => new_block.append(context).wide_modular_op(
582                    op,
583                    map_value(result),
584                    map_value(arg1),
585                    map_value(arg2),
586                    map_value(arg3),
587                ),
588                FuelVmInstruction::WideCmpOp { op, arg1, arg2 } => new_block
589                    .append(context)
590                    .wide_cmp_op(op, map_value(arg1), map_value(arg2)),
591                FuelVmInstruction::Retd { ptr, len } => new_block
592                    .append(context)
593                    .retd(map_value(ptr), map_value(len)),
594            },
595            InstOp::GetElemPtr {
596                base,
597                elem_ptr_ty,
598                indices,
599            } => {
600                let elem_ty = elem_ptr_ty.get_pointee_type(context).unwrap();
601                new_block.append(context).get_elem_ptr(
602                    map_value(base),
603                    elem_ty,
604                    indices.iter().map(|idx| map_value(*idx)).collect(),
605                )
606            }
607            InstOp::GetLocal(local_var) => {
608                new_block.append(context).get_local(map_local(local_var))
609            }
610            InstOp::GetGlobal(global_var) => new_block.append(context).get_global(global_var),
611            InstOp::GetStorageKey(storage_key) => {
612                new_block.append(context).get_storage_key(storage_key)
613            }
614            InstOp::GetConfig(module, name) => new_block.append(context).get_config(module, name),
615            InstOp::IntToPtr(value, ty) => {
616                new_block.append(context).int_to_ptr(map_value(value), ty)
617            }
618            InstOp::Load(src_val) => new_block.append(context).load(map_value(src_val)),
619            InstOp::MemCopyBytes {
620                dst_val_ptr,
621                src_val_ptr,
622                byte_len,
623            } => new_block.append(context).mem_copy_bytes(
624                map_value(dst_val_ptr),
625                map_value(src_val_ptr),
626                byte_len,
627            ),
628            InstOp::MemCopyVal {
629                dst_val_ptr,
630                src_val_ptr,
631            } => new_block
632                .append(context)
633                .mem_copy_val(map_value(dst_val_ptr), map_value(src_val_ptr)),
634            InstOp::MemClearVal { dst_val_ptr } => new_block
635                .append(context)
636                .mem_clear_val(map_value(dst_val_ptr)),
637            InstOp::Nop => new_block.append(context).nop(),
638            InstOp::PtrToInt(value, ty) => {
639                new_block.append(context).ptr_to_int(map_value(value), ty)
640            }
641            // We convert `ret` to `br post_block` and add the returned value as a phi value.
642            InstOp::Ret(val, _) => new_block
643                .append(context)
644                .branch(*post_block, vec![map_value(val)]),
645            InstOp::Store {
646                dst_val_ptr,
647                stored_val,
648            } => new_block
649                .append(context)
650                .store(map_value(dst_val_ptr), map_value(stored_val)),
651        }
652        .add_metadatum(context, metadata);
653
654        value_map.insert(*instruction, new_ins);
655    }
656}