sway_ir/optimize/
ret_demotion.rs

1/// Return value demotion.
2///
3/// This pass demotes 'by-value' function return types to 'by-reference` pointer types, based on
4/// target specific parameters.
5///
6/// An extra argument pointer is added to the function.
7/// The return value is mem_copied to the new argument instead of being returned by value.
8use crate::{
9    AnalysisResults, BlockArgument, ConstantContent, Context, Function, InstOp, Instruction,
10    InstructionInserter, IrError, Module, Pass, PassMutability, ScopedPass, Type, Value,
11};
12
13pub const RET_DEMOTION_NAME: &str = "ret-demotion";
14
15pub fn create_ret_demotion_pass() -> Pass {
16    Pass {
17        name: RET_DEMOTION_NAME,
18        descr: "Demotion of by-value function return values to by-reference",
19        deps: Vec::new(),
20        runner: ScopedPass::ModulePass(PassMutability::Transform(ret_val_demotion)),
21    }
22}
23
24pub fn ret_val_demotion(
25    context: &mut Context,
26    _analyses: &AnalysisResults,
27    module: Module,
28) -> Result<bool, IrError> {
29    // This is a module pass because we need to update all the callers of a function if we change
30    // its signature.
31    let mut changed = false;
32    for function in module.function_iter(context) {
33        // Reject non-candidate.
34        let ret_type = function.get_return_type(context);
35        if !super::target_fuel::is_demotable_type(context, &ret_type) {
36            // Return type fits in a register.
37            continue;
38        }
39
40        changed = true;
41
42        // Change the function signature.
43        let ptr_ret_type = Type::new_typed_pointer(context, ret_type);
44        let unit_ty = Type::get_unit(context);
45
46        // The storage for the return value must be determined.  For entry-point functions it's a new
47        // local and otherwise it's an extra argument.
48        let entry_block = function.get_entry_block(context);
49        let ptr_arg_val = if function.is_entry(context) {
50            // Entry functions return a pointer to the original return type.
51            function.set_return_type(context, ptr_ret_type);
52
53            // Create a local variable to hold the return value.
54            let ret_var = function.new_unique_local_var(
55                context,
56                "__ret_value".to_owned(),
57                ret_type,
58                None,
59                false,
60            );
61
62            // Insert the return value pointer at the start of the entry block.
63            let get_ret_var =
64                Value::new_instruction(context, entry_block, InstOp::GetLocal(ret_var));
65            entry_block.prepend_instructions(context, vec![get_ret_var]);
66            get_ret_var
67        } else {
68            // non-entry functions now return unit.
69            function.set_return_type(context, unit_ty);
70
71            let ptr_arg_val = Value::new_argument(
72                context,
73                BlockArgument {
74                    block: entry_block,
75                    idx: function.num_args(context),
76                    ty: ptr_ret_type,
77                    is_immutable: false,
78                },
79            );
80            function.add_arg(context, "__ret_value", ptr_arg_val);
81            entry_block.add_arg(context, ptr_arg_val);
82            ptr_arg_val
83        };
84
85        // Gather the blocks which are returning.
86        let ret_blocks = function
87            .block_iter(context)
88            .filter_map(|block| {
89                block.get_terminator(context).and_then(|term| {
90                    if let InstOp::Ret(ret_val, _ty) = term.op {
91                        Some((block, ret_val))
92                    } else {
93                        None
94                    }
95                })
96            })
97            .collect::<Vec<_>>();
98
99        // Update each `ret` to store the return value to the 'out' arg and then return the pointer.
100        for (ret_block, ret_val) in ret_blocks {
101            // This is a special case where we're replacing the terminator.  We can just pop it off the
102            // end of the block and add new instructions.
103            let last_instr_pos = ret_block.num_instructions(context) - 1;
104            let orig_ret_val = ret_block.get_instruction_at(context, last_instr_pos);
105            ret_block.remove_instruction_at(context, last_instr_pos);
106            let md_idx = orig_ret_val.and_then(|val| val.get_metadata(context));
107
108            ret_block
109                .append(context)
110                .store(ptr_arg_val, ret_val)
111                .add_metadatum(context, md_idx);
112
113            if !function.is_entry(context) {
114                let unit_ret = ConstantContent::get_unit(context);
115                ret_block
116                    .append(context)
117                    .ret(unit_ret, unit_ty)
118                    .add_metadatum(context, md_idx);
119            } else {
120                // Entry functions still return the pointer to the return value.
121                ret_block
122                    .append(context)
123                    .ret(ptr_arg_val, ptr_ret_type)
124                    .add_metadatum(context, md_idx);
125            }
126        }
127
128        // If the function isn't an entry point we need to update all the callers to pass the extra
129        // argument.
130        if !function.is_entry(context) {
131            update_callers(context, function, ret_type);
132        }
133    }
134
135    Ok(changed)
136}
137
138fn update_callers(context: &mut Context, function: Function, ret_type: Type) {
139    // Now update all the callers to pass the return value argument. Find all the call sites for
140    // this function.
141    let call_sites = context
142        .module_iter()
143        .flat_map(|module| module.function_iter(context))
144        .flat_map(|ref call_from_func| {
145            call_from_func
146                .block_iter(context)
147                .flat_map(|ref block| {
148                    block
149                        .instruction_iter(context)
150                        .filter_map(|instr_val| {
151                            if let Instruction {
152                                op: InstOp::Call(call_to_func, _),
153                                ..
154                            } = instr_val
155                                .get_instruction(context)
156                                .expect("`instruction_iter()` must return instruction values.")
157                            {
158                                (*call_to_func == function).then_some((
159                                    *call_from_func,
160                                    *block,
161                                    instr_val,
162                                ))
163                            } else {
164                                None
165                            }
166                        })
167                        .collect::<Vec<_>>()
168                })
169                .collect::<Vec<_>>()
170        })
171        .collect::<Vec<_>>();
172
173    // Create a local var to receive the return value for each call site.  Replace the `call`
174    // instruction with a `get_local`, an updated `call` and a `load`.
175    for (calling_func, calling_block, call_val) in call_sites {
176        // First make a new local variable.
177        let loc_var = calling_func.new_unique_local_var(
178            context,
179            "__ret_val".to_owned(),
180            ret_type,
181            None,
182            false,
183        );
184        let get_loc_val = Value::new_instruction(context, calling_block, InstOp::GetLocal(loc_var));
185
186        // Next we need to copy the original `call` but add the extra arg.
187        let Some(Instruction {
188            op: InstOp::Call(_, args),
189            ..
190        }) = call_val.get_instruction(context)
191        else {
192            unreachable!("`call_val` is definitely a call instruction.");
193        };
194        let mut new_args = args.clone();
195        new_args.push(get_loc_val);
196        let new_call_val =
197            Value::new_instruction(context, calling_block, InstOp::Call(function, new_args));
198
199        // And finally load the value from the new local var.
200        let load_val = Value::new_instruction(context, calling_block, InstOp::Load(get_loc_val));
201
202        calling_block
203            .replace_instruction(context, call_val, get_loc_val, false)
204            .unwrap();
205        let mut inserter = InstructionInserter::new(
206            context,
207            calling_block,
208            crate::InsertionPosition::After(get_loc_val),
209        );
210        inserter.insert_slice(&[new_call_val, load_val]);
211
212        // Replace the old call with the new load.
213        calling_func.replace_value(context, call_val, load_val, None);
214    }
215}