Skip to main content

wave_compiler/mir/
instruction.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! MIR instruction definitions in SSA form.
5//!
6//! Each instruction operates on SSA values (virtual registers) and
7//! produces at most one result value. Instructions include arithmetic,
8//! memory operations, comparisons, conversions, and GPU intrinsics.
9
10use super::types::MirType;
11use super::value::ValueId;
12use crate::hir::expr::{BinOp, BuiltinFunc, MemoryScope, ShuffleMode};
13use crate::hir::types::AddressSpace;
14
15/// Atomic memory operation kinds.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum AtomicOp {
18    /// Atomic addition.
19    Add,
20    /// Atomic minimum.
21    Min,
22    /// Atomic maximum.
23    Max,
24}
25
26/// Constant value for MIR const instructions.
27#[derive(Debug, Clone, PartialEq)]
28pub enum ConstValue {
29    /// Integer constant.
30    I32(i32),
31    /// Unsigned integer constant.
32    U32(u32),
33    /// Float constant.
34    F32(f32),
35    /// Boolean constant.
36    Bool(bool),
37}
38
39impl ConstValue {
40    /// Returns the bits of this constant as a u32.
41    #[must_use]
42    pub fn to_bits(&self) -> u32 {
43        match self {
44            #[allow(clippy::cast_sign_loss)]
45            Self::I32(v) => *v as u32,
46            Self::U32(v) => *v,
47            Self::F32(v) => v.to_bits(),
48            Self::Bool(v) => u32::from(*v),
49        }
50    }
51}
52
53/// A MIR instruction in SSA form.
54#[derive(Debug, Clone, PartialEq)]
55pub enum MirInst {
56    /// Binary arithmetic/logic operation.
57    BinOp {
58        /// Destination value.
59        dest: ValueId,
60        /// Operator.
61        op: BinOp,
62        /// Left operand.
63        lhs: ValueId,
64        /// Right operand.
65        rhs: ValueId,
66        /// Result type.
67        ty: MirType,
68    },
69    /// Unary operation.
70    UnaryOp {
71        /// Destination value.
72        dest: ValueId,
73        /// Operator.
74        op: crate::hir::expr::UnaryOp,
75        /// Operand.
76        operand: ValueId,
77        /// Result type.
78        ty: MirType,
79    },
80    /// Load from memory.
81    Load {
82        /// Destination value.
83        dest: ValueId,
84        /// Address to load from.
85        addr: ValueId,
86        /// Address space.
87        space: AddressSpace,
88        /// Loaded value type.
89        ty: MirType,
90    },
91    /// Store to memory.
92    Store {
93        /// Address to store to.
94        addr: ValueId,
95        /// Value to store.
96        value: ValueId,
97        /// Address space.
98        space: AddressSpace,
99    },
100    /// Built-in function call.
101    Call {
102        /// Optional destination value.
103        dest: Option<ValueId>,
104        /// Function being called.
105        func: BuiltinFunc,
106        /// Arguments.
107        args: Vec<ValueId>,
108    },
109    /// Type cast/conversion.
110    Cast {
111        /// Destination value.
112        dest: ValueId,
113        /// Source value.
114        value: ValueId,
115        /// Source type.
116        from: MirType,
117        /// Target type.
118        to: MirType,
119    },
120    /// Constant value.
121    Const {
122        /// Destination value.
123        dest: ValueId,
124        /// The constant.
125        value: ConstValue,
126    },
127    /// Wave shuffle operation.
128    Shuffle {
129        /// Destination value.
130        dest: ValueId,
131        /// Value to shuffle.
132        value: ValueId,
133        /// Target lane/offset.
134        lane: ValueId,
135        /// Shuffle mode.
136        mode: ShuffleMode,
137    },
138    /// Read a special register (`thread_id`, `workgroup_id`, etc.).
139    ReadSpecialReg {
140        /// Destination value.
141        dest: ValueId,
142        /// Special register index.
143        sr_index: u8,
144    },
145    /// Atomic read-modify-write.
146    AtomicRmw {
147        /// Destination value (old value).
148        dest: ValueId,
149        /// Address.
150        addr: ValueId,
151        /// Operand value.
152        value: ValueId,
153        /// Atomic operation.
154        op: AtomicOp,
155        /// Memory scope.
156        scope: MemoryScope,
157    },
158    /// Workgroup barrier.
159    Barrier,
160    /// Memory fence.
161    Fence {
162        /// Scope of the fence.
163        scope: MemoryScope,
164    },
165}
166
167impl MirInst {
168    /// Returns the destination value ID if this instruction produces one.
169    #[must_use]
170    pub fn dest(&self) -> Option<ValueId> {
171        match self {
172            Self::BinOp { dest, .. }
173            | Self::UnaryOp { dest, .. }
174            | Self::Load { dest, .. }
175            | Self::Cast { dest, .. }
176            | Self::Const { dest, .. }
177            | Self::Shuffle { dest, .. }
178            | Self::ReadSpecialReg { dest, .. }
179            | Self::AtomicRmw { dest, .. } => Some(*dest),
180            Self::Call { dest, .. } => *dest,
181            Self::Store { .. } | Self::Barrier | Self::Fence { .. } => None,
182        }
183    }
184
185    /// Returns all value IDs used as operands by this instruction.
186    #[must_use]
187    pub fn operands(&self) -> Vec<ValueId> {
188        match self {
189            Self::BinOp { lhs, rhs, .. } => vec![*lhs, *rhs],
190            Self::UnaryOp { operand, .. } => vec![*operand],
191            Self::Load { addr, .. } => vec![*addr],
192            Self::Store { addr, value, .. } | Self::AtomicRmw { addr, value, .. } => {
193                vec![*addr, *value]
194            }
195            Self::Call { args, .. } => args.clone(),
196            Self::Cast { value, .. } => vec![*value],
197            Self::Const { .. }
198            | Self::ReadSpecialReg { .. }
199            | Self::Barrier
200            | Self::Fence { .. } => vec![],
201            Self::Shuffle { value, lane, .. } => vec![*value, *lane],
202        }
203    }
204
205    /// Returns true if this instruction has side effects.
206    #[must_use]
207    pub fn has_side_effects(&self) -> bool {
208        matches!(
209            self,
210            Self::Store { .. }
211                | Self::AtomicRmw { .. }
212                | Self::Barrier
213                | Self::Fence { .. }
214                | Self::Call { .. }
215        )
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_const_value_bits() {
225        assert_eq!(ConstValue::I32(42).to_bits(), 42);
226        assert_eq!(ConstValue::U32(0xFF).to_bits(), 0xFF);
227        assert_eq!(ConstValue::F32(1.0).to_bits(), 0x3F80_0000);
228        assert_eq!(ConstValue::Bool(true).to_bits(), 1);
229        assert_eq!(ConstValue::Bool(false).to_bits(), 0);
230    }
231
232    #[test]
233    fn test_instruction_dest_and_operands() {
234        let inst = MirInst::BinOp {
235            dest: ValueId(3),
236            op: BinOp::Add,
237            lhs: ValueId(1),
238            rhs: ValueId(2),
239            ty: MirType::I32,
240        };
241        assert_eq!(inst.dest(), Some(ValueId(3)));
242        assert_eq!(inst.operands(), vec![ValueId(1), ValueId(2)]);
243        assert!(!inst.has_side_effects());
244    }
245
246    #[test]
247    fn test_store_has_side_effects() {
248        let inst = MirInst::Store {
249            addr: ValueId(0),
250            value: ValueId(1),
251            space: AddressSpace::Device,
252        };
253        assert!(inst.has_side_effects());
254        assert_eq!(inst.dest(), None);
255        assert_eq!(inst.operands(), vec![ValueId(0), ValueId(1)]);
256    }
257
258    #[test]
259    fn test_barrier_has_side_effects() {
260        let inst = MirInst::Barrier;
261        assert!(inst.has_side_effects());
262        assert_eq!(inst.dest(), None);
263        assert!(inst.operands().is_empty());
264    }
265}