Skip to main content

wave_compiler/optimize/
sccp.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Sparse conditional constant propagation pass.
5//!
6//! Combines constant propagation with unreachable code detection. More
7//! powerful than simple constant folding because it tracks which branches
8//! are always taken and can prove code unreachable.
9
10use std::collections::{HashMap, HashSet, VecDeque};
11
12use super::pass::Pass;
13use crate::hir::expr::BinOp;
14use crate::mir::basic_block::Terminator;
15use crate::mir::function::MirFunction;
16use crate::mir::instruction::{ConstValue, MirInst};
17use crate::mir::value::{BlockId, ValueId};
18
19/// Lattice value for SCCP.
20#[derive(Debug, Clone, PartialEq)]
21enum Lattice {
22    Top,
23    Constant(i32),
24    Bottom,
25}
26
27/// Sparse conditional constant propagation pass.
28pub struct Sccp;
29
30impl Pass for Sccp {
31    fn name(&self) -> &'static str {
32        "sccp"
33    }
34
35    fn run(&self, func: &mut MirFunction) -> bool {
36        let mut lattice: HashMap<ValueId, Lattice> = HashMap::new();
37        let mut executable: HashSet<BlockId> = HashSet::new();
38        let mut worklist: VecDeque<BlockId> = VecDeque::new();
39
40        for param in &func.params {
41            lattice.insert(param.value, Lattice::Top);
42        }
43
44        executable.insert(func.entry);
45        worklist.push_back(func.entry);
46
47        while let Some(bid) = worklist.pop_front() {
48            let block = match func.block(bid) {
49                Some(b) => b.clone(),
50                None => continue,
51            };
52
53            for inst in &block.instructions {
54                evaluate_instruction(inst, &mut lattice);
55            }
56
57            match &block.terminator {
58                Terminator::Branch { target } => {
59                    if executable.insert(*target) {
60                        worklist.push_back(*target);
61                    }
62                }
63                Terminator::CondBranch {
64                    cond,
65                    true_target,
66                    false_target,
67                } => match lattice.get(cond) {
68                    Some(Lattice::Constant(v)) if *v != 0 => {
69                        if executable.insert(*true_target) {
70                            worklist.push_back(*true_target);
71                        }
72                    }
73                    Some(Lattice::Constant(0)) => {
74                        if executable.insert(*false_target) {
75                            worklist.push_back(*false_target);
76                        }
77                    }
78                    _ => {
79                        if executable.insert(*true_target) {
80                            worklist.push_back(*true_target);
81                        }
82                        if executable.insert(*false_target) {
83                            worklist.push_back(*false_target);
84                        }
85                    }
86                },
87                Terminator::Return => {}
88            }
89        }
90
91        let mut changed = false;
92
93        for block in &mut func.blocks {
94            for inst in &mut block.instructions {
95                if let Some(dest) = inst.dest() {
96                    if let Some(Lattice::Constant(v)) = lattice.get(&dest) {
97                        if !matches!(inst, MirInst::Const { .. }) {
98                            *inst = MirInst::Const {
99                                dest,
100                                value: ConstValue::I32(*v),
101                            };
102                            changed = true;
103                        }
104                    }
105                }
106            }
107        }
108
109        let original_count = func.blocks.len();
110        func.blocks
111            .retain(|b| executable.contains(&b.id) || b.id == func.entry);
112        if func.blocks.len() != original_count {
113            changed = true;
114        }
115
116        changed
117    }
118}
119
120fn evaluate_instruction(inst: &MirInst, lattice: &mut HashMap<ValueId, Lattice>) {
121    match inst {
122        MirInst::Const { dest, value } => {
123            let v = match value {
124                ConstValue::I32(v) => *v,
125                ConstValue::U32(v) => i32::from_ne_bytes(v.to_ne_bytes()),
126                ConstValue::Bool(v) => i32::from(*v),
127                ConstValue::F32(_) => return,
128            };
129            lattice.insert(*dest, Lattice::Constant(v));
130        }
131        MirInst::BinOp {
132            dest, op, lhs, rhs, ..
133        } => {
134            let l = lattice.get(lhs).cloned().unwrap_or(Lattice::Bottom);
135            let r = lattice.get(rhs).cloned().unwrap_or(Lattice::Bottom);
136
137            let result = match (&l, &r) {
138                (Lattice::Constant(a), Lattice::Constant(b)) => match op {
139                    BinOp::Add => Lattice::Constant(a.wrapping_add(*b)),
140                    BinOp::Sub => Lattice::Constant(a.wrapping_sub(*b)),
141                    BinOp::Mul => Lattice::Constant(a.wrapping_mul(*b)),
142                    BinOp::Lt => Lattice::Constant(i32::from(*a < *b)),
143                    BinOp::Eq => Lattice::Constant(i32::from(*a == *b)),
144                    _ => Lattice::Bottom,
145                },
146                _ => Lattice::Bottom,
147            };
148            lattice.insert(*dest, result);
149        }
150        _ => {}
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::mir::basic_block::{BasicBlock, Terminator};
158    use crate::mir::types::MirType;
159
160    #[test]
161    fn test_sccp_folds_constants() {
162        let mut func = MirFunction::new("test".into(), BlockId(0));
163        let mut bb = BasicBlock::new(BlockId(0));
164        bb.instructions.push(MirInst::Const {
165            dest: ValueId(0),
166            value: ConstValue::I32(10),
167        });
168        bb.instructions.push(MirInst::Const {
169            dest: ValueId(1),
170            value: ConstValue::I32(20),
171        });
172        bb.instructions.push(MirInst::BinOp {
173            dest: ValueId(2),
174            op: BinOp::Add,
175            lhs: ValueId(0),
176            rhs: ValueId(1),
177            ty: MirType::I32,
178        });
179        bb.terminator = Terminator::Return;
180        func.blocks.push(bb);
181
182        let pass = Sccp;
183        assert!(pass.run(&mut func));
184        match &func.blocks[0].instructions[2] {
185            MirInst::Const { value, .. } => assert_eq!(*value, ConstValue::I32(30)),
186            other => panic!("expected Const, got {other:?}"),
187        }
188    }
189
190    #[test]
191    fn test_sccp_no_change_without_constants() {
192        let mut func = MirFunction::new("test".into(), BlockId(0));
193        let mut bb = BasicBlock::new(BlockId(0));
194        bb.instructions.push(MirInst::BinOp {
195            dest: ValueId(2),
196            op: BinOp::Add,
197            lhs: ValueId(0),
198            rhs: ValueId(1),
199            ty: MirType::I32,
200        });
201        bb.terminator = Terminator::Return;
202        func.blocks.push(bb);
203
204        let pass = Sccp;
205        assert!(!pass.run(&mut func));
206    }
207}