sonatina_codegen/optim/
sccp.rs

1//! This module contains a solver for Sparse Conditional Constant Propagation.
2//!
3//! The algorithm is based on Mark N. Wegman., Frank Kcnncth Zadeck.: Constant propagation with conditional branches:  
4//! ACM Transactions on Programming Languages and Systems Volume 13 Issue 2 April 1991 pp 181–210:  
5//! <https://doi.org/10.1145/103135.103136>
6
7use std::{collections::BTreeSet, ops};
8
9use cranelift_entity::SecondaryMap;
10
11use crate::cfg::ControlFlowGraph;
12
13use sonatina_ir::{
14    func_cursor::{CursorLocation, FuncCursor, InsnInserter},
15    insn::{BinaryOp, CastOp, InsnData, UnaryOp},
16    Block, Function, Immediate, Insn, Type, Value,
17};
18
19#[derive(Debug)]
20pub struct SccpSolver {
21    lattice: SecondaryMap<Value, LatticeCell>,
22    reachable_edges: BTreeSet<FlowEdge>,
23    reachable_blocks: BTreeSet<Block>,
24
25    flow_work: Vec<FlowEdge>,
26    ssa_work: Vec<Value>,
27}
28
29impl SccpSolver {
30    pub fn new() -> Self {
31        Self {
32            lattice: SecondaryMap::default(),
33            reachable_edges: BTreeSet::default(),
34            reachable_blocks: BTreeSet::default(),
35            flow_work: Vec::default(),
36            ssa_work: Vec::default(),
37        }
38    }
39    pub fn run(&mut self, func: &mut Function, cfg: &mut ControlFlowGraph) {
40        self.clear();
41
42        let entry_block = match func.layout.entry_block() {
43            Some(block) => block,
44            _ => return,
45        };
46
47        // Function arguments must be `LatticeCell::Top`
48        for arg in &func.arg_values {
49            self.lattice[*arg] = LatticeCell::Top;
50        }
51
52        // Evaluate all values in entry block.
53        self.reachable_blocks.insert(entry_block);
54        self.eval_insns_in(func, entry_block);
55
56        let mut changed = true;
57        while changed {
58            changed = false;
59
60            while let Some(edge) = self.flow_work.pop() {
61                changed = true;
62                self.eval_edge(func, edge);
63            }
64
65            while let Some(value) = self.ssa_work.pop() {
66                changed = true;
67                for &user in func.dfg.users(value) {
68                    let user_block = func.layout.insn_block(user);
69                    if self.reachable_blocks.contains(&user_block) {
70                        if func.dfg.is_phi(user) {
71                            self.eval_phi(func, user);
72                        } else {
73                            self.eval_insn(func, user);
74                        }
75                    }
76                }
77            }
78        }
79
80        self.remove_unreachable_edges(func);
81        cfg.compute(func);
82        self.fold_insns(func, cfg);
83    }
84
85    pub fn clear(&mut self) {
86        self.lattice.clear();
87        self.reachable_edges.clear();
88        self.reachable_blocks.clear();
89        self.flow_work.clear();
90        self.ssa_work.clear();
91    }
92
93    fn eval_edge(&mut self, func: &mut Function, edge: FlowEdge) {
94        let dest = edge.to;
95
96        if self.reachable_edges.contains(&edge) {
97            return;
98        }
99        self.reachable_edges.insert(edge);
100
101        if self.reachable_blocks.contains(&dest) {
102            self.eval_phis_in(func, dest);
103        } else {
104            self.reachable_blocks.insert(dest);
105            self.eval_insns_in(func, dest);
106        }
107
108        if let Some(last_insn) = func.layout.last_insn_of(dest) {
109            let branch_info = func.dfg.analyze_branch(last_insn);
110            if branch_info.dests_num() == 1 {
111                self.flow_work.push(FlowEdge::new(
112                    last_insn,
113                    branch_info.iter_dests().next().unwrap(),
114                ))
115            }
116        }
117    }
118
119    fn eval_phis_in(&mut self, func: &Function, block: Block) {
120        for insn in func.layout.iter_insn(block) {
121            if func.dfg.is_phi(insn) {
122                self.eval_phi(func, insn);
123            }
124        }
125    }
126
127    fn eval_phi(&mut self, func: &Function, insn: Insn) {
128        debug_assert!(func.dfg.is_phi(insn));
129        debug_assert!(self
130            .reachable_blocks
131            .contains(&func.layout.insn_block(insn)));
132
133        for &arg in func.dfg.insn_args(insn) {
134            if let Some(imm) = func.dfg.value_imm(arg) {
135                self.set_lattice_cell(arg, LatticeCell::Const(imm));
136            }
137        }
138
139        let block = func.layout.insn_block(insn);
140
141        let mut eval_result = LatticeCell::Bot;
142        for (i, from) in func.dfg.phi_blocks(insn).iter().enumerate() {
143            if self.is_reachable(func, *from, block) {
144                let phi_arg = func.dfg.insn_arg(insn, i);
145                let v_cell = self.lattice[phi_arg];
146                eval_result = eval_result.join(v_cell);
147            }
148        }
149
150        let phi_value = func.dfg.insn_result(insn).unwrap();
151        if eval_result != self.lattice[phi_value] {
152            self.ssa_work.push(phi_value);
153            self.lattice[phi_value] = eval_result;
154        }
155    }
156
157    fn eval_insns_in(&mut self, func: &Function, block: Block) {
158        for insn in func.layout.iter_insn(block) {
159            if func.dfg.is_phi(insn) {
160                self.eval_phi(func, insn);
161            } else {
162                self.eval_insn(func, insn);
163            }
164        }
165    }
166
167    fn eval_insn(&mut self, func: &Function, insn: Insn) {
168        debug_assert!(!func.dfg.is_phi(insn));
169        for &arg in func.dfg.insn_args(insn) {
170            if let Some(imm) = func.dfg.value_imm(arg) {
171                self.set_lattice_cell(arg, LatticeCell::Const(imm));
172            }
173        }
174
175        let cell = match func.dfg.insn_data(insn) {
176            InsnData::Unary { code, args } => {
177                let arg_cell = self.lattice[args[0]];
178                match *code {
179                    UnaryOp::Not => arg_cell.not(),
180                    UnaryOp::Neg => arg_cell.neg(),
181                }
182            }
183
184            InsnData::Binary { code, args } => {
185                let lhs = self.lattice[args[0]];
186                let rhs = self.lattice[args[1]];
187                match *code {
188                    BinaryOp::Add => lhs.add(rhs),
189                    BinaryOp::Sub => lhs.sub(rhs),
190                    BinaryOp::Mul => lhs.mul(rhs),
191                    BinaryOp::Udiv => lhs.udiv(rhs),
192                    BinaryOp::Sdiv => lhs.sdiv(rhs),
193                    BinaryOp::Lt => lhs.lt(rhs),
194                    BinaryOp::Gt => lhs.gt(rhs),
195                    BinaryOp::Slt => lhs.slt(rhs),
196                    BinaryOp::Sgt => lhs.sgt(rhs),
197                    BinaryOp::Le => lhs.le(rhs),
198                    BinaryOp::Ge => lhs.ge(rhs),
199                    BinaryOp::Sle => lhs.sle(rhs),
200                    BinaryOp::Sge => lhs.sge(rhs),
201                    BinaryOp::Eq => lhs.eq(rhs),
202                    BinaryOp::Ne => lhs.ne(rhs),
203                    BinaryOp::And => lhs.and(rhs),
204                    BinaryOp::Or => lhs.or(rhs),
205                    BinaryOp::Xor => lhs.xor(rhs),
206                }
207            }
208
209            InsnData::Cast { code, args, ty } => {
210                let arg_cell = self.lattice[args[0]];
211                match code {
212                    CastOp::Sext => arg_cell.sext(*ty),
213                    CastOp::Zext => arg_cell.zext(*ty),
214                    CastOp::Trunc => arg_cell.trunc(*ty),
215                    CastOp::BitCast => LatticeCell::Top,
216                }
217            }
218
219            InsnData::Load { .. } => LatticeCell::Top,
220
221            InsnData::Call { .. } => LatticeCell::Top,
222
223            InsnData::Jump { dests, .. } => {
224                self.flow_work.push(FlowEdge::new(insn, dests[0]));
225                return;
226            }
227
228            InsnData::Branch { args, dests } => {
229                let v_cell = self.lattice[args[0]];
230
231                if v_cell.is_top() {
232                    // Add both then and else edges.
233                    self.flow_work.push(FlowEdge::new(insn, dests[0]));
234                    self.flow_work.push(FlowEdge::new(insn, dests[1]));
235                } else if v_cell.is_bot() {
236                    unreachable!();
237                } else if v_cell.is_zero() {
238                    // Add else edge.
239                    self.flow_work.push(FlowEdge::new(insn, dests[1]));
240                } else {
241                    // Add then edge.
242                    self.flow_work.push(FlowEdge::new(insn, dests[0]));
243                }
244
245                return;
246            }
247
248            InsnData::BrTable {
249                args,
250                default,
251                table,
252            } => {
253                // An closure that add all destinations of the `BrTable.
254                let mut add_all_dest = || {
255                    if let Some(default) = default {
256                        self.flow_work.push(FlowEdge::new(insn, *default));
257                    }
258                    for dest in table {
259                        self.flow_work.push(FlowEdge::new(insn, *dest));
260                    }
261                };
262
263                let v_cell = self.lattice[args[0]];
264
265                // If the argument of the `BrTable` is top, then add all destinations.
266                if v_cell.is_top() {
267                    add_all_dest();
268                    return;
269                }
270
271                // Verifier verifies that the use of the argument must dominated by the its
272                // definition, so `v_cell` must not be bot.
273                if v_cell.is_bot() {
274                    unreachable!()
275                }
276
277                let mut contains_top = false;
278                for (value, dest) in args[1..].iter().zip(table.iter()) {
279                    if self.lattice[*value] == v_cell {
280                        self.flow_work.push(FlowEdge::new(insn, *dest));
281                        return;
282                    } else if v_cell.is_top() {
283                        contains_top = true;
284                    }
285                }
286
287                if contains_top {
288                    // If one of the table value is top, then add all dests.
289                    add_all_dest();
290                } else {
291                    // If all table values is not top, then just add default destination.
292                    if let Some(default) = default {
293                        self.flow_work.push(FlowEdge::new(insn, *default));
294                    }
295                }
296
297                return;
298            }
299
300            InsnData::Alloca { .. } | InsnData::Gep { .. } => LatticeCell::Top,
301
302            InsnData::Store { .. } | InsnData::Return { .. } => {
303                // No insn result. Do nothing.
304                return;
305            }
306
307            InsnData::Phi { .. } => unreachable!(),
308        };
309
310        let insn_result = func.dfg.insn_result(insn).unwrap();
311        self.set_lattice_cell(insn_result, cell);
312    }
313
314    /// Remove unreachable edges and blocks.
315    fn remove_unreachable_edges(&self, func: &mut Function) {
316        let entry_block = func.layout.entry_block().unwrap();
317        let mut inserter = InsnInserter::new(func, CursorLocation::BlockTop(entry_block));
318
319        loop {
320            match inserter.loc() {
321                CursorLocation::BlockTop(block) => {
322                    if !self.reachable_blocks.contains(&block) {
323                        inserter.remove_block();
324                    } else {
325                        inserter.proceed();
326                    }
327                }
328
329                CursorLocation::BlockBottom(..) => inserter.proceed(),
330
331                CursorLocation::At(insn) => {
332                    if inserter.func().dfg.is_branch(insn) {
333                        let branch_info = inserter.func().dfg.analyze_branch(insn);
334                        for dest in branch_info.iter_dests().collect::<Vec<_>>() {
335                            if !self.is_reachable_edge(insn, dest) {
336                                inserter.func_mut().dfg.remove_branch_dest(insn, dest);
337                            }
338                        }
339                    }
340                    inserter.proceed();
341                }
342
343                CursorLocation::NoWhere => break,
344            }
345        }
346    }
347
348    fn is_reachable_edge(&self, insn: Insn, dest: Block) -> bool {
349        self.reachable_edges.contains(&FlowEdge::new(insn, dest))
350    }
351
352    fn fold_insns(&mut self, func: &mut Function, cfg: &ControlFlowGraph) {
353        let mut rpo: Vec<_> = cfg.post_order().collect();
354        rpo.reverse();
355
356        for block in rpo {
357            let mut next_insn = func.layout.first_insn_of(block);
358            while let Some(insn) = next_insn {
359                next_insn = func.layout.next_insn_of(insn);
360                self.fold(func, insn);
361            }
362        }
363    }
364
365    fn fold(&self, func: &mut Function, insn: Insn) {
366        let insn_result = match func.dfg.insn_result(insn) {
367            Some(result) => result,
368            None => return,
369        };
370
371        match self.lattice[insn_result].to_imm() {
372            Some(imm) => {
373                InsnInserter::new(func, CursorLocation::At(insn)).remove_insn();
374                let new_value = func.dfg.make_imm_value(imm);
375                func.dfg.change_to_alias(insn_result, new_value);
376            }
377            None => {
378                if func.dfg.is_phi(insn) {
379                    self.try_fold_phi(func, insn)
380                }
381            }
382        }
383    }
384
385    fn try_fold_phi(&self, func: &mut Function, insn: Insn) {
386        debug_assert!(func.dfg.is_phi(insn));
387
388        let mut blocks = func.dfg.phi_blocks(insn).to_vec();
389        blocks.retain(|block| !self.reachable_blocks.contains(block));
390        for block in blocks {
391            func.dfg.remove_phi_arg(insn, block);
392        }
393
394        // Remove phi function if it has just one argument.
395        if func.dfg.insn_args_num(insn) == 1 {
396            let phi_value = func.dfg.insn_result(insn).unwrap();
397            func.dfg
398                .change_to_alias(phi_value, func.dfg.insn_arg(insn, 0));
399            InsnInserter::new(func, CursorLocation::At(insn)).remove_insn();
400        }
401    }
402
403    fn is_reachable(&self, func: &Function, from: Block, to: Block) -> bool {
404        let last_insn = if let Some(insn) = func.layout.last_insn_of(from) {
405            insn
406        } else {
407            return false;
408        };
409        for dest in func.dfg.analyze_branch(last_insn).iter_dests() {
410            if dest == to
411                && self
412                    .reachable_edges
413                    .contains(&FlowEdge::new(last_insn, dest))
414            {
415                return true;
416            }
417        }
418        false
419    }
420
421    fn set_lattice_cell(&mut self, value: Value, cell: LatticeCell) {
422        let old_cell = &self.lattice[value];
423        if old_cell != &cell {
424            self.lattice[value] = cell;
425            self.ssa_work.push(value);
426        }
427    }
428}
429
430impl Default for SccpSolver {
431    fn default() -> Self {
432        Self::new()
433    }
434}
435
436#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
437struct FlowEdge {
438    insn: Insn,
439    to: Block,
440}
441
442impl FlowEdge {
443    fn new(insn: Insn, to: Block) -> Self {
444        Self { insn, to }
445    }
446}
447
448#[derive(Debug, Clone, Copy)]
449enum LatticeCell {
450    Top,
451    Const(Immediate),
452    Bot,
453}
454
455impl PartialEq for LatticeCell {
456    fn eq(&self, rhs: &Self) -> bool {
457        match (self, rhs) {
458            (Self::Top, Self::Top) | (Self::Bot, Self::Bot) => true,
459            (Self::Const(v1), Self::Const(v2)) => v1 == v2,
460            _ => false,
461        }
462    }
463}
464
465impl LatticeCell {
466    fn to_imm(self) -> Option<Immediate> {
467        match self {
468            Self::Top | Self::Bot => None,
469            Self::Const(imm) => Some(imm),
470        }
471    }
472
473    fn is_zero(self) -> bool {
474        match self {
475            Self::Top | Self::Bot => false,
476            Self::Const(c) => c.is_zero(),
477        }
478    }
479
480    fn is_top(self) -> bool {
481        matches!(self, Self::Top)
482    }
483
484    fn is_bot(self) -> bool {
485        matches!(self, Self::Bot)
486    }
487
488    fn join(self, rhs: Self) -> Self {
489        match (self, rhs) {
490            (Self::Top, _) | (_, Self::Top) => Self::Top,
491            (Self::Const(v1), Self::Const(v2)) => {
492                if v1 == v2 {
493                    self
494                } else {
495                    Self::Top
496                }
497            }
498            (Self::Bot, other) | (other, Self::Bot) => other,
499        }
500    }
501
502    fn apply_unop<F>(self, f: F) -> Self
503    where
504        F: FnOnce(Immediate) -> Immediate,
505    {
506        match self {
507            Self::Top => Self::Top,
508            Self::Const(lhs) => Self::Const(f(lhs)),
509            Self::Bot => Self::Bot,
510        }
511    }
512
513    fn apply_binop<F>(self, rhs: Self, f: F) -> Self
514    where
515        F: FnOnce(Immediate, Immediate) -> Immediate,
516    {
517        match (self, rhs) {
518            (Self::Top, _) | (_, Self::Top) => Self::Top,
519            (Self::Const(lhs), Self::Const(rhs)) => Self::Const(f(lhs, rhs)),
520            (Self::Bot, _) | (_, Self::Bot) => Self::Bot,
521        }
522    }
523
524    fn not(self) -> Self {
525        self.apply_unop(ops::Not::not)
526    }
527
528    fn neg(self) -> Self {
529        self.apply_unop(ops::Neg::neg)
530    }
531
532    fn add(self, rhs: Self) -> Self {
533        self.apply_binop(rhs, ops::Add::add)
534    }
535
536    fn sub(self, rhs: Self) -> Self {
537        self.apply_binop(rhs, ops::Sub::sub)
538    }
539
540    fn mul(self, rhs: Self) -> Self {
541        self.apply_binop(rhs, ops::Mul::mul)
542    }
543
544    fn udiv(self, rhs: Self) -> Self {
545        self.apply_binop(rhs, Immediate::udiv)
546    }
547
548    fn sdiv(self, rhs: Self) -> Self {
549        self.apply_binop(rhs, Immediate::sdiv)
550    }
551
552    fn lt(self, rhs: Self) -> Self {
553        self.apply_binop(rhs, Immediate::lt)
554    }
555
556    fn gt(self, rhs: Self) -> Self {
557        self.apply_binop(rhs, Immediate::gt)
558    }
559
560    fn slt(self, rhs: Self) -> Self {
561        self.apply_binop(rhs, Immediate::slt)
562    }
563
564    fn sgt(self, rhs: Self) -> Self {
565        self.apply_binop(rhs, Immediate::sgt)
566    }
567
568    fn le(self, rhs: Self) -> Self {
569        self.apply_binop(rhs, Immediate::le)
570    }
571
572    fn ge(self, rhs: Self) -> Self {
573        self.apply_binop(rhs, Immediate::ge)
574    }
575
576    fn sle(self, rhs: Self) -> Self {
577        self.apply_binop(rhs, Immediate::sle)
578    }
579
580    fn sge(self, rhs: Self) -> Self {
581        self.apply_binop(rhs, Immediate::sge)
582    }
583
584    fn eq(self, rhs: Self) -> Self {
585        self.apply_binop(rhs, Immediate::imm_eq)
586    }
587
588    fn ne(self, rhs: Self) -> Self {
589        self.apply_binop(rhs, Immediate::imm_ne)
590    }
591
592    fn and(self, rhs: Self) -> Self {
593        self.apply_binop(rhs, ops::BitAnd::bitand)
594    }
595
596    fn or(self, rhs: Self) -> Self {
597        self.apply_binop(rhs, ops::BitOr::bitor)
598    }
599
600    fn xor(self, rhs: Self) -> Self {
601        self.apply_binop(rhs, ops::BitXor::bitxor)
602    }
603
604    fn sext(self, ty: Type) -> Self {
605        self.apply_unop(|val| Immediate::sext(val, ty))
606    }
607
608    fn zext(self, ty: Type) -> Self {
609        self.apply_unop(|val| Immediate::zext(val, ty))
610    }
611
612    fn trunc(self, ty: Type) -> Self {
613        self.apply_unop(|val| Immediate::trunc(val, ty))
614    }
615}
616
617impl Default for LatticeCell {
618    fn default() -> Self {
619        Self::Bot
620    }
621}