1use crate::{
9 AnalysisResults, BlockArgument, ConstantContent, Context, Function, InstOp, Instruction,
10 InstructionInserter, IrError, Module, Pass, PassMutability, ScopedPass, Type, Value,
11};
12
13pub const RET_DEMOTION_NAME: &str = "ret-demotion";
14
15pub fn create_ret_demotion_pass() -> Pass {
16 Pass {
17 name: RET_DEMOTION_NAME,
18 descr: "Demotion of by-value function return values to by-reference",
19 deps: Vec::new(),
20 runner: ScopedPass::ModulePass(PassMutability::Transform(ret_val_demotion)),
21 }
22}
23
24pub fn ret_val_demotion(
25 context: &mut Context,
26 _analyses: &AnalysisResults,
27 module: Module,
28) -> Result<bool, IrError> {
29 let mut changed = false;
32 for function in module.function_iter(context) {
33 let ret_type = function.get_return_type(context);
35 if !super::target_fuel::is_demotable_type(context, &ret_type) {
36 continue;
38 }
39
40 changed = true;
41
42 let ptr_ret_type = Type::new_typed_pointer(context, ret_type);
44 let unit_ty = Type::get_unit(context);
45
46 let entry_block = function.get_entry_block(context);
49 let ptr_arg_val = if function.is_entry(context) {
50 function.set_return_type(context, ptr_ret_type);
52
53 let ret_var = function.new_unique_local_var(
55 context,
56 "__ret_value".to_owned(),
57 ret_type,
58 None,
59 false,
60 );
61
62 let get_ret_var =
64 Value::new_instruction(context, entry_block, InstOp::GetLocal(ret_var));
65 entry_block.prepend_instructions(context, vec![get_ret_var]);
66 get_ret_var
67 } else {
68 function.set_return_type(context, unit_ty);
70
71 let ptr_arg_val = Value::new_argument(
72 context,
73 BlockArgument {
74 block: entry_block,
75 idx: function.num_args(context),
76 ty: ptr_ret_type,
77 is_immutable: false,
78 },
79 );
80 function.add_arg(context, "__ret_value", ptr_arg_val);
81 entry_block.add_arg(context, ptr_arg_val);
82 ptr_arg_val
83 };
84
85 let ret_blocks = function
87 .block_iter(context)
88 .filter_map(|block| {
89 block.get_terminator(context).and_then(|term| {
90 if let InstOp::Ret(ret_val, _ty) = term.op {
91 Some((block, ret_val))
92 } else {
93 None
94 }
95 })
96 })
97 .collect::<Vec<_>>();
98
99 for (ret_block, ret_val) in ret_blocks {
101 let last_instr_pos = ret_block.num_instructions(context) - 1;
104 let orig_ret_val = ret_block.get_instruction_at(context, last_instr_pos);
105 ret_block.remove_instruction_at(context, last_instr_pos);
106 let md_idx = orig_ret_val.and_then(|val| val.get_metadata(context));
107
108 ret_block
109 .append(context)
110 .store(ptr_arg_val, ret_val)
111 .add_metadatum(context, md_idx);
112
113 if !function.is_entry(context) {
114 let unit_ret = ConstantContent::get_unit(context);
115 ret_block
116 .append(context)
117 .ret(unit_ret, unit_ty)
118 .add_metadatum(context, md_idx);
119 } else {
120 ret_block
122 .append(context)
123 .ret(ptr_arg_val, ptr_ret_type)
124 .add_metadatum(context, md_idx);
125 }
126 }
127
128 if !function.is_entry(context) {
131 update_callers(context, function, ret_type);
132 }
133 }
134
135 Ok(changed)
136}
137
138fn update_callers(context: &mut Context, function: Function, ret_type: Type) {
139 let call_sites = context
142 .module_iter()
143 .flat_map(|module| module.function_iter(context))
144 .flat_map(|ref call_from_func| {
145 call_from_func
146 .block_iter(context)
147 .flat_map(|ref block| {
148 block
149 .instruction_iter(context)
150 .filter_map(|instr_val| {
151 if let Instruction {
152 op: InstOp::Call(call_to_func, _),
153 ..
154 } = instr_val
155 .get_instruction(context)
156 .expect("`instruction_iter()` must return instruction values.")
157 {
158 (*call_to_func == function).then_some((
159 *call_from_func,
160 *block,
161 instr_val,
162 ))
163 } else {
164 None
165 }
166 })
167 .collect::<Vec<_>>()
168 })
169 .collect::<Vec<_>>()
170 })
171 .collect::<Vec<_>>();
172
173 for (calling_func, calling_block, call_val) in call_sites {
176 let loc_var = calling_func.new_unique_local_var(
178 context,
179 "__ret_val".to_owned(),
180 ret_type,
181 None,
182 false,
183 );
184 let get_loc_val = Value::new_instruction(context, calling_block, InstOp::GetLocal(loc_var));
185
186 let Some(Instruction {
188 op: InstOp::Call(_, args),
189 ..
190 }) = call_val.get_instruction(context)
191 else {
192 unreachable!("`call_val` is definitely a call instruction.");
193 };
194 let mut new_args = args.clone();
195 new_args.push(get_loc_val);
196 let new_call_val =
197 Value::new_instruction(context, calling_block, InstOp::Call(function, new_args));
198
199 let load_val = Value::new_instruction(context, calling_block, InstOp::Load(get_loc_val));
201
202 calling_block
203 .replace_instruction(context, call_val, get_loc_val, false)
204 .unwrap();
205 let mut inserter = InstructionInserter::new(
206 context,
207 calling_block,
208 crate::InsertionPosition::After(get_loc_val),
209 );
210 inserter.insert_slice(&[new_call_val, load_val]);
211
212 calling_func.replace_value(context, call_val, load_val, None);
214 }
215}