Skip to main content

wave_compiler/mir/
function.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! MIR function/kernel representation with control flow graph.
5//!
6//! A `MirFunction` contains basic blocks forming a CFG, along with
7//! type information for all SSA values and kernel parameter metadata.
8
9use std::collections::HashMap;
10
11use super::basic_block::BasicBlock;
12use super::types::MirType;
13use super::value::{BlockId, ValueId};
14
15/// A kernel parameter in MIR form.
16#[derive(Debug, Clone, PartialEq)]
17pub struct MirParam {
18    /// SSA value representing this parameter.
19    pub value: ValueId,
20    /// Parameter type.
21    pub ty: MirType,
22    /// Parameter name (for debugging).
23    pub name: String,
24}
25
26/// A function/kernel in MIR with a control flow graph.
27#[derive(Debug, Clone)]
28pub struct MirFunction {
29    /// Kernel name.
30    pub name: String,
31    /// Kernel parameters.
32    pub params: Vec<MirParam>,
33    /// Basic blocks forming the CFG.
34    pub blocks: Vec<BasicBlock>,
35    /// Entry block ID.
36    pub entry: BlockId,
37    /// Type mapping for all SSA values.
38    pub value_types: HashMap<ValueId, MirType>,
39}
40
41impl MirFunction {
42    /// Create a new MIR function.
43    #[must_use]
44    pub fn new(name: String, entry: BlockId) -> Self {
45        Self {
46            name,
47            params: Vec::new(),
48            blocks: Vec::new(),
49            entry,
50            value_types: HashMap::new(),
51        }
52    }
53
54    /// Get a basic block by ID.
55    #[must_use]
56    pub fn block(&self, id: BlockId) -> Option<&BasicBlock> {
57        self.blocks.iter().find(|b| b.id == id)
58    }
59
60    /// Get a mutable reference to a basic block by ID.
61    pub fn block_mut(&mut self, id: BlockId) -> Option<&mut BasicBlock> {
62        self.blocks.iter_mut().find(|b| b.id == id)
63    }
64
65    /// Returns all block IDs in the function.
66    #[must_use]
67    pub fn block_ids(&self) -> Vec<BlockId> {
68        self.blocks.iter().map(|b| b.id).collect()
69    }
70
71    /// Returns the number of basic blocks.
72    #[must_use]
73    pub fn block_count(&self) -> usize {
74        self.blocks.len()
75    }
76
77    /// Add a type mapping for a value.
78    pub fn set_type(&mut self, value: ValueId, ty: MirType) {
79        self.value_types.insert(value, ty);
80    }
81
82    /// Get the type of a value.
83    #[must_use]
84    pub fn get_type(&self, value: ValueId) -> Option<MirType> {
85        self.value_types.get(&value).copied()
86    }
87
88    /// Compute predecessor blocks for each block.
89    #[must_use]
90    pub fn predecessors(&self) -> HashMap<BlockId, Vec<BlockId>> {
91        let mut preds: HashMap<BlockId, Vec<BlockId>> = HashMap::new();
92        for block in &self.blocks {
93            preds.entry(block.id).or_default();
94            for succ in block.successors() {
95                preds.entry(succ).or_default().push(block.id);
96            }
97        }
98        preds
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use crate::mir::basic_block::{BasicBlock, Terminator};
106
107    #[test]
108    fn test_mir_function_blocks() {
109        let mut func = MirFunction::new("test".into(), BlockId(0));
110        let mut bb0 = BasicBlock::new(BlockId(0));
111        bb0.terminator = Terminator::Branch { target: BlockId(1) };
112        let bb1 = BasicBlock::new(BlockId(1));
113        func.blocks.push(bb0);
114        func.blocks.push(bb1);
115
116        assert_eq!(func.block_count(), 2);
117        assert!(func.block(BlockId(0)).is_some());
118        assert!(func.block(BlockId(2)).is_none());
119        assert_eq!(func.block_ids(), vec![BlockId(0), BlockId(1)]);
120    }
121
122    #[test]
123    fn test_predecessors() {
124        let mut func = MirFunction::new("test".into(), BlockId(0));
125        let mut bb0 = BasicBlock::new(BlockId(0));
126        bb0.terminator = Terminator::CondBranch {
127            cond: ValueId(0),
128            true_target: BlockId(1),
129            false_target: BlockId(2),
130        };
131        let bb1 = BasicBlock::new(BlockId(1));
132        let bb2 = BasicBlock::new(BlockId(2));
133        func.blocks.push(bb0);
134        func.blocks.push(bb1);
135        func.blocks.push(bb2);
136
137        let preds = func.predecessors();
138        assert!(preds[&BlockId(0)].is_empty());
139        assert_eq!(preds[&BlockId(1)], vec![BlockId(0)]);
140        assert_eq!(preds[&BlockId(2)], vec![BlockId(0)]);
141    }
142
143    #[test]
144    fn test_value_types() {
145        let mut func = MirFunction::new("test".into(), BlockId(0));
146        func.set_type(ValueId(0), MirType::I32);
147        func.set_type(ValueId(1), MirType::F32);
148        assert_eq!(func.get_type(ValueId(0)), Some(MirType::I32));
149        assert_eq!(func.get_type(ValueId(1)), Some(MirType::F32));
150        assert_eq!(func.get_type(ValueId(99)), None);
151    }
152}