wave_compiler/mir/
function.rs1use std::collections::HashMap;
10
11use super::basic_block::BasicBlock;
12use super::types::MirType;
13use super::value::{BlockId, ValueId};
14
15#[derive(Debug, Clone, PartialEq)]
17pub struct MirParam {
18 pub value: ValueId,
20 pub ty: MirType,
22 pub name: String,
24}
25
26#[derive(Debug, Clone)]
28pub struct MirFunction {
29 pub name: String,
31 pub params: Vec<MirParam>,
33 pub blocks: Vec<BasicBlock>,
35 pub entry: BlockId,
37 pub value_types: HashMap<ValueId, MirType>,
39}
40
41impl MirFunction {
42 #[must_use]
44 pub fn new(name: String, entry: BlockId) -> Self {
45 Self {
46 name,
47 params: Vec::new(),
48 blocks: Vec::new(),
49 entry,
50 value_types: HashMap::new(),
51 }
52 }
53
54 #[must_use]
56 pub fn block(&self, id: BlockId) -> Option<&BasicBlock> {
57 self.blocks.iter().find(|b| b.id == id)
58 }
59
60 pub fn block_mut(&mut self, id: BlockId) -> Option<&mut BasicBlock> {
62 self.blocks.iter_mut().find(|b| b.id == id)
63 }
64
65 #[must_use]
67 pub fn block_ids(&self) -> Vec<BlockId> {
68 self.blocks.iter().map(|b| b.id).collect()
69 }
70
71 #[must_use]
73 pub fn block_count(&self) -> usize {
74 self.blocks.len()
75 }
76
77 pub fn set_type(&mut self, value: ValueId, ty: MirType) {
79 self.value_types.insert(value, ty);
80 }
81
82 #[must_use]
84 pub fn get_type(&self, value: ValueId) -> Option<MirType> {
85 self.value_types.get(&value).copied()
86 }
87
88 #[must_use]
90 pub fn predecessors(&self) -> HashMap<BlockId, Vec<BlockId>> {
91 let mut preds: HashMap<BlockId, Vec<BlockId>> = HashMap::new();
92 for block in &self.blocks {
93 preds.entry(block.id).or_default();
94 for succ in block.successors() {
95 preds.entry(succ).or_default().push(block.id);
96 }
97 }
98 preds
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use crate::mir::basic_block::{BasicBlock, Terminator};
106
107 #[test]
108 fn test_mir_function_blocks() {
109 let mut func = MirFunction::new("test".into(), BlockId(0));
110 let mut bb0 = BasicBlock::new(BlockId(0));
111 bb0.terminator = Terminator::Branch { target: BlockId(1) };
112 let bb1 = BasicBlock::new(BlockId(1));
113 func.blocks.push(bb0);
114 func.blocks.push(bb1);
115
116 assert_eq!(func.block_count(), 2);
117 assert!(func.block(BlockId(0)).is_some());
118 assert!(func.block(BlockId(2)).is_none());
119 assert_eq!(func.block_ids(), vec![BlockId(0), BlockId(1)]);
120 }
121
122 #[test]
123 fn test_predecessors() {
124 let mut func = MirFunction::new("test".into(), BlockId(0));
125 let mut bb0 = BasicBlock::new(BlockId(0));
126 bb0.terminator = Terminator::CondBranch {
127 cond: ValueId(0),
128 true_target: BlockId(1),
129 false_target: BlockId(2),
130 };
131 let bb1 = BasicBlock::new(BlockId(1));
132 let bb2 = BasicBlock::new(BlockId(2));
133 func.blocks.push(bb0);
134 func.blocks.push(bb1);
135 func.blocks.push(bb2);
136
137 let preds = func.predecessors();
138 assert!(preds[&BlockId(0)].is_empty());
139 assert_eq!(preds[&BlockId(1)], vec![BlockId(0)]);
140 assert_eq!(preds[&BlockId(2)], vec![BlockId(0)]);
141 }
142
143 #[test]
144 fn test_value_types() {
145 let mut func = MirFunction::new("test".into(), BlockId(0));
146 func.set_type(ValueId(0), MirType::I32);
147 func.set_type(ValueId(1), MirType::F32);
148 assert_eq!(func.get_type(ValueId(0)), Some(MirType::I32));
149 assert_eq!(func.get_type(ValueId(1)), Some(MirType::F32));
150 assert_eq!(func.get_type(ValueId(99)), None);
151 }
152}