wave_compiler/optimize/
mem2reg.rs1use std::collections::HashMap;
11
12use super::pass::Pass;
13use crate::hir::types::AddressSpace;
14use crate::mir::function::MirFunction;
15use crate::mir::instruction::MirInst;
16use crate::mir::value::ValueId;
17
18pub struct Mem2Reg;
20
21impl Pass for Mem2Reg {
22 fn name(&self) -> &'static str {
23 "mem2reg"
24 }
25
26 fn run(&self, func: &mut MirFunction) -> bool {
27 let mut local_stores: HashMap<ValueId, ValueId> = HashMap::new();
28 let mut promotable_loads: Vec<(usize, usize, ValueId, ValueId)> = Vec::new();
29 let mut changed = false;
30
31 for (block_idx, block) in func.blocks.iter().enumerate() {
32 for (inst_idx, inst) in block.instructions.iter().enumerate() {
33 match inst {
34 MirInst::Store {
35 addr,
36 value,
37 space: AddressSpace::Private,
38 } => {
39 local_stores.insert(*addr, *value);
40 }
41 MirInst::Load {
42 dest,
43 addr,
44 space: AddressSpace::Private,
45 ..
46 } => {
47 if let Some(&stored_val) = local_stores.get(addr) {
48 promotable_loads.push((block_idx, inst_idx, *dest, stored_val));
49 }
50 }
51 _ => {}
52 }
53 }
54 }
55
56 for (block_idx, inst_idx, dest, replacement_value) in promotable_loads.into_iter().rev() {
57 func.blocks[block_idx].instructions[inst_idx] = MirInst::Const {
58 dest,
59 value: crate::mir::instruction::ConstValue::U32(replacement_value.index()),
60 };
61 changed = true;
62 }
63
64 changed
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use crate::mir::basic_block::{BasicBlock, Terminator};
72 use crate::mir::instruction::MirInst;
73 use crate::mir::types::MirType;
74 use crate::mir::value::BlockId;
75
76 #[test]
77 fn test_mem2reg_promotes_private_load() {
78 let mut func = MirFunction::new("test".into(), BlockId(0));
79 let mut bb = BasicBlock::new(BlockId(0));
80 bb.instructions.push(MirInst::Store {
81 addr: ValueId(0),
82 value: ValueId(1),
83 space: AddressSpace::Private,
84 });
85 bb.instructions.push(MirInst::Load {
86 dest: ValueId(2),
87 addr: ValueId(0),
88 space: AddressSpace::Private,
89 ty: MirType::I32,
90 });
91 bb.terminator = Terminator::Return;
92 func.blocks.push(bb);
93
94 let pass = Mem2Reg;
95 assert!(pass.run(&mut func));
96 }
97
98 #[test]
99 fn test_mem2reg_ignores_device_memory() {
100 let mut func = MirFunction::new("test".into(), BlockId(0));
101 let mut bb = BasicBlock::new(BlockId(0));
102 bb.instructions.push(MirInst::Store {
103 addr: ValueId(0),
104 value: ValueId(1),
105 space: AddressSpace::Device,
106 });
107 bb.instructions.push(MirInst::Load {
108 dest: ValueId(2),
109 addr: ValueId(0),
110 space: AddressSpace::Device,
111 ty: MirType::I32,
112 });
113 bb.terminator = Terminator::Return;
114 func.blocks.push(bb);
115
116 let pass = Mem2Reg;
117 assert!(!pass.run(&mut func));
118 }
119}