Skip to main content

wave_compiler/hir/
expr.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! HIR expression types for WAVE GPU kernels.
5//!
6//! Expressions represent computations that produce values, including
7//! arithmetic, memory access, GPU intrinsics, and type conversions.
8
9use super::types::{AddressSpace, Type};
10
11/// Dimension index for multi-dimensional GPU dispatch.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum Dimension {
14    /// X dimension (index 0).
15    X,
16    /// Y dimension (index 1).
17    Y,
18    /// Z dimension (index 2).
19    Z,
20}
21
22/// Shuffle modes for wave-level data exchange.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum ShuffleMode {
25    /// Direct shuffle to a specific lane.
26    Direct,
27    /// Shuffle up by an offset.
28    Up,
29    /// Shuffle down by an offset.
30    Down,
31    /// Shuffle with XOR of lane IDs.
32    Xor,
33}
34
35/// Binary operators.
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
37pub enum BinOp {
38    /// Addition.
39    Add,
40    /// Subtraction.
41    Sub,
42    /// Multiplication.
43    Mul,
44    /// Division.
45    Div,
46    /// Floor division.
47    FloorDiv,
48    /// Modulo.
49    Mod,
50    /// Power.
51    Pow,
52    /// Bitwise AND.
53    BitAnd,
54    /// Bitwise OR.
55    BitOr,
56    /// Bitwise XOR.
57    BitXor,
58    /// Left shift.
59    Shl,
60    /// Right shift.
61    Shr,
62    /// Equal.
63    Eq,
64    /// Not equal.
65    Ne,
66    /// Less than.
67    Lt,
68    /// Less than or equal.
69    Le,
70    /// Greater than.
71    Gt,
72    /// Greater than or equal.
73    Ge,
74}
75
76/// Unary operators.
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
78pub enum UnaryOp {
79    /// Arithmetic negation.
80    Neg,
81    /// Bitwise NOT.
82    BitNot,
83    /// Logical NOT.
84    Not,
85}
86
87/// Built-in math and GPU functions.
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89pub enum BuiltinFunc {
90    /// Square root.
91    Sqrt,
92    /// Sine.
93    Sin,
94    /// Cosine.
95    Cos,
96    /// Base-2 exponential.
97    Exp2,
98    /// Base-2 logarithm.
99    Log2,
100    /// Absolute value.
101    Abs,
102    /// Minimum of two values.
103    Min,
104    /// Maximum of two values.
105    Max,
106    /// Atomic add.
107    AtomicAdd,
108}
109
110/// Literal values.
111#[derive(Debug, Clone, PartialEq)]
112pub enum Literal {
113    /// Integer literal.
114    Int(i64),
115    /// Unsigned integer literal.
116    UInt(u64),
117    /// Float literal.
118    Float(f64),
119    /// Boolean literal.
120    Bool(bool),
121}
122
123/// Memory scope for fence operations.
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
125pub enum MemoryScope {
126    /// Wave-level scope.
127    Wave,
128    /// Workgroup-level scope.
129    Workgroup,
130    /// Device-level scope.
131    Device,
132}
133
134/// HIR expression producing a value.
135#[derive(Debug, Clone, PartialEq)]
136pub enum Expr {
137    /// Variable reference.
138    Var(String),
139    /// Literal constant.
140    Literal(Literal),
141    /// Binary operation.
142    BinOp {
143        /// The operator.
144        op: BinOp,
145        /// Left operand.
146        lhs: Box<Expr>,
147        /// Right operand.
148        rhs: Box<Expr>,
149    },
150    /// Unary operation.
151    UnaryOp {
152        /// The operator.
153        op: UnaryOp,
154        /// Operand.
155        operand: Box<Expr>,
156    },
157    /// Built-in function call.
158    Call {
159        /// The function being called.
160        func: BuiltinFunc,
161        /// Arguments to the function.
162        args: Vec<Expr>,
163    },
164    /// Array/buffer indexing.
165    Index {
166        /// Base pointer/array.
167        base: Box<Expr>,
168        /// Index expression.
169        index: Box<Expr>,
170    },
171    /// Type cast.
172    Cast {
173        /// Expression to cast.
174        expr: Box<Expr>,
175        /// Target type.
176        to: Type,
177    },
178    /// Thread ID in a given dimension.
179    ThreadId(Dimension),
180    /// Workgroup ID in a given dimension.
181    WorkgroupId(Dimension),
182    /// Workgroup size in a given dimension.
183    WorkgroupSize(Dimension),
184    /// Lane ID within wave.
185    LaneId,
186    /// Wave width (number of lanes).
187    WaveWidth,
188    /// Load from memory.
189    Load {
190        /// Address to load from.
191        addr: Box<Expr>,
192        /// Address space.
193        space: AddressSpace,
194    },
195    /// Wave shuffle operation.
196    Shuffle {
197        /// Value to shuffle.
198        value: Box<Expr>,
199        /// Target lane or offset.
200        lane: Box<Expr>,
201        /// Shuffle mode.
202        mode: ShuffleMode,
203    },
204}
205
206impl BinOp {
207    /// Returns true if this operator produces a boolean result.
208    #[must_use]
209    pub fn is_comparison(&self) -> bool {
210        matches!(
211            self,
212            Self::Eq | Self::Ne | Self::Lt | Self::Le | Self::Gt | Self::Ge
213        )
214    }
215
216    /// Returns true if this operator is arithmetic.
217    #[must_use]
218    pub fn is_arithmetic(&self) -> bool {
219        matches!(
220            self,
221            Self::Add | Self::Sub | Self::Mul | Self::Div | Self::FloorDiv | Self::Mod | Self::Pow
222        )
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn test_binop_classification() {
232        assert!(BinOp::Eq.is_comparison());
233        assert!(BinOp::Lt.is_comparison());
234        assert!(!BinOp::Add.is_comparison());
235        assert!(BinOp::Add.is_arithmetic());
236        assert!(BinOp::Mul.is_arithmetic());
237        assert!(!BinOp::Eq.is_arithmetic());
238    }
239
240    #[test]
241    fn test_expr_construction() {
242        let add_expr = Expr::BinOp {
243            op: BinOp::Add,
244            lhs: Box::new(Expr::Var("a".into())),
245            rhs: Box::new(Expr::Var("b".into())),
246        };
247        match &add_expr {
248            Expr::BinOp { op, lhs, rhs } => {
249                assert_eq!(*op, BinOp::Add);
250                assert_eq!(*lhs, Box::new(Expr::Var("a".into())));
251                assert_eq!(*rhs, Box::new(Expr::Var("b".into())));
252            }
253            _ => panic!("expected BinOp"),
254        }
255    }
256
257    #[test]
258    fn test_index_expr() {
259        let idx = Expr::Index {
260            base: Box::new(Expr::Var("arr".into())),
261            index: Box::new(Expr::ThreadId(Dimension::X)),
262        };
263        match &idx {
264            Expr::Index { base, index } => {
265                assert_eq!(**base, Expr::Var("arr".into()));
266                assert_eq!(**index, Expr::ThreadId(Dimension::X));
267            }
268            _ => panic!("expected Index"),
269        }
270    }
271
272    #[test]
273    fn test_literal_variants() {
274        let int_lit = Literal::Int(42);
275        let float_lit = Literal::Float(3.14);
276        let bool_lit = Literal::Bool(true);
277        assert_eq!(int_lit, Literal::Int(42));
278        assert_eq!(float_lit, Literal::Float(3.14));
279        assert_eq!(bool_lit, Literal::Bool(true));
280    }
281}