wasm_core/
cfgraph.rs

1use opcode::Opcode;
2use prelude::{BTreeMap, BTreeSet};
3
4#[derive(Clone, Debug)]
5pub struct CFGraph {
6    pub blocks: Vec<BasicBlock>
7}
8
9#[derive(Clone, Debug)]
10pub struct BasicBlock {
11    pub opcodes: Vec<Opcode>,
12    pub br: Option<Branch> // must be Some in a valid control graph
13}
14
15#[derive(Clone, Debug, Eq, PartialEq)]
16pub enum Branch {
17    Jmp(BlockId),
18    JmpEither(BlockId, BlockId), // (if_true, if_false)
19    JmpTable(Vec<BlockId>, BlockId),
20    Return
21}
22
23#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
24pub struct BlockId(pub usize);
25
26pub type OptimizeResult<T> = Result<T, OptimizeError>;
27
28#[derive(Clone, Debug)]
29pub enum OptimizeError {
30    InvalidBranchTarget,
31    Custom(String)
32}
33
34pub trait Optimizer {
35    type Return;
36
37    fn optimize(&self, cfg: &mut CFGraph) -> OptimizeResult<Self::Return>;
38}
39
40fn _assert_optimizer_trait_object_safe() {
41    struct Opt {}
42    impl Optimizer for Opt {
43        type Return = ();
44        fn optimize(&self, _: &mut CFGraph) -> OptimizeResult<Self::Return> { Ok(()) }
45    }
46
47    let _obj: Box<Optimizer<Return = ()>> = Box::new(Opt {});
48}
49
50trait CheckedBranchTarget {
51    type TValue;
52
53    fn checked_branch_target(&self) -> OptimizeResult<Self::TValue>;
54}
55
56impl<'a> CheckedBranchTarget for Option<&'a BlockId> {
57    type TValue = BlockId;
58
59    fn checked_branch_target(&self) -> OptimizeResult<BlockId> {
60        match *self {
61            Some(v) => Ok(*v),
62            None => Err(OptimizeError::InvalidBranchTarget)
63        }
64    }
65}
66
67impl CFGraph {
68    pub fn from_function(fops: &[Opcode]) -> OptimizeResult<CFGraph> {
69        Ok(CFGraph {
70            blocks: scan_basic_blocks(fops)?
71        })
72    }
73
74    pub fn validate(&self) -> OptimizeResult<()> {
75        for blk in &self.blocks {
76            for op in &blk.opcodes {
77                if op.is_branch() {
78                    return Err(OptimizeError::Custom(
79                        "Branch instruction(s) found in the middle of a basic block".into()
80                    ));
81                }
82            }
83            let br = if let Some(ref br) = blk.br {
84                br
85            } else {
86                return Err(OptimizeError::Custom(
87                    "Empty branch target(s) found".into()
88                ));
89            };
90            let br_ok = match *br {
91                Branch::Jmp(id) => {
92                    if id.0 >= self.blocks.len() {
93                        false
94                    } else {
95                        true
96                    }
97                },
98                Branch::JmpEither(a, b) => {
99                    if a.0 >= self.blocks.len() || b.0 >= self.blocks.len() {
100                        false
101                    } else {
102                        true
103                    }
104                },
105                Branch::JmpTable(ref targets, otherwise) => {
106                    let mut ok = true;
107                    for t in targets {
108                        if t.0 >= self.blocks.len() {
109                            ok = false;
110                            break;
111                        }
112                    }
113                    if ok {
114                        if otherwise.0 >= self.blocks.len() {
115                            false
116                        } else {
117                            true
118                        }
119                    } else {
120                        false
121                    }
122                },
123                Branch::Return => true
124            };
125            if !br_ok {
126                return Err(OptimizeError::Custom(
127                    "Invalid branch target(s)".into()
128                ));
129            }
130        }
131
132        Ok(())
133    }
134
135    /// Generate sequential opcodes.
136    pub fn gen_opcodes(&self) -> Vec<Opcode> {
137        enum OpOrBr {
138            Op(Opcode),
139            Br(Branch) // pending branch to basic block
140        }
141
142        let mut seq: Vec<OpOrBr> = Vec::new();
143        let mut begin_instrs: Vec<u32> = Vec::with_capacity(self.blocks.len());
144
145        for (i, bb) in self.blocks.iter().enumerate() {
146            begin_instrs.push(seq.len() as u32);
147            for op in &bb.opcodes {
148                seq.push(OpOrBr::Op(op.clone()));
149            }
150            seq.push(OpOrBr::Br(bb.br.as_ref().unwrap().clone()));
151        }
152
153        seq.into_iter().map(|oob| {
154            match oob {
155                OpOrBr::Op(op) => op,
156                OpOrBr::Br(br) => {
157                    match br {
158                        Branch::Jmp(BlockId(id)) => Opcode::Jmp(begin_instrs[id]),
159                        Branch::JmpEither(BlockId(if_true), BlockId(if_false)) => {
160                            Opcode::JmpEither(
161                                begin_instrs[if_true],
162                                begin_instrs[if_false]
163                            )
164                        },
165                        Branch::JmpTable(targets, BlockId(otherwise)) => Opcode::JmpTable(
166                            targets.into_iter().map(|BlockId(id)| begin_instrs[id]).collect(),
167                            begin_instrs[otherwise]
168                        ),
169                        Branch::Return => Opcode::Return
170                    }
171                }
172            }
173        }).collect()
174    }
175
176    pub fn optimize<
177        T: Optimizer<Return = R>,
178        R
179    >(&mut self, optimizer: T) -> OptimizeResult<R> {
180        optimizer.optimize(self)
181    }
182}
183
184impl BasicBlock {
185    pub fn new() -> BasicBlock {
186        BasicBlock {
187            opcodes: vec! [],
188            br: None
189        }
190    }
191}
192
193impl Opcode {
194    fn is_branch(&self) -> bool {
195        match *self {
196            Opcode::Jmp(_) | Opcode::JmpIf(_) | Opcode::JmpEither(_, _) | Opcode::JmpTable(_, _) | Opcode::Return => true,
197            _ => false
198        }
199    }
200}
201
202/// Constructs a Vec of basic blocks.
203fn scan_basic_blocks(ops: &[Opcode]) -> OptimizeResult<Vec<BasicBlock>> {
204    if ops.len() == 0 {
205        return Ok(Vec::new());
206    }
207
208    let mut jmp_targets: BTreeSet<u32> = BTreeSet::new();
209
210    // Entry point.
211    jmp_targets.insert(0);
212
213    {
214        // Detect jmp targets
215        for (i, op) in ops.iter().enumerate() {
216            if op.is_branch() {
217                match *op {
218                    Opcode::Jmp(id) => {
219                        jmp_targets.insert(id);
220                    },
221                    Opcode::JmpIf(id) => {
222                        jmp_targets.insert(id);
223                    },
224                    Opcode::JmpEither(a, b) => {
225                        jmp_targets.insert(a);
226                        jmp_targets.insert(b);
227                    },
228                    Opcode::JmpTable(ref targets, otherwise) => {
229                        for t in targets {
230                            jmp_targets.insert(*t);
231                        }
232                        jmp_targets.insert(otherwise);
233                    },
234                    Opcode::Return => {},
235                    _ => unreachable!()
236                }
237
238                // The instruction following a branch starts a new basic block.
239                jmp_targets.insert((i + 1) as u32);
240            }
241        }
242    }
243
244    // Split opcodes into basic blocks
245    let (bb_ops, instr_mappings): (Vec<&[Opcode]>, BTreeMap<u32, BlockId>) = {
246        let mut bb_ops: Vec<&[Opcode]> = Vec::new();
247        let mut instr_mappings: BTreeMap<u32, BlockId> = BTreeMap::new();
248
249        // jmp_targets.len() >= 1 holds here because of `jmp_targets.insert(0)`
250        let mut jmp_targets: Vec<u32> = jmp_targets.iter().map(|v| *v).collect();
251
252        // [start, end) ...
253        // ops.len
254        {
255            let last = *jmp_targets.last().unwrap() as usize;
256            if last > ops.len() {
257                return Err(OptimizeError::InvalidBranchTarget);
258            }
259
260            // ops.len() >= 1 holds here.
261            // if last == 0 (same as jmp_targets.len() == 1) then a new jmp target will still be pushed
262            // so that jmp_targets.len() >= 2 always hold after this.
263            if last < ops.len() {
264                jmp_targets.push(ops.len() as u32);
265            }
266        }
267
268        for i in 0..jmp_targets.len() - 1 {
269            // [st..ed)
270            let st = jmp_targets[i] as usize;
271            let ed = jmp_targets[i + 1] as usize;
272            instr_mappings.insert(st as u32, BlockId(bb_ops.len()));
273            bb_ops.push(&ops[st..ed]);
274        }
275
276        (bb_ops, instr_mappings)
277    };
278
279    let mut bbs: Vec<BasicBlock> = Vec::new();
280
281    for (i, bb) in bb_ops.iter().enumerate() {
282        let mut bb = bb.to_vec();
283
284        let br: Option<Branch> = if let Some(op) = bb.last() {
285            if op.is_branch() {
286                Some(match *op {
287                    Opcode::Jmp(target) => Branch::Jmp(instr_mappings.get(&target).checked_branch_target()?),
288                    Opcode::JmpIf(target) => Branch::JmpEither(
289                        instr_mappings.get(&target).checked_branch_target()?, // if true
290                        BlockId(i + 1) // otherwise
291                    ),
292                    Opcode::JmpEither(a, b) => Branch::JmpEither(
293                        instr_mappings.get(&a).checked_branch_target()?,
294                        instr_mappings.get(&b).checked_branch_target()?
295                    ),
296                    Opcode::JmpTable(ref targets, otherwise) => {
297                        let mut br_targets: Vec<BlockId> = Vec::new();
298                        for t in targets {
299                            br_targets.push(instr_mappings.get(t).checked_branch_target()?);
300                        }
301                        Branch::JmpTable(
302                            br_targets,
303                            instr_mappings.get(&otherwise).checked_branch_target()?
304                        )
305                    },
306                    Opcode::Return => Branch::Return,
307                    _ => unreachable!()
308                })
309            } else {
310                None
311            }
312        } else {
313            None
314        };
315
316        let br: Branch = if let Some(v) = br {
317            bb.pop().unwrap();
318            v
319        } else {
320            Branch::Jmp(BlockId(i + 1))
321        };
322
323        let mut result = BasicBlock::new();
324        result.opcodes = bb;
325        result.br = Some(br);
326
327        bbs.push(result);
328    }
329
330    Ok(bbs)
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_jmp() {
339        let opcodes: Vec<Opcode> = vec! [
340            // bb 0
341            Opcode::I32Const(100), // 0
342            Opcode::Jmp(3), // 1
343            // bb 1, implicit fallthrough
344            Opcode::I32Const(50), // 2
345            // bb 2 (due to jmp)
346            Opcode::I32Const(25), // 3
347            Opcode::Return // 4
348        ];
349
350        let cfg = CFGraph::from_function(opcodes.as_slice()).unwrap();
351        cfg.validate().unwrap();
352
353        assert_eq!(cfg.blocks.len(), 3);
354        assert_eq!(cfg.blocks[0].br, Some(Branch::Jmp(BlockId(2))));
355        assert_eq!(cfg.blocks[1].br, Some(Branch::Jmp(BlockId(2))));
356        assert_eq!(cfg.blocks[2].br, Some(Branch::Return));
357
358        eprintln!("{:?}", cfg);
359
360        eprintln!("{:?}", cfg.gen_opcodes());
361    }
362
363    #[test]
364    fn test_jmp_if() {
365        let opcodes: Vec<Opcode> = vec! [
366            // bb 0
367            Opcode::I32Const(100), // 0
368            Opcode::JmpIf(3), // 1
369            // bb 1, implicit fallthrough
370            Opcode::I32Const(50), // 2
371            // bb 2 (due to jmp)
372            Opcode::I32Const(25), // 3
373            Opcode::Return // 4
374        ];
375
376        let cfg = CFGraph::from_function(opcodes.as_slice()).unwrap();
377        cfg.validate().unwrap();
378
379        assert_eq!(cfg.blocks.len(), 3);
380        assert_eq!(cfg.blocks[0].br, Some(Branch::JmpEither(BlockId(2), BlockId(1))));
381        assert_eq!(cfg.blocks[1].br, Some(Branch::Jmp(BlockId(2))));
382        assert_eq!(cfg.blocks[2].br, Some(Branch::Return));
383
384        eprintln!("{:?}", cfg);
385
386        eprintln!("{:?}", cfg.gen_opcodes());
387    }
388
389    #[test]
390    fn test_circular() {
391        let opcodes: Vec<Opcode> = vec! [
392            // bb 1
393            Opcode::I32Const(100), // 0
394            Opcode::JmpIf(0),
395            // bb 2
396            Opcode::Return // 4
397        ];
398
399        let cfg = CFGraph::from_function(opcodes.as_slice()).unwrap();
400        cfg.validate().unwrap();
401
402        assert_eq!(cfg.blocks.len(), 2);
403        assert_eq!(cfg.blocks[0].br, Some(Branch::JmpEither(BlockId(0), BlockId(1))));
404
405        eprintln!("{:?}", cfg);
406
407        eprintln!("{:?}", cfg.gen_opcodes());
408    }
409
410    #[test]
411    fn test_invalid_branch_target() {
412        let opcodes: Vec<Opcode> = vec! [ Opcode::Jmp(10) ];
413        match CFGraph::from_function(opcodes.as_slice()) {
414            Err(OptimizeError::InvalidBranchTarget) => {},
415            _ => panic!("Expecting an InvalidBranchTarget error")
416        }
417    }
418}