1use opcode::Opcode;
2use prelude::{BTreeMap, BTreeSet};
3
4#[derive(Clone, Debug)]
5pub struct CFGraph {
6 pub blocks: Vec<BasicBlock>
7}
8
9#[derive(Clone, Debug)]
10pub struct BasicBlock {
11 pub opcodes: Vec<Opcode>,
12 pub br: Option<Branch> }
14
15#[derive(Clone, Debug, Eq, PartialEq)]
16pub enum Branch {
17 Jmp(BlockId),
18 JmpEither(BlockId, BlockId), JmpTable(Vec<BlockId>, BlockId),
20 Return
21}
22
23#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
24pub struct BlockId(pub usize);
25
26pub type OptimizeResult<T> = Result<T, OptimizeError>;
27
28#[derive(Clone, Debug)]
29pub enum OptimizeError {
30 InvalidBranchTarget,
31 Custom(String)
32}
33
34pub trait Optimizer {
35 type Return;
36
37 fn optimize(&self, cfg: &mut CFGraph) -> OptimizeResult<Self::Return>;
38}
39
40fn _assert_optimizer_trait_object_safe() {
41 struct Opt {}
42 impl Optimizer for Opt {
43 type Return = ();
44 fn optimize(&self, _: &mut CFGraph) -> OptimizeResult<Self::Return> { Ok(()) }
45 }
46
47 let _obj: Box<Optimizer<Return = ()>> = Box::new(Opt {});
48}
49
50trait CheckedBranchTarget {
51 type TValue;
52
53 fn checked_branch_target(&self) -> OptimizeResult<Self::TValue>;
54}
55
56impl<'a> CheckedBranchTarget for Option<&'a BlockId> {
57 type TValue = BlockId;
58
59 fn checked_branch_target(&self) -> OptimizeResult<BlockId> {
60 match *self {
61 Some(v) => Ok(*v),
62 None => Err(OptimizeError::InvalidBranchTarget)
63 }
64 }
65}
66
67impl CFGraph {
68 pub fn from_function(fops: &[Opcode]) -> OptimizeResult<CFGraph> {
69 Ok(CFGraph {
70 blocks: scan_basic_blocks(fops)?
71 })
72 }
73
74 pub fn validate(&self) -> OptimizeResult<()> {
75 for blk in &self.blocks {
76 for op in &blk.opcodes {
77 if op.is_branch() {
78 return Err(OptimizeError::Custom(
79 "Branch instruction(s) found in the middle of a basic block".into()
80 ));
81 }
82 }
83 let br = if let Some(ref br) = blk.br {
84 br
85 } else {
86 return Err(OptimizeError::Custom(
87 "Empty branch target(s) found".into()
88 ));
89 };
90 let br_ok = match *br {
91 Branch::Jmp(id) => {
92 if id.0 >= self.blocks.len() {
93 false
94 } else {
95 true
96 }
97 },
98 Branch::JmpEither(a, b) => {
99 if a.0 >= self.blocks.len() || b.0 >= self.blocks.len() {
100 false
101 } else {
102 true
103 }
104 },
105 Branch::JmpTable(ref targets, otherwise) => {
106 let mut ok = true;
107 for t in targets {
108 if t.0 >= self.blocks.len() {
109 ok = false;
110 break;
111 }
112 }
113 if ok {
114 if otherwise.0 >= self.blocks.len() {
115 false
116 } else {
117 true
118 }
119 } else {
120 false
121 }
122 },
123 Branch::Return => true
124 };
125 if !br_ok {
126 return Err(OptimizeError::Custom(
127 "Invalid branch target(s)".into()
128 ));
129 }
130 }
131
132 Ok(())
133 }
134
135 pub fn gen_opcodes(&self) -> Vec<Opcode> {
137 enum OpOrBr {
138 Op(Opcode),
139 Br(Branch) }
141
142 let mut seq: Vec<OpOrBr> = Vec::new();
143 let mut begin_instrs: Vec<u32> = Vec::with_capacity(self.blocks.len());
144
145 for (i, bb) in self.blocks.iter().enumerate() {
146 begin_instrs.push(seq.len() as u32);
147 for op in &bb.opcodes {
148 seq.push(OpOrBr::Op(op.clone()));
149 }
150 seq.push(OpOrBr::Br(bb.br.as_ref().unwrap().clone()));
151 }
152
153 seq.into_iter().map(|oob| {
154 match oob {
155 OpOrBr::Op(op) => op,
156 OpOrBr::Br(br) => {
157 match br {
158 Branch::Jmp(BlockId(id)) => Opcode::Jmp(begin_instrs[id]),
159 Branch::JmpEither(BlockId(if_true), BlockId(if_false)) => {
160 Opcode::JmpEither(
161 begin_instrs[if_true],
162 begin_instrs[if_false]
163 )
164 },
165 Branch::JmpTable(targets, BlockId(otherwise)) => Opcode::JmpTable(
166 targets.into_iter().map(|BlockId(id)| begin_instrs[id]).collect(),
167 begin_instrs[otherwise]
168 ),
169 Branch::Return => Opcode::Return
170 }
171 }
172 }
173 }).collect()
174 }
175
176 pub fn optimize<
177 T: Optimizer<Return = R>,
178 R
179 >(&mut self, optimizer: T) -> OptimizeResult<R> {
180 optimizer.optimize(self)
181 }
182}
183
184impl BasicBlock {
185 pub fn new() -> BasicBlock {
186 BasicBlock {
187 opcodes: vec! [],
188 br: None
189 }
190 }
191}
192
193impl Opcode {
194 fn is_branch(&self) -> bool {
195 match *self {
196 Opcode::Jmp(_) | Opcode::JmpIf(_) | Opcode::JmpEither(_, _) | Opcode::JmpTable(_, _) | Opcode::Return => true,
197 _ => false
198 }
199 }
200}
201
202fn scan_basic_blocks(ops: &[Opcode]) -> OptimizeResult<Vec<BasicBlock>> {
204 if ops.len() == 0 {
205 return Ok(Vec::new());
206 }
207
208 let mut jmp_targets: BTreeSet<u32> = BTreeSet::new();
209
210 jmp_targets.insert(0);
212
213 {
214 for (i, op) in ops.iter().enumerate() {
216 if op.is_branch() {
217 match *op {
218 Opcode::Jmp(id) => {
219 jmp_targets.insert(id);
220 },
221 Opcode::JmpIf(id) => {
222 jmp_targets.insert(id);
223 },
224 Opcode::JmpEither(a, b) => {
225 jmp_targets.insert(a);
226 jmp_targets.insert(b);
227 },
228 Opcode::JmpTable(ref targets, otherwise) => {
229 for t in targets {
230 jmp_targets.insert(*t);
231 }
232 jmp_targets.insert(otherwise);
233 },
234 Opcode::Return => {},
235 _ => unreachable!()
236 }
237
238 jmp_targets.insert((i + 1) as u32);
240 }
241 }
242 }
243
244 let (bb_ops, instr_mappings): (Vec<&[Opcode]>, BTreeMap<u32, BlockId>) = {
246 let mut bb_ops: Vec<&[Opcode]> = Vec::new();
247 let mut instr_mappings: BTreeMap<u32, BlockId> = BTreeMap::new();
248
249 let mut jmp_targets: Vec<u32> = jmp_targets.iter().map(|v| *v).collect();
251
252 {
255 let last = *jmp_targets.last().unwrap() as usize;
256 if last > ops.len() {
257 return Err(OptimizeError::InvalidBranchTarget);
258 }
259
260 if last < ops.len() {
264 jmp_targets.push(ops.len() as u32);
265 }
266 }
267
268 for i in 0..jmp_targets.len() - 1 {
269 let st = jmp_targets[i] as usize;
271 let ed = jmp_targets[i + 1] as usize;
272 instr_mappings.insert(st as u32, BlockId(bb_ops.len()));
273 bb_ops.push(&ops[st..ed]);
274 }
275
276 (bb_ops, instr_mappings)
277 };
278
279 let mut bbs: Vec<BasicBlock> = Vec::new();
280
281 for (i, bb) in bb_ops.iter().enumerate() {
282 let mut bb = bb.to_vec();
283
284 let br: Option<Branch> = if let Some(op) = bb.last() {
285 if op.is_branch() {
286 Some(match *op {
287 Opcode::Jmp(target) => Branch::Jmp(instr_mappings.get(&target).checked_branch_target()?),
288 Opcode::JmpIf(target) => Branch::JmpEither(
289 instr_mappings.get(&target).checked_branch_target()?, BlockId(i + 1) ),
292 Opcode::JmpEither(a, b) => Branch::JmpEither(
293 instr_mappings.get(&a).checked_branch_target()?,
294 instr_mappings.get(&b).checked_branch_target()?
295 ),
296 Opcode::JmpTable(ref targets, otherwise) => {
297 let mut br_targets: Vec<BlockId> = Vec::new();
298 for t in targets {
299 br_targets.push(instr_mappings.get(t).checked_branch_target()?);
300 }
301 Branch::JmpTable(
302 br_targets,
303 instr_mappings.get(&otherwise).checked_branch_target()?
304 )
305 },
306 Opcode::Return => Branch::Return,
307 _ => unreachable!()
308 })
309 } else {
310 None
311 }
312 } else {
313 None
314 };
315
316 let br: Branch = if let Some(v) = br {
317 bb.pop().unwrap();
318 v
319 } else {
320 Branch::Jmp(BlockId(i + 1))
321 };
322
323 let mut result = BasicBlock::new();
324 result.opcodes = bb;
325 result.br = Some(br);
326
327 bbs.push(result);
328 }
329
330 Ok(bbs)
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_jmp() {
339 let opcodes: Vec<Opcode> = vec! [
340 Opcode::I32Const(100), Opcode::Jmp(3), Opcode::I32Const(50), Opcode::I32Const(25), Opcode::Return ];
349
350 let cfg = CFGraph::from_function(opcodes.as_slice()).unwrap();
351 cfg.validate().unwrap();
352
353 assert_eq!(cfg.blocks.len(), 3);
354 assert_eq!(cfg.blocks[0].br, Some(Branch::Jmp(BlockId(2))));
355 assert_eq!(cfg.blocks[1].br, Some(Branch::Jmp(BlockId(2))));
356 assert_eq!(cfg.blocks[2].br, Some(Branch::Return));
357
358 eprintln!("{:?}", cfg);
359
360 eprintln!("{:?}", cfg.gen_opcodes());
361 }
362
363 #[test]
364 fn test_jmp_if() {
365 let opcodes: Vec<Opcode> = vec! [
366 Opcode::I32Const(100), Opcode::JmpIf(3), Opcode::I32Const(50), Opcode::I32Const(25), Opcode::Return ];
375
376 let cfg = CFGraph::from_function(opcodes.as_slice()).unwrap();
377 cfg.validate().unwrap();
378
379 assert_eq!(cfg.blocks.len(), 3);
380 assert_eq!(cfg.blocks[0].br, Some(Branch::JmpEither(BlockId(2), BlockId(1))));
381 assert_eq!(cfg.blocks[1].br, Some(Branch::Jmp(BlockId(2))));
382 assert_eq!(cfg.blocks[2].br, Some(Branch::Return));
383
384 eprintln!("{:?}", cfg);
385
386 eprintln!("{:?}", cfg.gen_opcodes());
387 }
388
389 #[test]
390 fn test_circular() {
391 let opcodes: Vec<Opcode> = vec! [
392 Opcode::I32Const(100), Opcode::JmpIf(0),
395 Opcode::Return ];
398
399 let cfg = CFGraph::from_function(opcodes.as_slice()).unwrap();
400 cfg.validate().unwrap();
401
402 assert_eq!(cfg.blocks.len(), 2);
403 assert_eq!(cfg.blocks[0].br, Some(Branch::JmpEither(BlockId(0), BlockId(1))));
404
405 eprintln!("{:?}", cfg);
406
407 eprintln!("{:?}", cfg.gen_opcodes());
408 }
409
410 #[test]
411 fn test_invalid_branch_target() {
412 let opcodes: Vec<Opcode> = vec! [ Opcode::Jmp(10) ];
413 match CFGraph::from_function(opcodes.as_slice()) {
414 Err(OptimizeError::InvalidBranchTarget) => {},
415 _ => panic!("Expecting an InvalidBranchTarget error")
416 }
417 }
418}