1use indexmap::IndexMap;
2use rustc_hash::FxHashMap;
8use std::collections::HashSet;
9use sway_utils::mapped_stack::MappedStack;
10
11use crate::{
12 AnalysisResults, Block, BranchToWithArgs, Constant, Context, DomFronts, DomTree, Function,
13 InstOp, Instruction, IrError, LocalVar, Pass, PassMutability, PostOrder, ScopedPass, Type,
14 Value, ValueDatum, DOMINATORS_NAME, DOM_FRONTS_NAME, POSTORDER_NAME,
15};
16
17pub const MEM2REG_NAME: &str = "mem2reg";
18
19pub fn create_mem2reg_pass() -> Pass {
20 Pass {
21 name: MEM2REG_NAME,
22 descr: "Promotion of memory to SSA registers",
23 deps: vec![POSTORDER_NAME, DOMINATORS_NAME, DOM_FRONTS_NAME],
24 runner: ScopedPass::FunctionPass(PassMutability::Transform(promote_to_registers)),
25 }
26}
27
28fn get_validate_local_var(
30 context: &Context,
31 function: &Function,
32 val: &Value,
33) -> Option<(String, LocalVar)> {
34 match context.values[val.0].value {
35 ValueDatum::Instruction(Instruction {
36 op: InstOp::GetLocal(local_var),
37 ..
38 }) => {
39 let name = function.lookup_local_name(context, &local_var);
40 name.map(|name| (name.clone(), local_var))
41 }
42 _ => None,
43 }
44}
45
46fn is_promotable_type(context: &Context, ty: Type) -> bool {
47 ty.is_unit(context)
48 || ty.is_bool(context)
49 || ty.is_ptr(context)
50 || (ty.is_uint(context) && ty.get_uint_width(context).unwrap() <= 64)
51}
52
53fn filter_usable_locals(context: &mut Context, function: &Function) -> HashSet<String> {
55 let mut locals: HashSet<String> = function
58 .locals_iter(context)
59 .filter_map(|(name, var)| {
60 let ty = var.get_inner_type(context);
61 is_promotable_type(context, ty).then_some(name.clone())
62 })
63 .collect();
64
65 for (_, inst) in function.instruction_iter(context) {
66 match context.values[inst.0].value {
67 ValueDatum::Instruction(Instruction {
68 op: InstOp::Load(_),
69 ..
70 }) => {}
71 ValueDatum::Instruction(Instruction {
72 op:
73 InstOp::Store {
74 dst_val_ptr: _,
75 stored_val,
76 },
77 ..
78 }) => {
79 if let Some((local, _)) = get_validate_local_var(context, function, &stored_val) {
82 locals.remove(&local);
83 }
84 }
85 _ => {
86 let operands = inst.get_instruction(context).unwrap().op.get_operands();
88 for opd in operands {
89 if let Some((local, ..)) = get_validate_local_var(context, function, &opd) {
90 locals.remove(&local);
91 }
92 }
93 }
94 }
95 }
96 locals
97}
98
99pub fn compute_livein(
103 context: &mut Context,
104 function: &Function,
105 po: &PostOrder,
106 locals: &HashSet<String>,
107) -> FxHashMap<Block, HashSet<String>> {
108 let mut result = FxHashMap::<Block, HashSet<String>>::default();
109 for block in &po.po_to_block {
110 result.insert(*block, HashSet::<String>::default());
111 }
112
113 let mut changed = true;
114 while changed {
115 changed = false;
116 for block in &po.po_to_block {
117 let mut cur_live = HashSet::<String>::default();
119 for BranchToWithArgs { block: succ, .. } in block.successors(context) {
120 let succ_livein = &result[&succ];
121 cur_live.extend(succ_livein.iter().cloned());
122 }
123 for inst in block.instruction_iter(context).rev() {
125 match context.values[inst.0].value {
126 ValueDatum::Instruction(Instruction {
127 op: InstOp::Load(ptr),
128 ..
129 }) => {
130 let local_var = get_validate_local_var(context, function, &ptr);
131 match local_var {
132 Some((local, ..)) if locals.contains(&local) => {
133 cur_live.insert(local);
134 }
135 _ => {}
136 }
137 }
138 ValueDatum::Instruction(Instruction {
139 op: InstOp::Store { dst_val_ptr, .. },
140 ..
141 }) => {
142 let local_var = get_validate_local_var(context, function, &dst_val_ptr);
143 match local_var {
144 Some((local, _)) if locals.contains(&local) => {
145 cur_live.remove(&local);
146 }
147 _ => (),
148 }
149 }
150 _ => (),
151 }
152 }
153 if result[block] != cur_live {
154 result.get_mut(block).unwrap().extend(cur_live);
156 changed = true;
157 }
158 }
159 }
160 result
161}
162
163fn promote_globals(context: &mut Context, function: &Function) -> Result<bool, IrError> {
166 let mut replacements = FxHashMap::<Value, Constant>::default();
167 for (_, inst) in function.instruction_iter(context) {
168 if let ValueDatum::Instruction(Instruction {
169 op: InstOp::Load(ptr),
170 ..
171 }) = context.values[inst.0].value
172 {
173 if let ValueDatum::Instruction(Instruction {
174 op: InstOp::GetGlobal(global_var),
175 ..
176 }) = context.values[ptr.0].value
177 {
178 if !global_var.is_mutable(context)
179 && is_promotable_type(context, global_var.get_inner_type(context))
180 {
181 let constant = *global_var
182 .get_initializer(context)
183 .expect("`global_var` is not mutable so it must be initialized");
184 replacements.insert(inst, constant);
185 }
186 }
187 }
188 }
189
190 if replacements.is_empty() {
191 return Ok(false);
192 }
193
194 let replacements = replacements
195 .into_iter()
196 .map(|(k, v)| (k, Value::new_constant(context, v)))
197 .collect::<FxHashMap<_, _>>();
198
199 function.replace_values(context, &replacements, None);
200
201 Ok(true)
202}
203
204pub fn promote_to_registers(
206 context: &mut Context,
207 analyses: &AnalysisResults,
208 function: Function,
209) -> Result<bool, IrError> {
210 let mut modified = false;
211 modified |= promote_globals(context, &function)?;
212 modified |= promote_locals(context, analyses, function)?;
213 Ok(modified)
214}
215
216pub fn promote_locals(
220 context: &mut Context,
221 analyses: &AnalysisResults,
222 function: Function,
223) -> Result<bool, IrError> {
224 let safe_locals = filter_usable_locals(context, &function);
225 if safe_locals.is_empty() {
226 return Ok(false);
227 }
228
229 let po: &PostOrder = analyses.get_analysis_result(function);
230 let dom_tree: &DomTree = analyses.get_analysis_result(function);
231 let dom_fronts: &DomFronts = analyses.get_analysis_result(function);
232 let liveins = compute_livein(context, &function, po, &safe_locals);
233
234 let mut new_phi_tracker = HashSet::<(String, Block)>::new();
236 let mut worklist = Vec::<(String, Type, Block)>::new();
238 let mut phi_to_local = FxHashMap::<Value, String>::default();
239 for (block, inst) in po
243 .po_to_block
244 .iter()
245 .rev()
246 .flat_map(|b| b.instruction_iter(context).map(|i| (*b, i)))
247 {
248 if let ValueDatum::Instruction(Instruction {
249 op: InstOp::Store { dst_val_ptr, .. },
250 ..
251 }) = context.values[inst.0].value
252 {
253 match get_validate_local_var(context, &function, &dst_val_ptr) {
254 Some((local, var)) if safe_locals.contains(&local) => {
255 worklist.push((local, var.get_inner_type(context), block));
256 }
257 _ => (),
258 }
259 }
260 }
261 while let Some((local, ty, known_def)) = worklist.pop() {
263 for df in dom_fronts[&known_def].iter() {
264 if !new_phi_tracker.contains(&(local.clone(), *df)) && liveins[df].contains(&local) {
265 let index = df.new_arg(context, ty);
267 phi_to_local.insert(df.get_arg(context, index).unwrap(), local.clone());
268 new_phi_tracker.insert((local.clone(), *df));
269 worklist.push((local.clone(), ty, *df));
271 }
272 }
273 }
274
275 #[allow(clippy::too_many_arguments)]
279 fn record_rewrites(
280 context: &mut Context,
281 function: &Function,
282 dom_tree: &DomTree,
283 node: Block,
284 safe_locals: &HashSet<String>,
285 phi_to_local: &FxHashMap<Value, String>,
286 name_stack: &mut MappedStack<String, Value>,
287 rewrites: &mut FxHashMap<Value, Value>,
288 deletes: &mut Vec<(Block, Value)>,
289 ) {
290 let mut num_local_pushes = IndexMap::<String, u32>::new();
293
294 for arg in node.arg_iter(context) {
296 if let Some(local) = phi_to_local.get(arg) {
297 name_stack.push(local.clone(), *arg);
298 num_local_pushes
299 .entry(local.clone())
300 .and_modify(|count| *count += 1)
301 .or_insert(1);
302 }
303 }
304
305 for inst in node.instruction_iter(context) {
306 match context.values[inst.0].value {
307 ValueDatum::Instruction(Instruction {
308 op: InstOp::Load(ptr),
309 ..
310 }) => {
311 let local_var = get_validate_local_var(context, function, &ptr);
312 match local_var {
313 Some((local, var)) if safe_locals.contains(&local) => {
314 let new_val = match name_stack.get(&local) {
316 Some(val) => *val,
317 None => {
318 let constant = *var
320 .get_initializer(context)
321 .expect("We're dealing with an uninitialized value");
322 Value::new_constant(context, constant)
323 }
324 };
325 rewrites.insert(inst, new_val);
326 deletes.push((node, inst));
327 }
328 _ => (),
329 }
330 }
331 ValueDatum::Instruction(Instruction {
332 op:
333 InstOp::Store {
334 dst_val_ptr,
335 stored_val,
336 },
337 ..
338 }) => {
339 let local_var = get_validate_local_var(context, function, &dst_val_ptr);
340 match local_var {
341 Some((local, _)) if safe_locals.contains(&local) => {
342 name_stack.push(local.clone(), stored_val);
345 num_local_pushes
346 .entry(local)
347 .and_modify(|count| *count += 1)
348 .or_insert(1);
349 deletes.push((node, inst));
350 }
351 _ => (),
352 }
353 }
354 _ => (),
355 }
356 }
357
358 for BranchToWithArgs { block: succ, .. } in node.successors(context) {
360 let args: Vec<_> = succ.arg_iter(context).copied().collect();
361 for arg in args {
364 if let Some(local) = phi_to_local.get(&arg) {
365 let ptr = function.get_local_var(context, local).unwrap();
366 let new_val = match name_stack.get(local) {
367 Some(val) => *val,
368 None => {
369 let constant = *ptr
371 .get_initializer(context)
372 .expect("We're dealing with an uninitialized value");
373 Value::new_constant(context, constant)
374 }
375 };
376 let params = node.get_succ_params_mut(context, &succ).unwrap();
377 params.push(new_val);
378 }
379 }
380 }
381
382 for child in dom_tree.children(node) {
384 record_rewrites(
385 context,
386 function,
387 dom_tree,
388 child,
389 safe_locals,
390 phi_to_local,
391 name_stack,
392 rewrites,
393 deletes,
394 );
395 }
396
397 for (local, pushes) in num_local_pushes.iter() {
399 for _ in 0..*pushes {
400 name_stack.pop(local);
401 }
402 }
403 }
404
405 let mut name_stack = MappedStack::<String, Value>::default();
406 let mut value_replacement = FxHashMap::<Value, Value>::default();
407 let mut delete_insts = Vec::<(Block, Value)>::new();
408 record_rewrites(
409 context,
410 &function,
411 dom_tree,
412 function.get_entry_block(context),
413 &safe_locals,
414 &phi_to_local,
415 &mut name_stack,
416 &mut value_replacement,
417 &mut delete_insts,
418 );
419
420 function.replace_values(context, &value_replacement, None);
422 for (block, inst) in delete_insts {
424 block.remove_instruction(context, inst);
425 }
426
427 Ok(true)
428}