1use std::{collections::BTreeSet, ops};
8
9use cranelift_entity::SecondaryMap;
10
11use crate::cfg::ControlFlowGraph;
12
13use sonatina_ir::{
14 func_cursor::{CursorLocation, FuncCursor, InsnInserter},
15 insn::{BinaryOp, CastOp, InsnData, UnaryOp},
16 Block, Function, Immediate, Insn, Type, Value,
17};
18
19#[derive(Debug)]
20pub struct SccpSolver {
21 lattice: SecondaryMap<Value, LatticeCell>,
22 reachable_edges: BTreeSet<FlowEdge>,
23 reachable_blocks: BTreeSet<Block>,
24
25 flow_work: Vec<FlowEdge>,
26 ssa_work: Vec<Value>,
27}
28
29impl SccpSolver {
30 pub fn new() -> Self {
31 Self {
32 lattice: SecondaryMap::default(),
33 reachable_edges: BTreeSet::default(),
34 reachable_blocks: BTreeSet::default(),
35 flow_work: Vec::default(),
36 ssa_work: Vec::default(),
37 }
38 }
39 pub fn run(&mut self, func: &mut Function, cfg: &mut ControlFlowGraph) {
40 self.clear();
41
42 let entry_block = match func.layout.entry_block() {
43 Some(block) => block,
44 _ => return,
45 };
46
47 for arg in &func.arg_values {
49 self.lattice[*arg] = LatticeCell::Top;
50 }
51
52 self.reachable_blocks.insert(entry_block);
54 self.eval_insns_in(func, entry_block);
55
56 let mut changed = true;
57 while changed {
58 changed = false;
59
60 while let Some(edge) = self.flow_work.pop() {
61 changed = true;
62 self.eval_edge(func, edge);
63 }
64
65 while let Some(value) = self.ssa_work.pop() {
66 changed = true;
67 for &user in func.dfg.users(value) {
68 let user_block = func.layout.insn_block(user);
69 if self.reachable_blocks.contains(&user_block) {
70 if func.dfg.is_phi(user) {
71 self.eval_phi(func, user);
72 } else {
73 self.eval_insn(func, user);
74 }
75 }
76 }
77 }
78 }
79
80 self.remove_unreachable_edges(func);
81 cfg.compute(func);
82 self.fold_insns(func, cfg);
83 }
84
85 pub fn clear(&mut self) {
86 self.lattice.clear();
87 self.reachable_edges.clear();
88 self.reachable_blocks.clear();
89 self.flow_work.clear();
90 self.ssa_work.clear();
91 }
92
93 fn eval_edge(&mut self, func: &mut Function, edge: FlowEdge) {
94 let dest = edge.to;
95
96 if self.reachable_edges.contains(&edge) {
97 return;
98 }
99 self.reachable_edges.insert(edge);
100
101 if self.reachable_blocks.contains(&dest) {
102 self.eval_phis_in(func, dest);
103 } else {
104 self.reachable_blocks.insert(dest);
105 self.eval_insns_in(func, dest);
106 }
107
108 if let Some(last_insn) = func.layout.last_insn_of(dest) {
109 let branch_info = func.dfg.analyze_branch(last_insn);
110 if branch_info.dests_num() == 1 {
111 self.flow_work.push(FlowEdge::new(
112 last_insn,
113 branch_info.iter_dests().next().unwrap(),
114 ))
115 }
116 }
117 }
118
119 fn eval_phis_in(&mut self, func: &Function, block: Block) {
120 for insn in func.layout.iter_insn(block) {
121 if func.dfg.is_phi(insn) {
122 self.eval_phi(func, insn);
123 }
124 }
125 }
126
127 fn eval_phi(&mut self, func: &Function, insn: Insn) {
128 debug_assert!(func.dfg.is_phi(insn));
129 debug_assert!(self
130 .reachable_blocks
131 .contains(&func.layout.insn_block(insn)));
132
133 for &arg in func.dfg.insn_args(insn) {
134 if let Some(imm) = func.dfg.value_imm(arg) {
135 self.set_lattice_cell(arg, LatticeCell::Const(imm));
136 }
137 }
138
139 let block = func.layout.insn_block(insn);
140
141 let mut eval_result = LatticeCell::Bot;
142 for (i, from) in func.dfg.phi_blocks(insn).iter().enumerate() {
143 if self.is_reachable(func, *from, block) {
144 let phi_arg = func.dfg.insn_arg(insn, i);
145 let v_cell = self.lattice[phi_arg];
146 eval_result = eval_result.join(v_cell);
147 }
148 }
149
150 let phi_value = func.dfg.insn_result(insn).unwrap();
151 if eval_result != self.lattice[phi_value] {
152 self.ssa_work.push(phi_value);
153 self.lattice[phi_value] = eval_result;
154 }
155 }
156
157 fn eval_insns_in(&mut self, func: &Function, block: Block) {
158 for insn in func.layout.iter_insn(block) {
159 if func.dfg.is_phi(insn) {
160 self.eval_phi(func, insn);
161 } else {
162 self.eval_insn(func, insn);
163 }
164 }
165 }
166
167 fn eval_insn(&mut self, func: &Function, insn: Insn) {
168 debug_assert!(!func.dfg.is_phi(insn));
169 for &arg in func.dfg.insn_args(insn) {
170 if let Some(imm) = func.dfg.value_imm(arg) {
171 self.set_lattice_cell(arg, LatticeCell::Const(imm));
172 }
173 }
174
175 let cell = match func.dfg.insn_data(insn) {
176 InsnData::Unary { code, args } => {
177 let arg_cell = self.lattice[args[0]];
178 match *code {
179 UnaryOp::Not => arg_cell.not(),
180 UnaryOp::Neg => arg_cell.neg(),
181 }
182 }
183
184 InsnData::Binary { code, args } => {
185 let lhs = self.lattice[args[0]];
186 let rhs = self.lattice[args[1]];
187 match *code {
188 BinaryOp::Add => lhs.add(rhs),
189 BinaryOp::Sub => lhs.sub(rhs),
190 BinaryOp::Mul => lhs.mul(rhs),
191 BinaryOp::Udiv => lhs.udiv(rhs),
192 BinaryOp::Sdiv => lhs.sdiv(rhs),
193 BinaryOp::Lt => lhs.lt(rhs),
194 BinaryOp::Gt => lhs.gt(rhs),
195 BinaryOp::Slt => lhs.slt(rhs),
196 BinaryOp::Sgt => lhs.sgt(rhs),
197 BinaryOp::Le => lhs.le(rhs),
198 BinaryOp::Ge => lhs.ge(rhs),
199 BinaryOp::Sle => lhs.sle(rhs),
200 BinaryOp::Sge => lhs.sge(rhs),
201 BinaryOp::Eq => lhs.eq(rhs),
202 BinaryOp::Ne => lhs.ne(rhs),
203 BinaryOp::And => lhs.and(rhs),
204 BinaryOp::Or => lhs.or(rhs),
205 BinaryOp::Xor => lhs.xor(rhs),
206 }
207 }
208
209 InsnData::Cast { code, args, ty } => {
210 let arg_cell = self.lattice[args[0]];
211 match code {
212 CastOp::Sext => arg_cell.sext(*ty),
213 CastOp::Zext => arg_cell.zext(*ty),
214 CastOp::Trunc => arg_cell.trunc(*ty),
215 CastOp::BitCast => LatticeCell::Top,
216 }
217 }
218
219 InsnData::Load { .. } => LatticeCell::Top,
220
221 InsnData::Call { .. } => LatticeCell::Top,
222
223 InsnData::Jump { dests, .. } => {
224 self.flow_work.push(FlowEdge::new(insn, dests[0]));
225 return;
226 }
227
228 InsnData::Branch { args, dests } => {
229 let v_cell = self.lattice[args[0]];
230
231 if v_cell.is_top() {
232 self.flow_work.push(FlowEdge::new(insn, dests[0]));
234 self.flow_work.push(FlowEdge::new(insn, dests[1]));
235 } else if v_cell.is_bot() {
236 unreachable!();
237 } else if v_cell.is_zero() {
238 self.flow_work.push(FlowEdge::new(insn, dests[1]));
240 } else {
241 self.flow_work.push(FlowEdge::new(insn, dests[0]));
243 }
244
245 return;
246 }
247
248 InsnData::BrTable {
249 args,
250 default,
251 table,
252 } => {
253 let mut add_all_dest = || {
255 if let Some(default) = default {
256 self.flow_work.push(FlowEdge::new(insn, *default));
257 }
258 for dest in table {
259 self.flow_work.push(FlowEdge::new(insn, *dest));
260 }
261 };
262
263 let v_cell = self.lattice[args[0]];
264
265 if v_cell.is_top() {
267 add_all_dest();
268 return;
269 }
270
271 if v_cell.is_bot() {
274 unreachable!()
275 }
276
277 let mut contains_top = false;
278 for (value, dest) in args[1..].iter().zip(table.iter()) {
279 if self.lattice[*value] == v_cell {
280 self.flow_work.push(FlowEdge::new(insn, *dest));
281 return;
282 } else if v_cell.is_top() {
283 contains_top = true;
284 }
285 }
286
287 if contains_top {
288 add_all_dest();
290 } else {
291 if let Some(default) = default {
293 self.flow_work.push(FlowEdge::new(insn, *default));
294 }
295 }
296
297 return;
298 }
299
300 InsnData::Alloca { .. } | InsnData::Gep { .. } => LatticeCell::Top,
301
302 InsnData::Store { .. } | InsnData::Return { .. } => {
303 return;
305 }
306
307 InsnData::Phi { .. } => unreachable!(),
308 };
309
310 let insn_result = func.dfg.insn_result(insn).unwrap();
311 self.set_lattice_cell(insn_result, cell);
312 }
313
314 fn remove_unreachable_edges(&self, func: &mut Function) {
316 let entry_block = func.layout.entry_block().unwrap();
317 let mut inserter = InsnInserter::new(func, CursorLocation::BlockTop(entry_block));
318
319 loop {
320 match inserter.loc() {
321 CursorLocation::BlockTop(block) => {
322 if !self.reachable_blocks.contains(&block) {
323 inserter.remove_block();
324 } else {
325 inserter.proceed();
326 }
327 }
328
329 CursorLocation::BlockBottom(..) => inserter.proceed(),
330
331 CursorLocation::At(insn) => {
332 if inserter.func().dfg.is_branch(insn) {
333 let branch_info = inserter.func().dfg.analyze_branch(insn);
334 for dest in branch_info.iter_dests().collect::<Vec<_>>() {
335 if !self.is_reachable_edge(insn, dest) {
336 inserter.func_mut().dfg.remove_branch_dest(insn, dest);
337 }
338 }
339 }
340 inserter.proceed();
341 }
342
343 CursorLocation::NoWhere => break,
344 }
345 }
346 }
347
348 fn is_reachable_edge(&self, insn: Insn, dest: Block) -> bool {
349 self.reachable_edges.contains(&FlowEdge::new(insn, dest))
350 }
351
352 fn fold_insns(&mut self, func: &mut Function, cfg: &ControlFlowGraph) {
353 let mut rpo: Vec<_> = cfg.post_order().collect();
354 rpo.reverse();
355
356 for block in rpo {
357 let mut next_insn = func.layout.first_insn_of(block);
358 while let Some(insn) = next_insn {
359 next_insn = func.layout.next_insn_of(insn);
360 self.fold(func, insn);
361 }
362 }
363 }
364
365 fn fold(&self, func: &mut Function, insn: Insn) {
366 let insn_result = match func.dfg.insn_result(insn) {
367 Some(result) => result,
368 None => return,
369 };
370
371 match self.lattice[insn_result].to_imm() {
372 Some(imm) => {
373 InsnInserter::new(func, CursorLocation::At(insn)).remove_insn();
374 let new_value = func.dfg.make_imm_value(imm);
375 func.dfg.change_to_alias(insn_result, new_value);
376 }
377 None => {
378 if func.dfg.is_phi(insn) {
379 self.try_fold_phi(func, insn)
380 }
381 }
382 }
383 }
384
385 fn try_fold_phi(&self, func: &mut Function, insn: Insn) {
386 debug_assert!(func.dfg.is_phi(insn));
387
388 let mut blocks = func.dfg.phi_blocks(insn).to_vec();
389 blocks.retain(|block| !self.reachable_blocks.contains(block));
390 for block in blocks {
391 func.dfg.remove_phi_arg(insn, block);
392 }
393
394 if func.dfg.insn_args_num(insn) == 1 {
396 let phi_value = func.dfg.insn_result(insn).unwrap();
397 func.dfg
398 .change_to_alias(phi_value, func.dfg.insn_arg(insn, 0));
399 InsnInserter::new(func, CursorLocation::At(insn)).remove_insn();
400 }
401 }
402
403 fn is_reachable(&self, func: &Function, from: Block, to: Block) -> bool {
404 let last_insn = if let Some(insn) = func.layout.last_insn_of(from) {
405 insn
406 } else {
407 return false;
408 };
409 for dest in func.dfg.analyze_branch(last_insn).iter_dests() {
410 if dest == to
411 && self
412 .reachable_edges
413 .contains(&FlowEdge::new(last_insn, dest))
414 {
415 return true;
416 }
417 }
418 false
419 }
420
421 fn set_lattice_cell(&mut self, value: Value, cell: LatticeCell) {
422 let old_cell = &self.lattice[value];
423 if old_cell != &cell {
424 self.lattice[value] = cell;
425 self.ssa_work.push(value);
426 }
427 }
428}
429
430impl Default for SccpSolver {
431 fn default() -> Self {
432 Self::new()
433 }
434}
435
436#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
437struct FlowEdge {
438 insn: Insn,
439 to: Block,
440}
441
442impl FlowEdge {
443 fn new(insn: Insn, to: Block) -> Self {
444 Self { insn, to }
445 }
446}
447
448#[derive(Debug, Clone, Copy)]
449enum LatticeCell {
450 Top,
451 Const(Immediate),
452 Bot,
453}
454
455impl PartialEq for LatticeCell {
456 fn eq(&self, rhs: &Self) -> bool {
457 match (self, rhs) {
458 (Self::Top, Self::Top) | (Self::Bot, Self::Bot) => true,
459 (Self::Const(v1), Self::Const(v2)) => v1 == v2,
460 _ => false,
461 }
462 }
463}
464
465impl LatticeCell {
466 fn to_imm(self) -> Option<Immediate> {
467 match self {
468 Self::Top | Self::Bot => None,
469 Self::Const(imm) => Some(imm),
470 }
471 }
472
473 fn is_zero(self) -> bool {
474 match self {
475 Self::Top | Self::Bot => false,
476 Self::Const(c) => c.is_zero(),
477 }
478 }
479
480 fn is_top(self) -> bool {
481 matches!(self, Self::Top)
482 }
483
484 fn is_bot(self) -> bool {
485 matches!(self, Self::Bot)
486 }
487
488 fn join(self, rhs: Self) -> Self {
489 match (self, rhs) {
490 (Self::Top, _) | (_, Self::Top) => Self::Top,
491 (Self::Const(v1), Self::Const(v2)) => {
492 if v1 == v2 {
493 self
494 } else {
495 Self::Top
496 }
497 }
498 (Self::Bot, other) | (other, Self::Bot) => other,
499 }
500 }
501
502 fn apply_unop<F>(self, f: F) -> Self
503 where
504 F: FnOnce(Immediate) -> Immediate,
505 {
506 match self {
507 Self::Top => Self::Top,
508 Self::Const(lhs) => Self::Const(f(lhs)),
509 Self::Bot => Self::Bot,
510 }
511 }
512
513 fn apply_binop<F>(self, rhs: Self, f: F) -> Self
514 where
515 F: FnOnce(Immediate, Immediate) -> Immediate,
516 {
517 match (self, rhs) {
518 (Self::Top, _) | (_, Self::Top) => Self::Top,
519 (Self::Const(lhs), Self::Const(rhs)) => Self::Const(f(lhs, rhs)),
520 (Self::Bot, _) | (_, Self::Bot) => Self::Bot,
521 }
522 }
523
524 fn not(self) -> Self {
525 self.apply_unop(ops::Not::not)
526 }
527
528 fn neg(self) -> Self {
529 self.apply_unop(ops::Neg::neg)
530 }
531
532 fn add(self, rhs: Self) -> Self {
533 self.apply_binop(rhs, ops::Add::add)
534 }
535
536 fn sub(self, rhs: Self) -> Self {
537 self.apply_binop(rhs, ops::Sub::sub)
538 }
539
540 fn mul(self, rhs: Self) -> Self {
541 self.apply_binop(rhs, ops::Mul::mul)
542 }
543
544 fn udiv(self, rhs: Self) -> Self {
545 self.apply_binop(rhs, Immediate::udiv)
546 }
547
548 fn sdiv(self, rhs: Self) -> Self {
549 self.apply_binop(rhs, Immediate::sdiv)
550 }
551
552 fn lt(self, rhs: Self) -> Self {
553 self.apply_binop(rhs, Immediate::lt)
554 }
555
556 fn gt(self, rhs: Self) -> Self {
557 self.apply_binop(rhs, Immediate::gt)
558 }
559
560 fn slt(self, rhs: Self) -> Self {
561 self.apply_binop(rhs, Immediate::slt)
562 }
563
564 fn sgt(self, rhs: Self) -> Self {
565 self.apply_binop(rhs, Immediate::sgt)
566 }
567
568 fn le(self, rhs: Self) -> Self {
569 self.apply_binop(rhs, Immediate::le)
570 }
571
572 fn ge(self, rhs: Self) -> Self {
573 self.apply_binop(rhs, Immediate::ge)
574 }
575
576 fn sle(self, rhs: Self) -> Self {
577 self.apply_binop(rhs, Immediate::sle)
578 }
579
580 fn sge(self, rhs: Self) -> Self {
581 self.apply_binop(rhs, Immediate::sge)
582 }
583
584 fn eq(self, rhs: Self) -> Self {
585 self.apply_binop(rhs, Immediate::imm_eq)
586 }
587
588 fn ne(self, rhs: Self) -> Self {
589 self.apply_binop(rhs, Immediate::imm_ne)
590 }
591
592 fn and(self, rhs: Self) -> Self {
593 self.apply_binop(rhs, ops::BitAnd::bitand)
594 }
595
596 fn or(self, rhs: Self) -> Self {
597 self.apply_binop(rhs, ops::BitOr::bitor)
598 }
599
600 fn xor(self, rhs: Self) -> Self {
601 self.apply_binop(rhs, ops::BitXor::bitxor)
602 }
603
604 fn sext(self, ty: Type) -> Self {
605 self.apply_unop(|val| Immediate::sext(val, ty))
606 }
607
608 fn zext(self, ty: Type) -> Self {
609 self.apply_unop(|val| Immediate::zext(val, ty))
610 }
611
612 fn trunc(self, ty: Type) -> Self {
613 self.apply_unop(|val| Immediate::trunc(val, ty))
614 }
615}
616
617impl Default for LatticeCell {
618 fn default() -> Self {
619 Self::Bot
620 }
621}