1use super::HashMap;
2use crate::frontend::FunctionBuilder;
3use alloc::vec::Vec;
4use cranelift_codegen::ir::condcodes::IntCC;
5use cranelift_codegen::ir::*;
6use log::debug;
7
8type EntryIndex = u64;
9
10#[derive(Debug, Default)]
44pub struct Switch {
45 cases: HashMap<EntryIndex, Block>,
46}
47
48impl Switch {
49 pub fn new() -> Self {
51 Self {
52 cases: HashMap::new(),
53 }
54 }
55
56 pub fn set_entry(&mut self, index: EntryIndex, block: Block) {
58 let prev = self.cases.insert(index, block);
59 assert!(
60 prev.is_none(),
61 "Tried to set the same entry {} twice",
62 index
63 );
64 }
65
66 pub fn entries(&self) -> &HashMap<EntryIndex, Block> {
68 &self.cases
69 }
70
71 fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {
80 debug!("build_contiguous_case_ranges before: {:#?}", self.cases);
81 let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();
82 cases.sort_by_key(|&(index, _)| index);
83
84 let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];
85 let mut last_index = None;
86 for (index, block) in cases {
87 match last_index {
88 None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),
89 Some(last_index) => {
90 if index > last_index + 1 {
91 contiguous_case_ranges.push(ContiguousCaseRange::new(index));
92 }
93 }
94 }
95 contiguous_case_ranges
96 .last_mut()
97 .unwrap()
98 .blocks
99 .push(block);
100 last_index = Some(index);
101 }
102
103 debug!(
104 "build_contiguous_case_ranges after: {:#?}",
105 contiguous_case_ranges
106 );
107
108 contiguous_case_ranges
109 }
110
111 fn build_search_tree(
113 bx: &mut FunctionBuilder,
114 val: Value,
115 otherwise: Block,
116 contiguous_case_ranges: Vec<ContiguousCaseRange>,
117 ) -> Vec<(EntryIndex, Block, Vec<Block>)> {
118 let mut cases_and_jt_blocks = Vec::new();
119
120 if contiguous_case_ranges.len() <= 3 {
122 Self::build_search_branches(
123 bx,
124 val,
125 otherwise,
126 contiguous_case_ranges,
127 &mut cases_and_jt_blocks,
128 );
129 return cases_and_jt_blocks;
130 }
131
132 let mut stack: Vec<(Option<Block>, Vec<ContiguousCaseRange>)> = Vec::new();
133 stack.push((None, contiguous_case_ranges));
134
135 while let Some((block, contiguous_case_ranges)) = stack.pop() {
136 if let Some(block) = block {
137 bx.switch_to_block(block);
138 }
139
140 if contiguous_case_ranges.len() <= 3 {
141 Self::build_search_branches(
142 bx,
143 val,
144 otherwise,
145 contiguous_case_ranges,
146 &mut cases_and_jt_blocks,
147 );
148 } else {
149 let split_point = contiguous_case_ranges.len() / 2;
150 let mut left = contiguous_case_ranges;
151 let right = left.split_off(split_point);
152
153 let left_block = bx.create_block();
154 let right_block = bx.create_block();
155
156 let should_take_right_side = bx.ins().icmp_imm(
157 IntCC::UnsignedGreaterThanOrEqual,
158 val,
159 right[0].first_index as i64,
160 );
161 bx.ins().brnz(should_take_right_side, right_block, &[]);
162 bx.ins().jump(left_block, &[]);
163
164 stack.push((Some(left_block), left));
165 stack.push((Some(right_block), right));
166 }
167 }
168
169 cases_and_jt_blocks
170 }
171
172 fn build_search_branches(
174 bx: &mut FunctionBuilder,
175 val: Value,
176 otherwise: Block,
177 contiguous_case_ranges: Vec<ContiguousCaseRange>,
178 cases_and_jt_blocks: &mut Vec<(EntryIndex, Block, Vec<Block>)>,
179 ) {
180 let mut was_branch = false;
181 let ins_fallthrough_jump = |was_branch: bool, bx: &mut FunctionBuilder| {
182 if was_branch {
183 let block = bx.create_block();
184 bx.ins().jump(block, &[]);
185 bx.switch_to_block(block);
186 }
187 };
188 for ContiguousCaseRange {
189 first_index,
190 blocks,
191 } in contiguous_case_ranges.into_iter().rev()
192 {
193 match (blocks.len(), first_index) {
194 (1, 0) => {
195 ins_fallthrough_jump(was_branch, bx);
196 bx.ins().brz(val, blocks[0], &[]);
197 }
198 (1, _) => {
199 ins_fallthrough_jump(was_branch, bx);
200 let is_good_val = bx.ins().icmp_imm(IntCC::Equal, val, first_index as i64);
201 bx.ins().brnz(is_good_val, blocks[0], &[]);
202 }
203 (_, 0) => {
204 let jt_block = bx.create_block();
206 bx.ins().jump(jt_block, &[]);
207 cases_and_jt_blocks.push((first_index, jt_block, blocks));
208 return;
212 }
213 (_, _) => {
214 ins_fallthrough_jump(was_branch, bx);
215 let jt_block = bx.create_block();
216 let is_good_val = bx.ins().icmp_imm(
217 IntCC::UnsignedGreaterThanOrEqual,
218 val,
219 first_index as i64,
220 );
221 bx.ins().brnz(is_good_val, jt_block, &[]);
222 cases_and_jt_blocks.push((first_index, jt_block, blocks));
223 }
224 }
225 was_branch = true;
226 }
227
228 bx.ins().jump(otherwise, &[]);
229 }
230
231 fn build_jump_tables(
233 bx: &mut FunctionBuilder,
234 val: Value,
235 otherwise: Block,
236 cases_and_jt_blocks: Vec<(EntryIndex, Block, Vec<Block>)>,
237 ) {
238 for (first_index, jt_block, blocks) in cases_and_jt_blocks.into_iter().rev() {
239 let mut jt_data = JumpTableData::new();
240 for block in blocks {
241 jt_data.push_entry(block);
242 }
243 let jump_table = bx.create_jump_table(jt_data);
244
245 bx.switch_to_block(jt_block);
246 let discr = if first_index == 0 {
247 val
248 } else {
249 bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
250 };
251 bx.ins().br_table(discr, otherwise, jump_table);
252 }
253 }
254
255 pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {
263 let val = match bx.func.dfg.value_type(val) {
265 types::I8 | types::I16 => bx.ins().uextend(types::I32, val),
266 _ => val,
267 };
268
269 let contiguous_case_ranges = self.collect_contiguous_case_ranges();
270 let cases_and_jt_blocks =
271 Self::build_search_tree(bx, val, otherwise, contiguous_case_ranges);
272 Self::build_jump_tables(bx, val, otherwise, cases_and_jt_blocks);
273 }
274}
275
276#[derive(Debug)]
287struct ContiguousCaseRange {
288 first_index: EntryIndex,
290
291 blocks: Vec<Block>,
293}
294
295impl ContiguousCaseRange {
296 fn new(first_index: EntryIndex) -> Self {
297 Self {
298 first_index,
299 blocks: Vec::new(),
300 }
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use crate::frontend::{FunctionBuilderContext, Position};
308 use alloc::string::ToString;
309 use cranelift_codegen::ir::Function;
310
311 macro_rules! setup {
312 ($default:expr, [$($index:expr,)*]) => {{
313 let mut func = Function::new();
314 let mut func_ctx = FunctionBuilderContext::new();
315 let mut position = Position::default();
316 {
317 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx, &mut position);
318 let block = bx.create_block();
319 bx.switch_to_block(block);
320 let val = bx.ins().iconst(types::I8, 0);
321 let mut switch = Switch::new();
322 $(
323 let block = bx.create_block();
324 switch.set_entry($index, block);
325 )*
326 switch.emit(&mut bx, val, Block::with_number($default).unwrap());
327 }
328 func
329 .to_string()
330 .trim_start_matches("function u0:0() fast {\n")
331 .trim_end_matches("\n}\n")
332 .to_string()
333 }};
334 }
335
336 #[test]
337 fn switch_zero() {
338 let func = setup!(0, [0,]);
339 assert_eq!(
340 func,
341 "block0:
342 v0 = iconst.i8 0
343 v1 = uextend.i32 v0
344 brz v1, block1
345 jump block0"
346 );
347 }
348
349 #[test]
350 fn switch_single() {
351 let func = setup!(0, [1,]);
352 assert_eq!(
353 func,
354 "block0:
355 v0 = iconst.i8 0
356 v1 = uextend.i32 v0
357 v2 = icmp_imm eq v1, 1
358 brnz v2, block1
359 jump block0"
360 );
361 }
362
363 #[test]
364 fn switch_bool() {
365 let func = setup!(0, [0, 1,]);
366 assert_eq!(
367 func,
368 " jt0 = jump_table [block1, block2]
369
370block0:
371 v0 = iconst.i8 0
372 v1 = uextend.i32 v0
373 jump block3
374
375block3:
376 br_table.i32 v1, block0, jt0"
377 );
378 }
379
380 #[test]
381 fn switch_two_gap() {
382 let func = setup!(0, [0, 2,]);
383 assert_eq!(
384 func,
385 "block0:
386 v0 = iconst.i8 0
387 v1 = uextend.i32 v0
388 v2 = icmp_imm eq v1, 2
389 brnz v2, block2
390 jump block3
391
392block3:
393 brz.i32 v1, block1
394 jump block0"
395 );
396 }
397
398 #[test]
399 fn switch_many() {
400 let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
401 assert_eq!(
402 func,
403 " jt0 = jump_table [block1, block2]
404 jt1 = jump_table [block5, block6, block7]
405
406block0:
407 v0 = iconst.i8 0
408 v1 = uextend.i32 v0
409 v2 = icmp_imm uge v1, 7
410 brnz v2, block9
411 jump block8
412
413block9:
414 v3 = icmp_imm.i32 uge v1, 10
415 brnz v3, block10
416 jump block11
417
418block11:
419 v4 = icmp_imm.i32 eq v1, 7
420 brnz v4, block4
421 jump block0
422
423block8:
424 v5 = icmp_imm.i32 eq v1, 5
425 brnz v5, block3
426 jump block12
427
428block12:
429 br_table.i32 v1, block0, jt0
430
431block10:
432 v6 = iadd_imm.i32 v1, -10
433 br_table v6, block0, jt1"
434 );
435 }
436
437 #[test]
438 fn switch_min_index_value() {
439 let func = setup!(0, [::core::i64::MIN as u64, 1,]);
440 assert_eq!(
441 func,
442 "block0:
443 v0 = iconst.i8 0
444 v1 = uextend.i32 v0
445 v2 = icmp_imm eq v1, 0x8000_0000_0000_0000
446 brnz v2, block1
447 jump block3
448
449block3:
450 v3 = icmp_imm.i32 eq v1, 1
451 brnz v3, block2
452 jump block0"
453 );
454 }
455
456 #[test]
457 fn switch_max_index_value() {
458 let func = setup!(0, [::core::i64::MAX as u64, 1,]);
459 assert_eq!(
460 func,
461 "block0:
462 v0 = iconst.i8 0
463 v1 = uextend.i32 v0
464 v2 = icmp_imm eq v1, 0x7fff_ffff_ffff_ffff
465 brnz v2, block1
466 jump block3
467
468block3:
469 v3 = icmp_imm.i32 eq v1, 1
470 brnz v3, block2
471 jump block0"
472 )
473 }
474
475 #[test]
476 fn switch_optimal_codegen() {
477 let func = setup!(0, [-1i64 as u64, 0, 1,]);
478 assert_eq!(
479 func,
480 " jt0 = jump_table [block2, block3]
481
482block0:
483 v0 = iconst.i8 0
484 v1 = uextend.i32 v0
485 v2 = icmp_imm eq v1, -1
486 brnz v2, block1
487 jump block4
488
489block4:
490 br_table.i32 v1, block0, jt0"
491 );
492 }
493}