vicis_core/pass/transform/
mem2reg.rs

1use crate::{
2    ir::{
3        function::{
4            basic_block::{BasicBlock, BasicBlockId},
5            instruction::{Instruction, InstructionId, Opcode, Operand, Phi},
6            Function,
7        },
8        value::{Value, ValueId},
9    },
10    pass::{analysis::dom_tree, transform::sccp::SCCP, TransformPass},
11};
12use rustc_hash::{FxHashMap, FxHashSet};
13use std::{any::Any, cmp::Ordering, collections::BinaryHeap};
14
15pub struct Mem2RegPass;
16
17pub struct Mem2Reg<'a> {
18    func: &'a mut Function,
19    dom_tree: dom_tree::DominatorTree<BasicBlock>,
20    inst_indexes: InstructionIndexes,
21}
22
23type InstructionIndex = usize;
24
25struct InstructionIndexes(FxHashMap<InstructionId, InstructionIndex>);
26
27#[derive(Debug, Clone, Copy, Eq, PartialEq)]
28struct BlockLevel(usize, BasicBlockId);
29
30struct RenameData {
31    cur: BasicBlockId,
32    pred: Option<BasicBlockId>,
33    incoming: FxHashMap<InstructionId, ValueId>,
34}
35
36impl<'a> Mem2Reg<'a> {
37    pub fn new(func: &'a mut Function) -> Self {
38        Self {
39            dom_tree: dom_tree::DominatorTree::new(func),
40            inst_indexes: InstructionIndexes::default(),
41            func,
42        }
43    }
44
45    pub fn run(&mut self) {
46        let mut single_store_alloca_list = vec![];
47        let mut single_block_alloca_list = vec![];
48        let mut multi_block_alloca_list = vec![];
49
50        for block_id in self.func.layout.block_iter() {
51            for inst_id in self.func.layout.inst_iter(block_id) {
52                let inst = self.func.data.inst_ref(inst_id);
53                if !inst.opcode.is_alloca() {
54                    continue;
55                }
56
57                let alloca = inst;
58
59                let is_promotable = self.is_promotable(alloca);
60                debug!(is_promotable);
61
62                if !is_promotable {
63                    continue;
64                }
65
66                let is_stored_only_once = self.is_stored_only_once(alloca);
67                debug!(is_stored_only_once);
68
69                let is_only_used_in_single_block = self.is_only_used_in_single_block(alloca);
70                debug!(is_only_used_in_single_block);
71
72                if is_stored_only_once {
73                    single_store_alloca_list.push(inst_id);
74                    continue;
75                }
76
77                if is_only_used_in_single_block {
78                    single_block_alloca_list.push(inst_id);
79                    continue;
80                }
81
82                multi_block_alloca_list.push(inst_id);
83            }
84        }
85
86        for alloca in single_store_alloca_list {
87            self.promote_single_store_alloca(alloca);
88        }
89
90        for alloca in single_block_alloca_list {
91            self.promote_single_block_alloca(alloca);
92        }
93
94        let mut phi_to_alloca = FxHashMap::default();
95        let mut added_phis = FxHashMap::default();
96        for &alloca in &multi_block_alloca_list {
97            self.promote_multi_block_alloca(alloca, &mut phi_to_alloca, &mut added_phis);
98        }
99
100        self.rename(multi_block_alloca_list, phi_to_alloca, added_phis);
101
102        SCCP::new(self.func).run();
103    }
104
105    fn promote_single_store_alloca(&mut self, alloca_id: InstructionId) {
106        let mut src = None;
107        let mut store_to_remove = None;
108        let mut loads_to_remove = vec![];
109
110        for &user_id in self.func.data.users_of(alloca_id) {
111            let user = self.func.data.inst_ref(user_id);
112            match user.opcode {
113                Opcode::Load => loads_to_remove.push(user_id),
114                Opcode::Store => {
115                    src = Some(user.operand.as_store().unwrap().src_val());
116                    store_to_remove = Some(user_id);
117                }
118                _ => unreachable!(),
119            }
120        }
121
122        let src = src.unwrap();
123        let store_to_remove = store_to_remove.unwrap();
124        let store_idx = self.inst_indexes.get(self.func, store_to_remove);
125
126        let mut remove_all_loads = true;
127        loads_to_remove.retain(|&load_id| {
128            let load = self.func.data.inst_ref(load_id);
129            let store = self.func.data.inst_ref(store_to_remove);
130            let valid = if load.parent == store.parent {
131                let load_idx = self.inst_indexes.get(self.func, load_id);
132                store_idx < load_idx
133            } else {
134                self.dom_tree.dominates(store.parent, load.parent)
135            };
136            remove_all_loads &= valid;
137            valid
138        });
139
140        if remove_all_loads {
141            self.func.remove_inst(store_to_remove);
142            self.func.remove_inst(alloca_id);
143        }
144
145        for load_id in loads_to_remove {
146            self.func.remove_inst(load_id);
147            for user_id in self.func.data.users_of(load_id).clone() {
148                self.func.data.replace_inst_arg(user_id, load_id, src);
149            }
150        }
151    }
152
153    fn promote_single_block_alloca(&mut self, alloca_id: InstructionId) {
154        fn find_nearest_store(
155            store_indexes: &[(InstructionId, InstructionIndex)],
156            load_idx: InstructionIndex,
157        ) -> Option<InstructionId> {
158            let i = store_indexes
159                .binary_search_by(|(_, store_idx)| store_idx.cmp(&load_idx))
160                .unwrap_or_else(|x| x);
161            if i == 0 {
162                return None;
163            }
164            Some(store_indexes[i - 1].0)
165        }
166
167        let mut store_indexes = vec![];
168        let mut loads = vec![];
169
170        for &user_id in self.func.data.users_of(alloca_id) {
171            let user = self.func.data.inst_ref(user_id);
172            match user.opcode {
173                Opcode::Store => {
174                    store_indexes.push((user_id, self.inst_indexes.get(self.func, user_id)))
175                }
176                Opcode::Load => loads.push(user_id),
177                _ => unreachable!(),
178            }
179        }
180
181        store_indexes.sort_by(|(_, x), (_, y)| x.cmp(y));
182
183        let mut remove_all_access = true;
184        let mut stores_to_remove = vec![];
185
186        for load_id in loads {
187            let load_idx = self.inst_indexes.get(self.func, load_id);
188            let nearest_store_id = match find_nearest_store(&store_indexes, load_idx) {
189                Some(nearest_store_id) => nearest_store_id,
190                None => {
191                    remove_all_access = false;
192                    continue;
193                }
194            };
195            let nearest_store = self.func.data.inst_ref(nearest_store_id);
196            let src = nearest_store.operand.as_store().unwrap().src_val();
197
198            stores_to_remove.push(nearest_store_id);
199
200            self.func.remove_inst(load_id);
201            for user_id in self.func.data.users_of(load_id).clone() {
202                self.func.data.replace_inst_arg(user_id, load_id, src);
203            }
204        }
205
206        if remove_all_access {
207            self.func.remove_inst(alloca_id);
208        }
209
210        for store in stores_to_remove {
211            self.func.remove_inst(store);
212        }
213    }
214
215    fn promote_multi_block_alloca(
216        &mut self,
217        alloca_id: InstructionId,
218        phi_to_alloca: &mut FxHashMap<InstructionId, InstructionId>,
219        added_phis: &mut FxHashMap<BasicBlockId, Vec<InstructionId>>,
220    ) {
221        let mut def_blocks = vec![];
222        let mut use_blocks = vec![];
223        let mut livein_blocks = FxHashSet::default();
224
225        for &user_id in self.func.data.users_of(alloca_id) {
226            let user = self.func.data.inst_ref(user_id);
227            match user.opcode {
228                Opcode::Store => def_blocks.push(user.parent),
229                Opcode::Load => use_blocks.push(user.parent),
230                _ => unreachable!(),
231            }
232        }
233
234        let mut worklist = use_blocks;
235        while let Some(block) = worklist.pop() {
236            if !livein_blocks.insert(block) {
237                continue;
238            }
239            for pred in self.func.data.basic_blocks[block].preds() {
240                if def_blocks.contains(pred) {
241                    continue;
242                }
243                worklist.push(*pred)
244            }
245        }
246
247        let mut queue = def_blocks
248            .iter()
249            .map(|&def| BlockLevel(self.dom_tree.level_of(def).unwrap(), def))
250            .collect::<BinaryHeap<_>>();
251        let mut visited_worklist = FxHashSet::default();
252        let mut visited_queue = FxHashSet::default();
253
254        while let Some(BlockLevel(root_level, root_block_id)) = queue.pop() {
255            let mut worklist = vec![root_block_id];
256            visited_worklist.insert(root_block_id);
257
258            while let Some(block_id) = worklist.pop() {
259                let block = &self.func.data.basic_blocks[block_id];
260                for succ_id in block.succs().clone() {
261                    let succ_level = self.dom_tree.level_of(succ_id).unwrap();
262                    if succ_level > root_level {
263                        continue;
264                    }
265                    if !visited_queue.insert(succ_id) {
266                        continue;
267                    }
268                    if !livein_blocks.contains(&succ_id) {
269                        continue;
270                    }
271
272                    {
273                        let ty = self
274                            .func
275                            .data
276                            .inst_ref(alloca_id)
277                            .operand
278                            .as_alloca()
279                            .unwrap()
280                            .ty();
281                        let phi = Opcode::Phi
282                            .with_block(succ_id)
283                            .with_operand(Operand::Phi(Phi {
284                                ty,
285                                args: vec![],
286                                blocks: vec![],
287                            }));
288                        let phi_id = self.func.data.create_inst(phi);
289                        self.func.layout.insert_inst_at_start(phi_id, succ_id);
290                        added_phis
291                            .entry(succ_id)
292                            .or_insert_with(Vec::new)
293                            .push(phi_id);
294                        phi_to_alloca.insert(phi_id, alloca_id);
295                    }
296
297                    if !def_blocks.contains(&succ_id) {
298                        queue.push(BlockLevel(succ_level, succ_id));
299                    }
300                }
301
302                if let Some(dom_children) = self.dom_tree.children_of(block_id) {
303                    for child in dom_children {
304                        if visited_worklist.insert(*child) {
305                            worklist.push(*child);
306                        }
307                    }
308                }
309            }
310        }
311    }
312
313    fn rename(
314        &mut self,
315        alloca_list: Vec<InstructionId>,
316        phi_to_alloca: FxHashMap<InstructionId, InstructionId>,
317        mut added_phis: FxHashMap<BasicBlockId, Vec<InstructionId>>,
318    ) {
319        let entry = self.func.layout.first_block.unwrap();
320
321        let mut visited = FxHashSet::default();
322        let mut worklist = vec![RenameData {
323            cur: entry,
324            pred: None,
325            incoming: FxHashMap::default(),
326        }];
327
328        while let Some(data) = worklist.pop() {
329            self.rename_sub(
330                &alloca_list,
331                &phi_to_alloca,
332                &mut worklist,
333                &mut added_phis,
334                &mut visited,
335                data,
336            );
337        }
338
339        for alloca_id in alloca_list {
340            self.func.remove_inst(alloca_id);
341        }
342    }
343
344    fn rename_sub(
345        &mut self,
346        alloca_list: &[InstructionId],
347        phi_to_alloca: &FxHashMap<InstructionId, InstructionId>,
348        worklist: &mut Vec<RenameData>,
349        added_phis: &mut FxHashMap<BasicBlockId, Vec<InstructionId>>,
350        visited: &mut FxHashSet<BasicBlockId>,
351        mut data: RenameData,
352    ) {
353        loop {
354            for phi_id in added_phis.get(&data.cur).unwrap_or(&vec![]) {
355                let alloca_id = phi_to_alloca[phi_id];
356                let incoming_id = data
357                    .incoming
358                    .get_mut(&alloca_id)
359                    .expect("TODO: return undef");
360                let phi = self.func.data.inst_ref_mut(*phi_id);
361                let phi = phi.operand.as_phi_mut().unwrap();
362                phi.args_mut().push(*incoming_id);
363                phi.blocks_mut().push(data.pred.unwrap());
364                self.func.data.validate_inst_uses(*phi_id);
365                *incoming_id = self.func.data.create_value(Value::Instruction(*phi_id));
366            }
367
368            if !visited.insert(data.cur) {
369                break;
370            }
371
372            let mut removal_list = vec![];
373
374            for inst_id in self.func.layout.inst_iter(data.cur) {
375                let inst = self.func.data.inst_ref(inst_id);
376                let alloca_id = *self
377                    .func
378                    .data
379                    .value_ref(match inst.opcode {
380                        Opcode::Store => inst.operand.as_store().unwrap().dst_val(),
381                        Opcode::Load => inst.operand.as_load().unwrap().src_val(),
382                        _ => continue,
383                    })
384                    .as_inst();
385                if !alloca_list.contains(&alloca_id) {
386                    continue;
387                }
388                match inst.opcode {
389                    Opcode::Store => {
390                        data.incoming
391                            .insert(alloca_id, inst.operand.as_store().unwrap().src_val());
392                    }
393                    Opcode::Load => {
394                        if let Some(val) = data.incoming.get(&alloca_id) {
395                            self.func.data.replace_all_uses(inst_id, *val);
396                        }
397                    }
398                    _ => unreachable!(),
399                }
400                removal_list.push(inst_id);
401            }
402
403            for remove in removal_list {
404                self.func.remove_inst(remove);
405            }
406
407            let block = &self.func.data.basic_blocks[data.cur];
408
409            if block.succs().is_empty() {
410                break;
411            }
412
413            data.pred = Some(data.cur);
414            let mut succ_iter = block.succs().iter();
415            data.cur = *succ_iter.next().unwrap();
416            for succ in succ_iter {
417                worklist.push(RenameData {
418                    cur: *succ,
419                    pred: data.pred,
420                    incoming: data.incoming.clone(),
421                })
422            }
423        }
424    }
425
426    fn is_promotable(&self, alloca: &Instruction) -> bool {
427        let alloca_id = alloca.id.unwrap();
428        let alloca = alloca.operand.as_alloca().unwrap();
429        let ty = alloca.ty();
430        (ty.is_primitive() || ty.is_pointer(&self.func.types))
431            && self.func.data.users_of(alloca_id).iter().all(|&user_id| {
432                let user = self.func.data.inst_ref(user_id);
433                user.opcode.is_load()
434                    || (user.opcode.is_store() && {
435                        let dst_id = user.operand.as_store().unwrap().dst_val();
436                        let dst = self.func.data.value_ref(dst_id);
437                        matches!(dst, Value::Instruction(id) if id == &alloca_id)
438                    })
439            })
440    }
441
442    fn is_stored_only_once(&self, alloca: &Instruction) -> bool {
443        let alloca_id = alloca.id.unwrap();
444        self.func
445            .data
446            .users_of(alloca_id)
447            .iter()
448            .fold(0usize, |acc, &user_id| {
449                let user = self.func.data.inst_ref(user_id);
450                user.opcode.is_store() as usize + acc
451            })
452            == 1
453    }
454
455    fn is_only_used_in_single_block(&self, alloca: &Instruction) -> bool {
456        let alloca_id = alloca.id.unwrap();
457        let mut last_parent = None;
458        self.func.data.users_of(alloca_id).iter().all(|&user_id| {
459            let user = self.func.data.inst_ref(user_id);
460            let eq = last_parent.get_or_insert(user.parent) == &user.parent;
461            last_parent = Some(user.parent);
462            eq
463        })
464    }
465}
466
467impl Default for InstructionIndexes {
468    fn default() -> Self {
469        Self(FxHashMap::default())
470    }
471}
472
473impl InstructionIndexes {
474    pub fn get(&mut self, func: &Function, inst_id: InstructionId) -> InstructionIndex {
475        if let Some(idx) = self.0.get(&inst_id) {
476            return *idx;
477        }
478
479        let inst = func.data.inst_ref(inst_id);
480        for (i, inst_id) in func.layout.inst_iter(inst.parent).enumerate() {
481            let opcode = func.data.inst_ref(inst_id).opcode;
482            let is_interesting = opcode.is_store() || opcode.is_load() || opcode.is_alloca();
483            if is_interesting {
484                self.0.insert(inst_id, i);
485            }
486        }
487
488        self.get(func, inst_id)
489    }
490}
491
492impl TransformPass<Function> for Mem2RegPass {
493    fn run_on(&self, func: &mut Function, _result: &mut Box<dyn Any>) {
494        Mem2Reg::new(func).run();
495    }
496}
497
498impl Ord for BlockLevel {
499    fn cmp(&self, other: &Self) -> Ordering {
500        self.0.cmp(&other.0)
501    }
502}
503
504impl PartialOrd for BlockLevel {
505    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
506        Some(self.cmp(other))
507    }
508}