vicis_core/pass/transform/
mem2reg.rs1use crate::{
2 ir::{
3 function::{
4 basic_block::{BasicBlock, BasicBlockId},
5 instruction::{Instruction, InstructionId, Opcode, Operand, Phi},
6 Function,
7 },
8 value::{Value, ValueId},
9 },
10 pass::{analysis::dom_tree, transform::sccp::SCCP, TransformPass},
11};
12use rustc_hash::{FxHashMap, FxHashSet};
13use std::{any::Any, cmp::Ordering, collections::BinaryHeap};
14
15pub struct Mem2RegPass;
16
17pub struct Mem2Reg<'a> {
18 func: &'a mut Function,
19 dom_tree: dom_tree::DominatorTree<BasicBlock>,
20 inst_indexes: InstructionIndexes,
21}
22
23type InstructionIndex = usize;
24
25struct InstructionIndexes(FxHashMap<InstructionId, InstructionIndex>);
26
27#[derive(Debug, Clone, Copy, Eq, PartialEq)]
28struct BlockLevel(usize, BasicBlockId);
29
30struct RenameData {
31 cur: BasicBlockId,
32 pred: Option<BasicBlockId>,
33 incoming: FxHashMap<InstructionId, ValueId>,
34}
35
36impl<'a> Mem2Reg<'a> {
37 pub fn new(func: &'a mut Function) -> Self {
38 Self {
39 dom_tree: dom_tree::DominatorTree::new(func),
40 inst_indexes: InstructionIndexes::default(),
41 func,
42 }
43 }
44
45 pub fn run(&mut self) {
46 let mut single_store_alloca_list = vec![];
47 let mut single_block_alloca_list = vec![];
48 let mut multi_block_alloca_list = vec![];
49
50 for block_id in self.func.layout.block_iter() {
51 for inst_id in self.func.layout.inst_iter(block_id) {
52 let inst = self.func.data.inst_ref(inst_id);
53 if !inst.opcode.is_alloca() {
54 continue;
55 }
56
57 let alloca = inst;
58
59 let is_promotable = self.is_promotable(alloca);
60 debug!(is_promotable);
61
62 if !is_promotable {
63 continue;
64 }
65
66 let is_stored_only_once = self.is_stored_only_once(alloca);
67 debug!(is_stored_only_once);
68
69 let is_only_used_in_single_block = self.is_only_used_in_single_block(alloca);
70 debug!(is_only_used_in_single_block);
71
72 if is_stored_only_once {
73 single_store_alloca_list.push(inst_id);
74 continue;
75 }
76
77 if is_only_used_in_single_block {
78 single_block_alloca_list.push(inst_id);
79 continue;
80 }
81
82 multi_block_alloca_list.push(inst_id);
83 }
84 }
85
86 for alloca in single_store_alloca_list {
87 self.promote_single_store_alloca(alloca);
88 }
89
90 for alloca in single_block_alloca_list {
91 self.promote_single_block_alloca(alloca);
92 }
93
94 let mut phi_to_alloca = FxHashMap::default();
95 let mut added_phis = FxHashMap::default();
96 for &alloca in &multi_block_alloca_list {
97 self.promote_multi_block_alloca(alloca, &mut phi_to_alloca, &mut added_phis);
98 }
99
100 self.rename(multi_block_alloca_list, phi_to_alloca, added_phis);
101
102 SCCP::new(self.func).run();
103 }
104
105 fn promote_single_store_alloca(&mut self, alloca_id: InstructionId) {
106 let mut src = None;
107 let mut store_to_remove = None;
108 let mut loads_to_remove = vec![];
109
110 for &user_id in self.func.data.users_of(alloca_id) {
111 let user = self.func.data.inst_ref(user_id);
112 match user.opcode {
113 Opcode::Load => loads_to_remove.push(user_id),
114 Opcode::Store => {
115 src = Some(user.operand.as_store().unwrap().src_val());
116 store_to_remove = Some(user_id);
117 }
118 _ => unreachable!(),
119 }
120 }
121
122 let src = src.unwrap();
123 let store_to_remove = store_to_remove.unwrap();
124 let store_idx = self.inst_indexes.get(self.func, store_to_remove);
125
126 let mut remove_all_loads = true;
127 loads_to_remove.retain(|&load_id| {
128 let load = self.func.data.inst_ref(load_id);
129 let store = self.func.data.inst_ref(store_to_remove);
130 let valid = if load.parent == store.parent {
131 let load_idx = self.inst_indexes.get(self.func, load_id);
132 store_idx < load_idx
133 } else {
134 self.dom_tree.dominates(store.parent, load.parent)
135 };
136 remove_all_loads &= valid;
137 valid
138 });
139
140 if remove_all_loads {
141 self.func.remove_inst(store_to_remove);
142 self.func.remove_inst(alloca_id);
143 }
144
145 for load_id in loads_to_remove {
146 self.func.remove_inst(load_id);
147 for user_id in self.func.data.users_of(load_id).clone() {
148 self.func.data.replace_inst_arg(user_id, load_id, src);
149 }
150 }
151 }
152
153 fn promote_single_block_alloca(&mut self, alloca_id: InstructionId) {
154 fn find_nearest_store(
155 store_indexes: &[(InstructionId, InstructionIndex)],
156 load_idx: InstructionIndex,
157 ) -> Option<InstructionId> {
158 let i = store_indexes
159 .binary_search_by(|(_, store_idx)| store_idx.cmp(&load_idx))
160 .unwrap_or_else(|x| x);
161 if i == 0 {
162 return None;
163 }
164 Some(store_indexes[i - 1].0)
165 }
166
167 let mut store_indexes = vec![];
168 let mut loads = vec![];
169
170 for &user_id in self.func.data.users_of(alloca_id) {
171 let user = self.func.data.inst_ref(user_id);
172 match user.opcode {
173 Opcode::Store => {
174 store_indexes.push((user_id, self.inst_indexes.get(self.func, user_id)))
175 }
176 Opcode::Load => loads.push(user_id),
177 _ => unreachable!(),
178 }
179 }
180
181 store_indexes.sort_by(|(_, x), (_, y)| x.cmp(y));
182
183 let mut remove_all_access = true;
184 let mut stores_to_remove = vec![];
185
186 for load_id in loads {
187 let load_idx = self.inst_indexes.get(self.func, load_id);
188 let nearest_store_id = match find_nearest_store(&store_indexes, load_idx) {
189 Some(nearest_store_id) => nearest_store_id,
190 None => {
191 remove_all_access = false;
192 continue;
193 }
194 };
195 let nearest_store = self.func.data.inst_ref(nearest_store_id);
196 let src = nearest_store.operand.as_store().unwrap().src_val();
197
198 stores_to_remove.push(nearest_store_id);
199
200 self.func.remove_inst(load_id);
201 for user_id in self.func.data.users_of(load_id).clone() {
202 self.func.data.replace_inst_arg(user_id, load_id, src);
203 }
204 }
205
206 if remove_all_access {
207 self.func.remove_inst(alloca_id);
208 }
209
210 for store in stores_to_remove {
211 self.func.remove_inst(store);
212 }
213 }
214
215 fn promote_multi_block_alloca(
216 &mut self,
217 alloca_id: InstructionId,
218 phi_to_alloca: &mut FxHashMap<InstructionId, InstructionId>,
219 added_phis: &mut FxHashMap<BasicBlockId, Vec<InstructionId>>,
220 ) {
221 let mut def_blocks = vec![];
222 let mut use_blocks = vec![];
223 let mut livein_blocks = FxHashSet::default();
224
225 for &user_id in self.func.data.users_of(alloca_id) {
226 let user = self.func.data.inst_ref(user_id);
227 match user.opcode {
228 Opcode::Store => def_blocks.push(user.parent),
229 Opcode::Load => use_blocks.push(user.parent),
230 _ => unreachable!(),
231 }
232 }
233
234 let mut worklist = use_blocks;
235 while let Some(block) = worklist.pop() {
236 if !livein_blocks.insert(block) {
237 continue;
238 }
239 for pred in self.func.data.basic_blocks[block].preds() {
240 if def_blocks.contains(pred) {
241 continue;
242 }
243 worklist.push(*pred)
244 }
245 }
246
247 let mut queue = def_blocks
248 .iter()
249 .map(|&def| BlockLevel(self.dom_tree.level_of(def).unwrap(), def))
250 .collect::<BinaryHeap<_>>();
251 let mut visited_worklist = FxHashSet::default();
252 let mut visited_queue = FxHashSet::default();
253
254 while let Some(BlockLevel(root_level, root_block_id)) = queue.pop() {
255 let mut worklist = vec![root_block_id];
256 visited_worklist.insert(root_block_id);
257
258 while let Some(block_id) = worklist.pop() {
259 let block = &self.func.data.basic_blocks[block_id];
260 for succ_id in block.succs().clone() {
261 let succ_level = self.dom_tree.level_of(succ_id).unwrap();
262 if succ_level > root_level {
263 continue;
264 }
265 if !visited_queue.insert(succ_id) {
266 continue;
267 }
268 if !livein_blocks.contains(&succ_id) {
269 continue;
270 }
271
272 {
273 let ty = self
274 .func
275 .data
276 .inst_ref(alloca_id)
277 .operand
278 .as_alloca()
279 .unwrap()
280 .ty();
281 let phi = Opcode::Phi
282 .with_block(succ_id)
283 .with_operand(Operand::Phi(Phi {
284 ty,
285 args: vec![],
286 blocks: vec![],
287 }));
288 let phi_id = self.func.data.create_inst(phi);
289 self.func.layout.insert_inst_at_start(phi_id, succ_id);
290 added_phis
291 .entry(succ_id)
292 .or_insert_with(Vec::new)
293 .push(phi_id);
294 phi_to_alloca.insert(phi_id, alloca_id);
295 }
296
297 if !def_blocks.contains(&succ_id) {
298 queue.push(BlockLevel(succ_level, succ_id));
299 }
300 }
301
302 if let Some(dom_children) = self.dom_tree.children_of(block_id) {
303 for child in dom_children {
304 if visited_worklist.insert(*child) {
305 worklist.push(*child);
306 }
307 }
308 }
309 }
310 }
311 }
312
313 fn rename(
314 &mut self,
315 alloca_list: Vec<InstructionId>,
316 phi_to_alloca: FxHashMap<InstructionId, InstructionId>,
317 mut added_phis: FxHashMap<BasicBlockId, Vec<InstructionId>>,
318 ) {
319 let entry = self.func.layout.first_block.unwrap();
320
321 let mut visited = FxHashSet::default();
322 let mut worklist = vec![RenameData {
323 cur: entry,
324 pred: None,
325 incoming: FxHashMap::default(),
326 }];
327
328 while let Some(data) = worklist.pop() {
329 self.rename_sub(
330 &alloca_list,
331 &phi_to_alloca,
332 &mut worklist,
333 &mut added_phis,
334 &mut visited,
335 data,
336 );
337 }
338
339 for alloca_id in alloca_list {
340 self.func.remove_inst(alloca_id);
341 }
342 }
343
344 fn rename_sub(
345 &mut self,
346 alloca_list: &[InstructionId],
347 phi_to_alloca: &FxHashMap<InstructionId, InstructionId>,
348 worklist: &mut Vec<RenameData>,
349 added_phis: &mut FxHashMap<BasicBlockId, Vec<InstructionId>>,
350 visited: &mut FxHashSet<BasicBlockId>,
351 mut data: RenameData,
352 ) {
353 loop {
354 for phi_id in added_phis.get(&data.cur).unwrap_or(&vec![]) {
355 let alloca_id = phi_to_alloca[phi_id];
356 let incoming_id = data
357 .incoming
358 .get_mut(&alloca_id)
359 .expect("TODO: return undef");
360 let phi = self.func.data.inst_ref_mut(*phi_id);
361 let phi = phi.operand.as_phi_mut().unwrap();
362 phi.args_mut().push(*incoming_id);
363 phi.blocks_mut().push(data.pred.unwrap());
364 self.func.data.validate_inst_uses(*phi_id);
365 *incoming_id = self.func.data.create_value(Value::Instruction(*phi_id));
366 }
367
368 if !visited.insert(data.cur) {
369 break;
370 }
371
372 let mut removal_list = vec![];
373
374 for inst_id in self.func.layout.inst_iter(data.cur) {
375 let inst = self.func.data.inst_ref(inst_id);
376 let alloca_id = *self
377 .func
378 .data
379 .value_ref(match inst.opcode {
380 Opcode::Store => inst.operand.as_store().unwrap().dst_val(),
381 Opcode::Load => inst.operand.as_load().unwrap().src_val(),
382 _ => continue,
383 })
384 .as_inst();
385 if !alloca_list.contains(&alloca_id) {
386 continue;
387 }
388 match inst.opcode {
389 Opcode::Store => {
390 data.incoming
391 .insert(alloca_id, inst.operand.as_store().unwrap().src_val());
392 }
393 Opcode::Load => {
394 if let Some(val) = data.incoming.get(&alloca_id) {
395 self.func.data.replace_all_uses(inst_id, *val);
396 }
397 }
398 _ => unreachable!(),
399 }
400 removal_list.push(inst_id);
401 }
402
403 for remove in removal_list {
404 self.func.remove_inst(remove);
405 }
406
407 let block = &self.func.data.basic_blocks[data.cur];
408
409 if block.succs().is_empty() {
410 break;
411 }
412
413 data.pred = Some(data.cur);
414 let mut succ_iter = block.succs().iter();
415 data.cur = *succ_iter.next().unwrap();
416 for succ in succ_iter {
417 worklist.push(RenameData {
418 cur: *succ,
419 pred: data.pred,
420 incoming: data.incoming.clone(),
421 })
422 }
423 }
424 }
425
426 fn is_promotable(&self, alloca: &Instruction) -> bool {
427 let alloca_id = alloca.id.unwrap();
428 let alloca = alloca.operand.as_alloca().unwrap();
429 let ty = alloca.ty();
430 (ty.is_primitive() || ty.is_pointer(&self.func.types))
431 && self.func.data.users_of(alloca_id).iter().all(|&user_id| {
432 let user = self.func.data.inst_ref(user_id);
433 user.opcode.is_load()
434 || (user.opcode.is_store() && {
435 let dst_id = user.operand.as_store().unwrap().dst_val();
436 let dst = self.func.data.value_ref(dst_id);
437 matches!(dst, Value::Instruction(id) if id == &alloca_id)
438 })
439 })
440 }
441
442 fn is_stored_only_once(&self, alloca: &Instruction) -> bool {
443 let alloca_id = alloca.id.unwrap();
444 self.func
445 .data
446 .users_of(alloca_id)
447 .iter()
448 .fold(0usize, |acc, &user_id| {
449 let user = self.func.data.inst_ref(user_id);
450 user.opcode.is_store() as usize + acc
451 })
452 == 1
453 }
454
455 fn is_only_used_in_single_block(&self, alloca: &Instruction) -> bool {
456 let alloca_id = alloca.id.unwrap();
457 let mut last_parent = None;
458 self.func.data.users_of(alloca_id).iter().all(|&user_id| {
459 let user = self.func.data.inst_ref(user_id);
460 let eq = last_parent.get_or_insert(user.parent) == &user.parent;
461 last_parent = Some(user.parent);
462 eq
463 })
464 }
465}
466
467impl Default for InstructionIndexes {
468 fn default() -> Self {
469 Self(FxHashMap::default())
470 }
471}
472
473impl InstructionIndexes {
474 pub fn get(&mut self, func: &Function, inst_id: InstructionId) -> InstructionIndex {
475 if let Some(idx) = self.0.get(&inst_id) {
476 return *idx;
477 }
478
479 let inst = func.data.inst_ref(inst_id);
480 for (i, inst_id) in func.layout.inst_iter(inst.parent).enumerate() {
481 let opcode = func.data.inst_ref(inst_id).opcode;
482 let is_interesting = opcode.is_store() || opcode.is_load() || opcode.is_alloca();
483 if is_interesting {
484 self.0.insert(inst_id, i);
485 }
486 }
487
488 self.get(func, inst_id)
489 }
490}
491
492impl TransformPass<Function> for Mem2RegPass {
493 fn run_on(&self, func: &mut Function, _result: &mut Box<dyn Any>) {
494 Mem2Reg::new(func).run();
495 }
496}
497
498impl Ord for BlockLevel {
499 fn cmp(&self, other: &Self) -> Ordering {
500 self.0.cmp(&other.0)
501 }
502}
503
504impl PartialOrd for BlockLevel {
505 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
506 Some(self.cmp(other))
507 }
508}