wave_compiler/hir/expr.rs
1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! HIR expression types for WAVE GPU kernels.
5//!
6//! Expressions represent computations that produce values, including
7//! arithmetic, memory access, GPU intrinsics, and type conversions.
8
9use super::types::{AddressSpace, Type};
10
11/// Dimension index for multi-dimensional GPU dispatch.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum Dimension {
14 /// X dimension (index 0).
15 X,
16 /// Y dimension (index 1).
17 Y,
18 /// Z dimension (index 2).
19 Z,
20}
21
22/// Shuffle modes for wave-level data exchange.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum ShuffleMode {
25 /// Direct shuffle to a specific lane.
26 Direct,
27 /// Shuffle up by an offset.
28 Up,
29 /// Shuffle down by an offset.
30 Down,
31 /// Shuffle with XOR of lane IDs.
32 Xor,
33}
34
35/// Binary operators.
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
37pub enum BinOp {
38 /// Addition.
39 Add,
40 /// Subtraction.
41 Sub,
42 /// Multiplication.
43 Mul,
44 /// Division.
45 Div,
46 /// Floor division.
47 FloorDiv,
48 /// Modulo.
49 Mod,
50 /// Power.
51 Pow,
52 /// Bitwise AND.
53 BitAnd,
54 /// Bitwise OR.
55 BitOr,
56 /// Bitwise XOR.
57 BitXor,
58 /// Left shift.
59 Shl,
60 /// Right shift.
61 Shr,
62 /// Equal.
63 Eq,
64 /// Not equal.
65 Ne,
66 /// Less than.
67 Lt,
68 /// Less than or equal.
69 Le,
70 /// Greater than.
71 Gt,
72 /// Greater than or equal.
73 Ge,
74}
75
76/// Unary operators.
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
78pub enum UnaryOp {
79 /// Arithmetic negation.
80 Neg,
81 /// Bitwise NOT.
82 BitNot,
83 /// Logical NOT.
84 Not,
85}
86
87/// Built-in math and GPU functions.
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89pub enum BuiltinFunc {
90 /// Square root.
91 Sqrt,
92 /// Sine.
93 Sin,
94 /// Cosine.
95 Cos,
96 /// Base-2 exponential.
97 Exp2,
98 /// Base-2 logarithm.
99 Log2,
100 /// Absolute value.
101 Abs,
102 /// Minimum of two values.
103 Min,
104 /// Maximum of two values.
105 Max,
106 /// Atomic add.
107 AtomicAdd,
108}
109
110/// Literal values.
111#[derive(Debug, Clone, PartialEq)]
112pub enum Literal {
113 /// Integer literal.
114 Int(i64),
115 /// Unsigned integer literal.
116 UInt(u64),
117 /// Float literal.
118 Float(f64),
119 /// Boolean literal.
120 Bool(bool),
121}
122
123/// Memory scope for fence operations.
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
125pub enum MemoryScope {
126 /// Wave-level scope.
127 Wave,
128 /// Workgroup-level scope.
129 Workgroup,
130 /// Device-level scope.
131 Device,
132}
133
134/// HIR expression producing a value.
135#[derive(Debug, Clone, PartialEq)]
136pub enum Expr {
137 /// Variable reference.
138 Var(String),
139 /// Literal constant.
140 Literal(Literal),
141 /// Binary operation.
142 BinOp {
143 /// The operator.
144 op: BinOp,
145 /// Left operand.
146 lhs: Box<Expr>,
147 /// Right operand.
148 rhs: Box<Expr>,
149 },
150 /// Unary operation.
151 UnaryOp {
152 /// The operator.
153 op: UnaryOp,
154 /// Operand.
155 operand: Box<Expr>,
156 },
157 /// Built-in function call.
158 Call {
159 /// The function being called.
160 func: BuiltinFunc,
161 /// Arguments to the function.
162 args: Vec<Expr>,
163 },
164 /// Array/buffer indexing.
165 Index {
166 /// Base pointer/array.
167 base: Box<Expr>,
168 /// Index expression.
169 index: Box<Expr>,
170 },
171 /// Type cast.
172 Cast {
173 /// Expression to cast.
174 expr: Box<Expr>,
175 /// Target type.
176 to: Type,
177 },
178 /// Thread ID in a given dimension.
179 ThreadId(Dimension),
180 /// Workgroup ID in a given dimension.
181 WorkgroupId(Dimension),
182 /// Workgroup size in a given dimension.
183 WorkgroupSize(Dimension),
184 /// Lane ID within wave.
185 LaneId,
186 /// Wave width (number of lanes).
187 WaveWidth,
188 /// Load from memory.
189 Load {
190 /// Address to load from.
191 addr: Box<Expr>,
192 /// Address space.
193 space: AddressSpace,
194 },
195 /// Wave shuffle operation.
196 Shuffle {
197 /// Value to shuffle.
198 value: Box<Expr>,
199 /// Target lane or offset.
200 lane: Box<Expr>,
201 /// Shuffle mode.
202 mode: ShuffleMode,
203 },
204}
205
206impl BinOp {
207 /// Returns true if this operator produces a boolean result.
208 #[must_use]
209 pub fn is_comparison(&self) -> bool {
210 matches!(
211 self,
212 Self::Eq | Self::Ne | Self::Lt | Self::Le | Self::Gt | Self::Ge
213 )
214 }
215
216 /// Returns true if this operator is arithmetic.
217 #[must_use]
218 pub fn is_arithmetic(&self) -> bool {
219 matches!(
220 self,
221 Self::Add | Self::Sub | Self::Mul | Self::Div | Self::FloorDiv | Self::Mod | Self::Pow
222 )
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn test_binop_classification() {
232 assert!(BinOp::Eq.is_comparison());
233 assert!(BinOp::Lt.is_comparison());
234 assert!(!BinOp::Add.is_comparison());
235 assert!(BinOp::Add.is_arithmetic());
236 assert!(BinOp::Mul.is_arithmetic());
237 assert!(!BinOp::Eq.is_arithmetic());
238 }
239
240 #[test]
241 fn test_expr_construction() {
242 let add_expr = Expr::BinOp {
243 op: BinOp::Add,
244 lhs: Box::new(Expr::Var("a".into())),
245 rhs: Box::new(Expr::Var("b".into())),
246 };
247 match &add_expr {
248 Expr::BinOp { op, lhs, rhs } => {
249 assert_eq!(*op, BinOp::Add);
250 assert_eq!(*lhs, Box::new(Expr::Var("a".into())));
251 assert_eq!(*rhs, Box::new(Expr::Var("b".into())));
252 }
253 _ => panic!("expected BinOp"),
254 }
255 }
256
257 #[test]
258 fn test_index_expr() {
259 let idx = Expr::Index {
260 base: Box::new(Expr::Var("arr".into())),
261 index: Box::new(Expr::ThreadId(Dimension::X)),
262 };
263 match &idx {
264 Expr::Index { base, index } => {
265 assert_eq!(**base, Expr::Var("arr".into()));
266 assert_eq!(**index, Expr::ThreadId(Dimension::X));
267 }
268 _ => panic!("expected Index"),
269 }
270 }
271
272 #[test]
273 fn test_literal_variants() {
274 let int_lit = Literal::Int(42);
275 let float_lit = Literal::Float(3.14);
276 let bool_lit = Literal::Bool(true);
277 assert_eq!(int_lit, Literal::Int(42));
278 assert_eq!(float_lit, Literal::Float(3.14));
279 assert_eq!(bool_lit, Literal::Bool(true));
280 }
281}