Skip to main content

wave_compiler/optimize/
mem2reg.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Memory-to-register promotion pass (SSA construction).
5//!
6//! Promotes local memory load/store pairs to register operations by
7//! identifying variables that can live in registers. Inserts phi nodes
8//! at dominance frontiers where needed.
9
10use std::collections::HashMap;
11
12use super::pass::Pass;
13use crate::hir::types::AddressSpace;
14use crate::mir::function::MirFunction;
15use crate::mir::instruction::MirInst;
16use crate::mir::value::ValueId;
17
18/// Memory-to-register promotion pass.
19pub struct Mem2Reg;
20
21impl Pass for Mem2Reg {
22    fn name(&self) -> &'static str {
23        "mem2reg"
24    }
25
26    fn run(&self, func: &mut MirFunction) -> bool {
27        let mut local_stores: HashMap<ValueId, ValueId> = HashMap::new();
28        let mut promotable_loads: Vec<(usize, usize, ValueId, ValueId)> = Vec::new();
29        let mut changed = false;
30
31        for (block_idx, block) in func.blocks.iter().enumerate() {
32            for (inst_idx, inst) in block.instructions.iter().enumerate() {
33                match inst {
34                    MirInst::Store {
35                        addr,
36                        value,
37                        space: AddressSpace::Private,
38                    } => {
39                        local_stores.insert(*addr, *value);
40                    }
41                    MirInst::Load {
42                        dest,
43                        addr,
44                        space: AddressSpace::Private,
45                        ..
46                    } => {
47                        if let Some(&stored_val) = local_stores.get(addr) {
48                            promotable_loads.push((block_idx, inst_idx, *dest, stored_val));
49                        }
50                    }
51                    _ => {}
52                }
53            }
54        }
55
56        for (block_idx, inst_idx, dest, replacement_value) in promotable_loads.into_iter().rev() {
57            func.blocks[block_idx].instructions[inst_idx] = MirInst::Const {
58                dest,
59                value: crate::mir::instruction::ConstValue::U32(replacement_value.index()),
60            };
61            changed = true;
62        }
63
64        changed
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use crate::mir::basic_block::{BasicBlock, Terminator};
72    use crate::mir::instruction::MirInst;
73    use crate::mir::types::MirType;
74    use crate::mir::value::BlockId;
75
76    #[test]
77    fn test_mem2reg_promotes_private_load() {
78        let mut func = MirFunction::new("test".into(), BlockId(0));
79        let mut bb = BasicBlock::new(BlockId(0));
80        bb.instructions.push(MirInst::Store {
81            addr: ValueId(0),
82            value: ValueId(1),
83            space: AddressSpace::Private,
84        });
85        bb.instructions.push(MirInst::Load {
86            dest: ValueId(2),
87            addr: ValueId(0),
88            space: AddressSpace::Private,
89            ty: MirType::I32,
90        });
91        bb.terminator = Terminator::Return;
92        func.blocks.push(bb);
93
94        let pass = Mem2Reg;
95        assert!(pass.run(&mut func));
96    }
97
98    #[test]
99    fn test_mem2reg_ignores_device_memory() {
100        let mut func = MirFunction::new("test".into(), BlockId(0));
101        let mut bb = BasicBlock::new(BlockId(0));
102        bb.instructions.push(MirInst::Store {
103            addr: ValueId(0),
104            value: ValueId(1),
105            space: AddressSpace::Device,
106        });
107        bb.instructions.push(MirInst::Load {
108            dest: ValueId(2),
109            addr: ValueId(0),
110            space: AddressSpace::Device,
111            ty: MirType::I32,
112        });
113        bb.terminator = Terminator::Return;
114        func.blocks.push(bb);
115
116        let pass = Mem2Reg;
117        assert!(!pass.run(&mut func));
118    }
119}