Skip to main content

wave_compiler/hir/
types.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Type system for WAVE GPU kernels.
5//!
6//! All types must be representable in WAVE registers. The type system
7//! covers scalar types, pointers with address spaces, fixed-size arrays,
8//! and void for functions with no return value.
9
10/// Memory address spaces for GPU memory hierarchy.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub enum AddressSpace {
13    /// Per-thread private memory (registers/stack).
14    Private,
15    /// Workgroup shared memory (scratchpad).
16    Local,
17    /// Global device memory.
18    Device,
19}
20
21/// GPU-specific type system where all types map to WAVE registers.
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub enum Type {
24    /// Unsigned 32-bit integer.
25    U32,
26    /// Signed 32-bit integer.
27    I32,
28    /// 32-bit floating point.
29    F32,
30    /// 16-bit floating point (stored in lower half of 32-bit register).
31    F16,
32    /// 64-bit floating point (register pair).
33    F64,
34    /// Boolean predicate (maps to WAVE predicate register).
35    Bool,
36    /// Pointer to memory in a specific address space.
37    Ptr(AddressSpace),
38    /// Fixed-size array of a given element type.
39    Array(Box<Type>, usize),
40    /// No return value.
41    Void,
42}
43
44impl Type {
45    /// Returns the size in bytes of this type.
46    #[must_use]
47    pub fn size_bytes(&self) -> usize {
48        match self {
49            Self::U32 | Self::I32 | Self::F32 | Self::Bool | Self::Ptr(_) => 4,
50            Self::F16 => 2,
51            Self::F64 => 8,
52            Self::Array(elem, count) => elem.size_bytes() * count,
53            Self::Void => 0,
54        }
55    }
56
57    /// Returns true if this is a floating-point type.
58    #[must_use]
59    pub fn is_float(&self) -> bool {
60        matches!(self, Self::F32 | Self::F16 | Self::F64)
61    }
62
63    /// Returns true if this is an integer type.
64    #[must_use]
65    pub fn is_integer(&self) -> bool {
66        matches!(self, Self::U32 | Self::I32)
67    }
68
69    /// Returns true if this is a pointer type.
70    #[must_use]
71    pub fn is_pointer(&self) -> bool {
72        matches!(self, Self::Ptr(_))
73    }
74
75    /// Returns the number of 32-bit registers needed to hold this type.
76    #[must_use]
77    pub fn register_count(&self) -> usize {
78        match self {
79            Self::U32 | Self::I32 | Self::F32 | Self::F16 | Self::Bool | Self::Ptr(_) => 1,
80            Self::F64 => 2,
81            Self::Array(elem, count) => elem.register_count() * count,
82            Self::Void => 0,
83        }
84    }
85}
86
87impl std::fmt::Display for Type {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        match self {
90            Self::U32 => write!(f, "u32"),
91            Self::I32 => write!(f, "i32"),
92            Self::F32 => write!(f, "f32"),
93            Self::F16 => write!(f, "f16"),
94            Self::F64 => write!(f, "f64"),
95            Self::Bool => write!(f, "bool"),
96            Self::Ptr(space) => write!(f, "ptr<{space:?}>"),
97            Self::Array(elem, size) => write!(f, "[{elem}; {size}]"),
98            Self::Void => write!(f, "void"),
99        }
100    }
101}
102
103impl std::fmt::Display for AddressSpace {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        match self {
106            Self::Private => write!(f, "private"),
107            Self::Local => write!(f, "local"),
108            Self::Device => write!(f, "device"),
109        }
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn test_type_sizes() {
119        assert_eq!(Type::U32.size_bytes(), 4);
120        assert_eq!(Type::I32.size_bytes(), 4);
121        assert_eq!(Type::F32.size_bytes(), 4);
122        assert_eq!(Type::F16.size_bytes(), 2);
123        assert_eq!(Type::F64.size_bytes(), 8);
124        assert_eq!(Type::Bool.size_bytes(), 4);
125        assert_eq!(Type::Ptr(AddressSpace::Device).size_bytes(), 4);
126        assert_eq!(Type::Array(Box::new(Type::F32), 4).size_bytes(), 16);
127        assert_eq!(Type::Void.size_bytes(), 0);
128    }
129
130    #[test]
131    fn test_type_classification() {
132        assert!(Type::F32.is_float());
133        assert!(Type::F16.is_float());
134        assert!(Type::F64.is_float());
135        assert!(!Type::U32.is_float());
136        assert!(Type::U32.is_integer());
137        assert!(Type::I32.is_integer());
138        assert!(!Type::F32.is_integer());
139        assert!(Type::Ptr(AddressSpace::Device).is_pointer());
140        assert!(!Type::U32.is_pointer());
141    }
142
143    #[test]
144    fn test_register_count() {
145        assert_eq!(Type::U32.register_count(), 1);
146        assert_eq!(Type::F64.register_count(), 2);
147        assert_eq!(Type::Array(Box::new(Type::F32), 4).register_count(), 4);
148        assert_eq!(Type::Void.register_count(), 0);
149    }
150
151    #[test]
152    fn test_type_display() {
153        assert_eq!(format!("{}", Type::U32), "u32");
154        assert_eq!(
155            format!("{}", Type::Ptr(AddressSpace::Device)),
156            "ptr<Device>"
157        );
158        assert_eq!(
159            format!("{}", Type::Array(Box::new(Type::F32), 4)),
160            "[f32; 4]"
161        );
162    }
163}