wasm_core/optimizers/
mod.rs

1use cfgraph::*;
2use prelude::{BTreeSet, BTreeMap};
3
4pub struct RemoveDeadBasicBlocks;
5
6impl Optimizer for RemoveDeadBasicBlocks {
7    type Return = ();
8
9    fn optimize(&self, cfg: &mut CFGraph) -> OptimizeResult<()> {
10        if cfg.blocks.len() == 0 {
11            return Ok(());
12        }
13
14        let mut reachable: BTreeSet<BlockId> = BTreeSet::new();
15
16        // Perform a depth-first search on the CFG to figure out reachable blocks.
17        {
18            let mut dfs_stack: Vec<BlockId> = vec! [ BlockId(0) ];
19
20            while let Some(blk_id) = dfs_stack.pop() {
21                if reachable.contains(&blk_id) {
22                    continue;
23                }
24
25                reachable.insert(blk_id);
26
27                let blk = &cfg.blocks[blk_id.0];
28                match *blk.br.as_ref().unwrap() {
29                    Branch::Jmp(t) => {
30                        dfs_stack.push(t);
31                    },
32                    Branch::JmpEither(a, b) => {
33                        dfs_stack.push(a);
34                        dfs_stack.push(b);
35                    },
36                    Branch::JmpTable(ref targets, otherwise) => {
37                        for t in targets {
38                            dfs_stack.push(*t);
39                        }
40                        dfs_stack.push(otherwise);
41                    },
42                    Branch::Return => {}
43                }
44            }
45        }
46
47        // Maps old block ids to new ones.
48        let mut block_id_mappings: BTreeMap<BlockId, BlockId> = BTreeMap::new();
49
50        // Reachable basic blocks
51        let mut new_basic_blocks = Vec::with_capacity(reachable.len());
52
53        {
54            // Old basic blocks
55            let mut old_basic_blocks = ::prelude::mem::replace(&mut cfg.blocks, Vec::new());
56
57            // reachable is a Set so blk_id will never duplicate.
58            for (i, blk_id) in reachable.iter().enumerate() {
59                block_id_mappings.insert(*blk_id, BlockId(i));
60                new_basic_blocks.push(
61                    ::prelude::mem::replace(
62                        &mut old_basic_blocks[blk_id.0],
63                        BasicBlock::new()
64                    )
65                );
66            }
67        }
68
69        for bb in &mut new_basic_blocks {
70            let old_br = bb.br.take().unwrap();
71            bb.br = Some(match old_br {
72                Branch::Jmp(id) => Branch::Jmp(*block_id_mappings.get(&id).unwrap()),
73                Branch::JmpEither(a, b) => Branch::JmpEither(
74                    *block_id_mappings.get(&a).unwrap(),
75                    *block_id_mappings.get(&b).unwrap()
76                ),
77                Branch::JmpTable(targets, otherwise) => Branch::JmpTable(
78                    targets.into_iter().map(|t| *block_id_mappings.get(&t).unwrap()).collect(),
79                    *block_id_mappings.get(&otherwise).unwrap()
80                ),
81                Branch::Return => Branch::Return
82            });
83        }
84
85        cfg.blocks = new_basic_blocks;
86
87        Ok(())
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use opcode::Opcode;
95
96    #[test]
97    fn test_remove_dead_basic_blocks() {
98        let opcodes: Vec<Opcode> = vec! [
99            // bb 0
100            Opcode::I32Const(100), // 0
101            Opcode::Jmp(3), // 1
102            // bb 1, never reached
103            Opcode::I32Const(50), // 2
104            // bb 2 (due to jmp)
105            Opcode::I32Const(25), // 3
106            Opcode::JmpIf(0), // 4
107            // bb 3
108            Opcode::Return // 5
109        ];
110
111        let mut cfg = CFGraph::from_function(opcodes.as_slice()).unwrap();
112        cfg.validate().unwrap();
113        cfg.optimize(RemoveDeadBasicBlocks).unwrap();
114        cfg.validate().unwrap();
115
116        assert_eq!(cfg.blocks.len(), 3);
117        assert_eq!(cfg.blocks[0].br, Some(Branch::Jmp(BlockId(1))));
118        assert_eq!(cfg.blocks[1].br, Some(Branch::JmpEither(BlockId(0), BlockId(2))));
119        assert_eq!(cfg.blocks[2].br, Some(Branch::Return));
120
121        eprintln!("{:?}", cfg);
122
123        eprintln!("{:?}", cfg.gen_opcodes());
124    }
125}