sway_ir/optimize/
mem2reg.rs

1use indexmap::IndexMap;
2/// Promote local memory to SSA registers.
3/// This pass is essentially SSA construction. A good readable reference is:
4/// https://www.cs.princeton.edu/~appel/modern/c/
5/// We use block arguments instead of explicit PHI nodes. Conceptually,
6/// they are both the same.
7use rustc_hash::FxHashMap;
8use std::collections::HashSet;
9use sway_utils::mapped_stack::MappedStack;
10
11use crate::{
12    AnalysisResults, Block, BranchToWithArgs, Constant, Context, DomFronts, DomTree, Function,
13    InstOp, Instruction, IrError, LocalVar, Pass, PassMutability, PostOrder, ScopedPass, Type,
14    Value, ValueDatum, DOMINATORS_NAME, DOM_FRONTS_NAME, POSTORDER_NAME,
15};
16
17pub const MEM2REG_NAME: &str = "mem2reg";
18
19pub fn create_mem2reg_pass() -> Pass {
20    Pass {
21        name: MEM2REG_NAME,
22        descr: "Promotion of memory to SSA registers",
23        deps: vec![POSTORDER_NAME, DOMINATORS_NAME, DOM_FRONTS_NAME],
24        runner: ScopedPass::FunctionPass(PassMutability::Transform(promote_to_registers)),
25    }
26}
27
28// Check if a value is a valid (for our optimization) local pointer
29fn get_validate_local_var(
30    context: &Context,
31    function: &Function,
32    val: &Value,
33) -> Option<(String, LocalVar)> {
34    match context.values[val.0].value {
35        ValueDatum::Instruction(Instruction {
36            op: InstOp::GetLocal(local_var),
37            ..
38        }) => {
39            let name = function.lookup_local_name(context, &local_var);
40            name.map(|name| (name.clone(), local_var))
41        }
42        _ => None,
43    }
44}
45
46fn is_promotable_type(context: &Context, ty: Type) -> bool {
47    ty.is_unit(context)
48        || ty.is_bool(context)
49        || ty.is_ptr(context)
50        || (ty.is_uint(context) && ty.get_uint_width(context).unwrap() <= 64)
51}
52
53// Returns those locals that can be promoted to SSA registers.
54fn filter_usable_locals(context: &mut Context, function: &Function) -> HashSet<String> {
55    // The size of an SSA register is target specific.  Here we're going to just stick with atomic
56    // types which can fit in 64-bits.
57    let mut locals: HashSet<String> = function
58        .locals_iter(context)
59        .filter_map(|(name, var)| {
60            let ty = var.get_inner_type(context);
61            is_promotable_type(context, ty).then_some(name.clone())
62        })
63        .collect();
64
65    for (_, inst) in function.instruction_iter(context) {
66        match context.values[inst.0].value {
67            ValueDatum::Instruction(Instruction {
68                op: InstOp::Load(_),
69                ..
70            }) => {}
71            ValueDatum::Instruction(Instruction {
72                op:
73                    InstOp::Store {
74                        dst_val_ptr: _,
75                        stored_val,
76                    },
77                ..
78            }) => {
79                // Make sure that a local's address isn't stored.
80                // E.g., in cases like `let r = &some_local;`.
81                if let Some((local, _)) = get_validate_local_var(context, function, &stored_val) {
82                    locals.remove(&local);
83                }
84            }
85            _ => {
86                // Make sure that no local escapes into instructions we don't understand.
87                let operands = inst.get_instruction(context).unwrap().op.get_operands();
88                for opd in operands {
89                    if let Some((local, ..)) = get_validate_local_var(context, function, &opd) {
90                        locals.remove(&local);
91                    }
92                }
93            }
94        }
95    }
96    locals
97}
98
99// For each block, compute the set of locals that are live-in.
100// TODO: Use rustc_index::bit_set::ChunkedBitSet by mapping local names to indices.
101//       This will allow more efficient set operations.
102pub fn compute_livein(
103    context: &mut Context,
104    function: &Function,
105    po: &PostOrder,
106    locals: &HashSet<String>,
107) -> FxHashMap<Block, HashSet<String>> {
108    let mut result = FxHashMap::<Block, HashSet<String>>::default();
109    for block in &po.po_to_block {
110        result.insert(*block, HashSet::<String>::default());
111    }
112
113    let mut changed = true;
114    while changed {
115        changed = false;
116        for block in &po.po_to_block {
117            // we begin by unioning the liveins at successor blocks.
118            let mut cur_live = HashSet::<String>::default();
119            for BranchToWithArgs { block: succ, .. } in block.successors(context) {
120                let succ_livein = &result[&succ];
121                cur_live.extend(succ_livein.iter().cloned());
122            }
123            // Scan the instructions, in reverse.
124            for inst in block.instruction_iter(context).rev() {
125                match context.values[inst.0].value {
126                    ValueDatum::Instruction(Instruction {
127                        op: InstOp::Load(ptr),
128                        ..
129                    }) => {
130                        let local_var = get_validate_local_var(context, function, &ptr);
131                        match local_var {
132                            Some((local, ..)) if locals.contains(&local) => {
133                                cur_live.insert(local);
134                            }
135                            _ => {}
136                        }
137                    }
138                    ValueDatum::Instruction(Instruction {
139                        op: InstOp::Store { dst_val_ptr, .. },
140                        ..
141                    }) => {
142                        let local_var = get_validate_local_var(context, function, &dst_val_ptr);
143                        match local_var {
144                            Some((local, _)) if locals.contains(&local) => {
145                                cur_live.remove(&local);
146                            }
147                            _ => (),
148                        }
149                    }
150                    _ => (),
151                }
152            }
153            if result[block] != cur_live {
154                // Whatever's live now, is the live-in for the block.
155                result.get_mut(block).unwrap().extend(cur_live);
156                changed = true;
157            }
158        }
159    }
160    result
161}
162
163/// Promote loads of globals constants to SSA registers
164/// We promote only non-mutable globals of copy types
165fn promote_globals(context: &mut Context, function: &Function) -> Result<bool, IrError> {
166    let mut replacements = FxHashMap::<Value, Constant>::default();
167    for (_, inst) in function.instruction_iter(context) {
168        if let ValueDatum::Instruction(Instruction {
169            op: InstOp::Load(ptr),
170            ..
171        }) = context.values[inst.0].value
172        {
173            if let ValueDatum::Instruction(Instruction {
174                op: InstOp::GetGlobal(global_var),
175                ..
176            }) = context.values[ptr.0].value
177            {
178                if !global_var.is_mutable(context)
179                    && is_promotable_type(context, global_var.get_inner_type(context))
180                {
181                    let constant = *global_var
182                        .get_initializer(context)
183                        .expect("`global_var` is not mutable so it must be initialized");
184                    replacements.insert(inst, constant);
185                }
186            }
187        }
188    }
189
190    if replacements.is_empty() {
191        return Ok(false);
192    }
193
194    let replacements = replacements
195        .into_iter()
196        .map(|(k, v)| (k, Value::new_constant(context, v)))
197        .collect::<FxHashMap<_, _>>();
198
199    function.replace_values(context, &replacements, None);
200
201    Ok(true)
202}
203
204/// Promote memory values that are accessed via load/store to SSA registers.
205pub fn promote_to_registers(
206    context: &mut Context,
207    analyses: &AnalysisResults,
208    function: Function,
209) -> Result<bool, IrError> {
210    let mut modified = false;
211    modified |= promote_globals(context, &function)?;
212    modified |= promote_locals(context, analyses, function)?;
213    Ok(modified)
214}
215
216/// Promote locals to registers. We promote only locals of copy types,
217/// whose every use is in a `get_local` without offsets, and the result of
218/// such a `get_local` is used only in a load or a store.
219pub fn promote_locals(
220    context: &mut Context,
221    analyses: &AnalysisResults,
222    function: Function,
223) -> Result<bool, IrError> {
224    let safe_locals = filter_usable_locals(context, &function);
225    if safe_locals.is_empty() {
226        return Ok(false);
227    }
228
229    let po: &PostOrder = analyses.get_analysis_result(function);
230    let dom_tree: &DomTree = analyses.get_analysis_result(function);
231    let dom_fronts: &DomFronts = analyses.get_analysis_result(function);
232    let liveins = compute_livein(context, &function, po, &safe_locals);
233
234    // A list of the PHIs we insert in this transform.
235    let mut new_phi_tracker = HashSet::<(String, Block)>::new();
236    // A map from newly inserted block args to the Local that it's a PHI for.
237    let mut worklist = Vec::<(String, Type, Block)>::new();
238    let mut phi_to_local = FxHashMap::<Value, String>::default();
239    // Insert PHIs for each definition (store) at its dominance frontiers.
240    // Start by adding the existing definitions (stores) to a worklist,
241    // in program order (reverse post order). This is for faster convergence (or maybe not).
242    for (block, inst) in po
243        .po_to_block
244        .iter()
245        .rev()
246        .flat_map(|b| b.instruction_iter(context).map(|i| (*b, i)))
247    {
248        if let ValueDatum::Instruction(Instruction {
249            op: InstOp::Store { dst_val_ptr, .. },
250            ..
251        }) = context.values[inst.0].value
252        {
253            match get_validate_local_var(context, &function, &dst_val_ptr) {
254                Some((local, var)) if safe_locals.contains(&local) => {
255                    worklist.push((local, var.get_inner_type(context), block));
256                }
257                _ => (),
258            }
259        }
260    }
261    // Transitively add PHIs, till nothing more to do.
262    while let Some((local, ty, known_def)) = worklist.pop() {
263        for df in dom_fronts[&known_def].iter() {
264            if !new_phi_tracker.contains(&(local.clone(), *df)) && liveins[df].contains(&local) {
265                // Insert PHI for this local at block df.
266                let index = df.new_arg(context, ty);
267                phi_to_local.insert(df.get_arg(context, index).unwrap(), local.clone());
268                new_phi_tracker.insert((local.clone(), *df));
269                // Add df to the worklist.
270                worklist.push((local.clone(), ty, *df));
271            }
272        }
273    }
274
275    // We're just left with rewriting the loads and stores into SSA.
276    // For efficiency, we first collect the rewrites
277    // and then apply them all together in the next step.
278    #[allow(clippy::too_many_arguments)]
279    fn record_rewrites(
280        context: &mut Context,
281        function: &Function,
282        dom_tree: &DomTree,
283        node: Block,
284        safe_locals: &HashSet<String>,
285        phi_to_local: &FxHashMap<Value, String>,
286        name_stack: &mut MappedStack<String, Value>,
287        rewrites: &mut FxHashMap<Value, Value>,
288        deletes: &mut Vec<(Block, Value)>,
289    ) {
290        // Whatever new definitions we find in this block, they must be popped
291        // when we're done. So let's keep track of that locally as a count.
292        let mut num_local_pushes = IndexMap::<String, u32>::new();
293
294        // Start with relevant block args, they are new definitions.
295        for arg in node.arg_iter(context) {
296            if let Some(local) = phi_to_local.get(arg) {
297                name_stack.push(local.clone(), *arg);
298                num_local_pushes
299                    .entry(local.clone())
300                    .and_modify(|count| *count += 1)
301                    .or_insert(1);
302            }
303        }
304
305        for inst in node.instruction_iter(context) {
306            match context.values[inst.0].value {
307                ValueDatum::Instruction(Instruction {
308                    op: InstOp::Load(ptr),
309                    ..
310                }) => {
311                    let local_var = get_validate_local_var(context, function, &ptr);
312                    match local_var {
313                        Some((local, var)) if safe_locals.contains(&local) => {
314                            // We should replace all uses of inst with new_stack[local].
315                            let new_val = match name_stack.get(&local) {
316                                Some(val) => *val,
317                                None => {
318                                    // Nothing on the stack, let's attempt to get the initializer
319                                    let constant = *var
320                                        .get_initializer(context)
321                                        .expect("We're dealing with an uninitialized value");
322                                    Value::new_constant(context, constant)
323                                }
324                            };
325                            rewrites.insert(inst, new_val);
326                            deletes.push((node, inst));
327                        }
328                        _ => (),
329                    }
330                }
331                ValueDatum::Instruction(Instruction {
332                    op:
333                        InstOp::Store {
334                            dst_val_ptr,
335                            stored_val,
336                        },
337                    ..
338                }) => {
339                    let local_var = get_validate_local_var(context, function, &dst_val_ptr);
340                    match local_var {
341                        Some((local, _)) if safe_locals.contains(&local) => {
342                            // Henceforth, everything that's dominated by this inst must use stored_val
343                            // instead of loading from dst_val.
344                            name_stack.push(local.clone(), stored_val);
345                            num_local_pushes
346                                .entry(local)
347                                .and_modify(|count| *count += 1)
348                                .or_insert(1);
349                            deletes.push((node, inst));
350                        }
351                        _ => (),
352                    }
353                }
354                _ => (),
355            }
356        }
357
358        // Update arguments to successor blocks (i.e., PHI args).
359        for BranchToWithArgs { block: succ, .. } in node.successors(context) {
360            let args: Vec<_> = succ.arg_iter(context).copied().collect();
361            // For every arg of succ, if it's in phi_to_local,
362            // we pass, as arg, the top value of local
363            for arg in args {
364                if let Some(local) = phi_to_local.get(&arg) {
365                    let ptr = function.get_local_var(context, local).unwrap();
366                    let new_val = match name_stack.get(local) {
367                        Some(val) => *val,
368                        None => {
369                            // Nothing on the stack, let's attempt to get the initializer
370                            let constant = *ptr
371                                .get_initializer(context)
372                                .expect("We're dealing with an uninitialized value");
373                            Value::new_constant(context, constant)
374                        }
375                    };
376                    let params = node.get_succ_params_mut(context, &succ).unwrap();
377                    params.push(new_val);
378                }
379            }
380        }
381
382        // Process dominator children.
383        for child in dom_tree.children(node) {
384            record_rewrites(
385                context,
386                function,
387                dom_tree,
388                child,
389                safe_locals,
390                phi_to_local,
391                name_stack,
392                rewrites,
393                deletes,
394            );
395        }
396
397        // Pop from the names stack.
398        for (local, pushes) in num_local_pushes.iter() {
399            for _ in 0..*pushes {
400                name_stack.pop(local);
401            }
402        }
403    }
404
405    let mut name_stack = MappedStack::<String, Value>::default();
406    let mut value_replacement = FxHashMap::<Value, Value>::default();
407    let mut delete_insts = Vec::<(Block, Value)>::new();
408    record_rewrites(
409        context,
410        &function,
411        dom_tree,
412        function.get_entry_block(context),
413        &safe_locals,
414        &phi_to_local,
415        &mut name_stack,
416        &mut value_replacement,
417        &mut delete_insts,
418    );
419
420    // Apply the rewrites.
421    function.replace_values(context, &value_replacement, None);
422    // Delete the loads and stores.
423    for (block, inst) in delete_insts {
424        block.remove_instruction(context, inst);
425    }
426
427    Ok(true)
428}