wasmer_clif_fork_frontend/
switch.rs

1use super::HashMap;
2use crate::frontend::FunctionBuilder;
3use alloc::vec::Vec;
4use cranelift_codegen::ir::condcodes::IntCC;
5use cranelift_codegen::ir::*;
6use log::debug;
7
8type EntryIndex = u64;
9
10/// Unlike with `br_table`, `Switch` cases may be sparse or non-0-based.
11/// They emit efficient code using branches, jump tables, or a combination of both.
12///
13/// # Example
14///
15/// ```rust
16/// # use cranelift_codegen::ir::types::*;
17/// # use cranelift_codegen::ir::{ExternalName, Function, Signature, InstBuilder};
18/// # use cranelift_codegen::isa::CallConv;
19/// # use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext, Position, Switch};
20/// #
21/// # let mut sig = Signature::new(CallConv::SystemV);
22/// # let mut fn_builder_ctx = FunctionBuilderContext::new();
23/// # let mut func = Function::with_name_signature(ExternalName::user(0, 0), sig);
24/// # let mut position = Position::default();
25/// # let mut builder = FunctionBuilder::new(&mut func, &mut fn_builder_ctx, &mut position);
26/// #
27/// # let entry = builder.create_block();
28/// # builder.switch_to_block(entry);
29/// #
30/// let block0 = builder.create_block();
31/// let block1 = builder.create_block();
32/// let block2 = builder.create_block();
33/// let fallback = builder.create_block();
34///
35/// let val = builder.ins().iconst(I32, 1);
36///
37/// let mut switch = Switch::new();
38/// switch.set_entry(0, block0);
39/// switch.set_entry(1, block1);
40/// switch.set_entry(7, block2);
41/// switch.emit(&mut builder, val, fallback);
42/// ```
43#[derive(Debug, Default)]
44pub struct Switch {
45    cases: HashMap<EntryIndex, Block>,
46}
47
48impl Switch {
49    /// Create a new empty switch
50    pub fn new() -> Self {
51        Self {
52            cases: HashMap::new(),
53        }
54    }
55
56    /// Set a switch entry
57    pub fn set_entry(&mut self, index: EntryIndex, block: Block) {
58        let prev = self.cases.insert(index, block);
59        assert!(
60            prev.is_none(),
61            "Tried to set the same entry {} twice",
62            index
63        );
64    }
65
66    /// Get a reference to all existing entries
67    pub fn entries(&self) -> &HashMap<EntryIndex, Block> {
68        &self.cases
69    }
70
71    /// Turn the `cases` `HashMap` into a list of `ContiguousCaseRange`s.
72    ///
73    /// # Postconditions
74    ///
75    /// * Every entry will be represented.
76    /// * The `ContiguousCaseRange`s will not overlap.
77    /// * Between two `ContiguousCaseRange`s there will be at least one entry index.
78    /// * No `ContiguousCaseRange`s will be empty.
79    fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {
80        debug!("build_contiguous_case_ranges before: {:#?}", self.cases);
81        let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();
82        cases.sort_by_key(|&(index, _)| index);
83
84        let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];
85        let mut last_index = None;
86        for (index, block) in cases {
87            match last_index {
88                None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),
89                Some(last_index) => {
90                    if index > last_index + 1 {
91                        contiguous_case_ranges.push(ContiguousCaseRange::new(index));
92                    }
93                }
94            }
95            contiguous_case_ranges
96                .last_mut()
97                .unwrap()
98                .blocks
99                .push(block);
100            last_index = Some(index);
101        }
102
103        debug!(
104            "build_contiguous_case_ranges after: {:#?}",
105            contiguous_case_ranges
106        );
107
108        contiguous_case_ranges
109    }
110
111    /// Binary search for the right `ContiguousCaseRange`.
112    fn build_search_tree(
113        bx: &mut FunctionBuilder,
114        val: Value,
115        otherwise: Block,
116        contiguous_case_ranges: Vec<ContiguousCaseRange>,
117    ) -> Vec<(EntryIndex, Block, Vec<Block>)> {
118        let mut cases_and_jt_blocks = Vec::new();
119
120        // Avoid allocation in the common case
121        if contiguous_case_ranges.len() <= 3 {
122            Self::build_search_branches(
123                bx,
124                val,
125                otherwise,
126                contiguous_case_ranges,
127                &mut cases_and_jt_blocks,
128            );
129            return cases_and_jt_blocks;
130        }
131
132        let mut stack: Vec<(Option<Block>, Vec<ContiguousCaseRange>)> = Vec::new();
133        stack.push((None, contiguous_case_ranges));
134
135        while let Some((block, contiguous_case_ranges)) = stack.pop() {
136            if let Some(block) = block {
137                bx.switch_to_block(block);
138            }
139
140            if contiguous_case_ranges.len() <= 3 {
141                Self::build_search_branches(
142                    bx,
143                    val,
144                    otherwise,
145                    contiguous_case_ranges,
146                    &mut cases_and_jt_blocks,
147                );
148            } else {
149                let split_point = contiguous_case_ranges.len() / 2;
150                let mut left = contiguous_case_ranges;
151                let right = left.split_off(split_point);
152
153                let left_block = bx.create_block();
154                let right_block = bx.create_block();
155
156                let should_take_right_side = bx.ins().icmp_imm(
157                    IntCC::UnsignedGreaterThanOrEqual,
158                    val,
159                    right[0].first_index as i64,
160                );
161                bx.ins().brnz(should_take_right_side, right_block, &[]);
162                bx.ins().jump(left_block, &[]);
163
164                stack.push((Some(left_block), left));
165                stack.push((Some(right_block), right));
166            }
167        }
168
169        cases_and_jt_blocks
170    }
171
172    /// Linear search for the right `ContiguousCaseRange`.
173    fn build_search_branches(
174        bx: &mut FunctionBuilder,
175        val: Value,
176        otherwise: Block,
177        contiguous_case_ranges: Vec<ContiguousCaseRange>,
178        cases_and_jt_blocks: &mut Vec<(EntryIndex, Block, Vec<Block>)>,
179    ) {
180        let mut was_branch = false;
181        let ins_fallthrough_jump = |was_branch: bool, bx: &mut FunctionBuilder| {
182            if was_branch {
183                let block = bx.create_block();
184                bx.ins().jump(block, &[]);
185                bx.switch_to_block(block);
186            }
187        };
188        for ContiguousCaseRange {
189            first_index,
190            blocks,
191        } in contiguous_case_ranges.into_iter().rev()
192        {
193            match (blocks.len(), first_index) {
194                (1, 0) => {
195                    ins_fallthrough_jump(was_branch, bx);
196                    bx.ins().brz(val, blocks[0], &[]);
197                }
198                (1, _) => {
199                    ins_fallthrough_jump(was_branch, bx);
200                    let is_good_val = bx.ins().icmp_imm(IntCC::Equal, val, first_index as i64);
201                    bx.ins().brnz(is_good_val, blocks[0], &[]);
202                }
203                (_, 0) => {
204                    // if `first_index` is 0, then `icmp_imm uge val, first_index` is trivially true
205                    let jt_block = bx.create_block();
206                    bx.ins().jump(jt_block, &[]);
207                    cases_and_jt_blocks.push((first_index, jt_block, blocks));
208                    // `jump otherwise` below must not be hit, because the current block has been
209                    // filled above. This is the last iteration anyway, as 0 is the smallest
210                    // unsigned int, so just return here.
211                    return;
212                }
213                (_, _) => {
214                    ins_fallthrough_jump(was_branch, bx);
215                    let jt_block = bx.create_block();
216                    let is_good_val = bx.ins().icmp_imm(
217                        IntCC::UnsignedGreaterThanOrEqual,
218                        val,
219                        first_index as i64,
220                    );
221                    bx.ins().brnz(is_good_val, jt_block, &[]);
222                    cases_and_jt_blocks.push((first_index, jt_block, blocks));
223                }
224            }
225            was_branch = true;
226        }
227
228        bx.ins().jump(otherwise, &[]);
229    }
230
231    /// For every item in `cases_and_jt_blocks` this will create a jump table in the specified block.
232    fn build_jump_tables(
233        bx: &mut FunctionBuilder,
234        val: Value,
235        otherwise: Block,
236        cases_and_jt_blocks: Vec<(EntryIndex, Block, Vec<Block>)>,
237    ) {
238        for (first_index, jt_block, blocks) in cases_and_jt_blocks.into_iter().rev() {
239            let mut jt_data = JumpTableData::new();
240            for block in blocks {
241                jt_data.push_entry(block);
242            }
243            let jump_table = bx.create_jump_table(jt_data);
244
245            bx.switch_to_block(jt_block);
246            let discr = if first_index == 0 {
247                val
248            } else {
249                bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
250            };
251            bx.ins().br_table(discr, otherwise, jump_table);
252        }
253    }
254
255    /// Build the switch
256    ///
257    /// # Arguments
258    ///
259    /// * The function builder to emit to
260    /// * The value to switch on
261    /// * The default block
262    pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {
263        // FIXME icmp(_imm) doesn't have encodings for i8 and i16 on x86(_64) yet
264        let val = match bx.func.dfg.value_type(val) {
265            types::I8 | types::I16 => bx.ins().uextend(types::I32, val),
266            _ => val,
267        };
268
269        let contiguous_case_ranges = self.collect_contiguous_case_ranges();
270        let cases_and_jt_blocks =
271            Self::build_search_tree(bx, val, otherwise, contiguous_case_ranges);
272        Self::build_jump_tables(bx, val, otherwise, cases_and_jt_blocks);
273    }
274}
275
276/// This represents a contiguous range of cases to switch on.
277///
278/// For example 10 => block1, 11 => block2, 12 => block7 will be represented as:
279///
280/// ```plain
281/// ContiguousCaseRange {
282///     first_index: 10,
283///     blocks: vec![Block::from_u32(1), Block::from_u32(2), Block::from_u32(7)]
284/// }
285/// ```
286#[derive(Debug)]
287struct ContiguousCaseRange {
288    /// The entry index of the first case. Eg. 10 when the entry indexes are 10, 11, 12 and 13.
289    first_index: EntryIndex,
290
291    /// The blocks to jump to sorted in ascending order of entry index.
292    blocks: Vec<Block>,
293}
294
295impl ContiguousCaseRange {
296    fn new(first_index: EntryIndex) -> Self {
297        Self {
298            first_index,
299            blocks: Vec::new(),
300        }
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::frontend::{FunctionBuilderContext, Position};
308    use alloc::string::ToString;
309    use cranelift_codegen::ir::Function;
310
311    macro_rules! setup {
312        ($default:expr, [$($index:expr,)*]) => {{
313            let mut func = Function::new();
314            let mut func_ctx = FunctionBuilderContext::new();
315            let mut position = Position::default();
316            {
317                let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx, &mut position);
318                let block = bx.create_block();
319                bx.switch_to_block(block);
320                let val = bx.ins().iconst(types::I8, 0);
321                let mut switch = Switch::new();
322                $(
323                    let block = bx.create_block();
324                    switch.set_entry($index, block);
325                )*
326                switch.emit(&mut bx, val, Block::with_number($default).unwrap());
327            }
328            func
329                .to_string()
330                .trim_start_matches("function u0:0() fast {\n")
331                .trim_end_matches("\n}\n")
332                .to_string()
333        }};
334    }
335
336    #[test]
337    fn switch_zero() {
338        let func = setup!(0, [0,]);
339        assert_eq!(
340            func,
341            "block0:
342    v0 = iconst.i8 0
343    v1 = uextend.i32 v0
344    brz v1, block1
345    jump block0"
346        );
347    }
348
349    #[test]
350    fn switch_single() {
351        let func = setup!(0, [1,]);
352        assert_eq!(
353            func,
354            "block0:
355    v0 = iconst.i8 0
356    v1 = uextend.i32 v0
357    v2 = icmp_imm eq v1, 1
358    brnz v2, block1
359    jump block0"
360        );
361    }
362
363    #[test]
364    fn switch_bool() {
365        let func = setup!(0, [0, 1,]);
366        assert_eq!(
367            func,
368            "    jt0 = jump_table [block1, block2]
369
370block0:
371    v0 = iconst.i8 0
372    v1 = uextend.i32 v0
373    jump block3
374
375block3:
376    br_table.i32 v1, block0, jt0"
377        );
378    }
379
380    #[test]
381    fn switch_two_gap() {
382        let func = setup!(0, [0, 2,]);
383        assert_eq!(
384            func,
385            "block0:
386    v0 = iconst.i8 0
387    v1 = uextend.i32 v0
388    v2 = icmp_imm eq v1, 2
389    brnz v2, block2
390    jump block3
391
392block3:
393    brz.i32 v1, block1
394    jump block0"
395        );
396    }
397
398    #[test]
399    fn switch_many() {
400        let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
401        assert_eq!(
402            func,
403            "    jt0 = jump_table [block1, block2]
404    jt1 = jump_table [block5, block6, block7]
405
406block0:
407    v0 = iconst.i8 0
408    v1 = uextend.i32 v0
409    v2 = icmp_imm uge v1, 7
410    brnz v2, block9
411    jump block8
412
413block9:
414    v3 = icmp_imm.i32 uge v1, 10
415    brnz v3, block10
416    jump block11
417
418block11:
419    v4 = icmp_imm.i32 eq v1, 7
420    brnz v4, block4
421    jump block0
422
423block8:
424    v5 = icmp_imm.i32 eq v1, 5
425    brnz v5, block3
426    jump block12
427
428block12:
429    br_table.i32 v1, block0, jt0
430
431block10:
432    v6 = iadd_imm.i32 v1, -10
433    br_table v6, block0, jt1"
434        );
435    }
436
437    #[test]
438    fn switch_min_index_value() {
439        let func = setup!(0, [::core::i64::MIN as u64, 1,]);
440        assert_eq!(
441            func,
442            "block0:
443    v0 = iconst.i8 0
444    v1 = uextend.i32 v0
445    v2 = icmp_imm eq v1, 0x8000_0000_0000_0000
446    brnz v2, block1
447    jump block3
448
449block3:
450    v3 = icmp_imm.i32 eq v1, 1
451    brnz v3, block2
452    jump block0"
453        );
454    }
455
456    #[test]
457    fn switch_max_index_value() {
458        let func = setup!(0, [::core::i64::MAX as u64, 1,]);
459        assert_eq!(
460            func,
461            "block0:
462    v0 = iconst.i8 0
463    v1 = uextend.i32 v0
464    v2 = icmp_imm eq v1, 0x7fff_ffff_ffff_ffff
465    brnz v2, block1
466    jump block3
467
468block3:
469    v3 = icmp_imm.i32 eq v1, 1
470    brnz v3, block2
471    jump block0"
472        )
473    }
474
475    #[test]
476    fn switch_optimal_codegen() {
477        let func = setup!(0, [-1i64 as u64, 0, 1,]);
478        assert_eq!(
479            func,
480            "    jt0 = jump_table [block2, block3]
481
482block0:
483    v0 = iconst.i8 0
484    v1 = uextend.i32 v0
485    v2 = icmp_imm eq v1, -1
486    brnz v2, block1
487    jump block4
488
489block4:
490    br_table.i32 v1, block0, jt0"
491        );
492    }
493}