wave_compiler/hir/
types.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub enum AddressSpace {
13 Private,
15 Local,
17 Device,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub enum Type {
24 U32,
26 I32,
28 F32,
30 F16,
32 F64,
34 Bool,
36 Ptr(AddressSpace),
38 Array(Box<Type>, usize),
40 Void,
42}
43
44impl Type {
45 #[must_use]
47 pub fn size_bytes(&self) -> usize {
48 match self {
49 Self::U32 | Self::I32 | Self::F32 | Self::Bool | Self::Ptr(_) => 4,
50 Self::F16 => 2,
51 Self::F64 => 8,
52 Self::Array(elem, count) => elem.size_bytes() * count,
53 Self::Void => 0,
54 }
55 }
56
57 #[must_use]
59 pub fn is_float(&self) -> bool {
60 matches!(self, Self::F32 | Self::F16 | Self::F64)
61 }
62
63 #[must_use]
65 pub fn is_integer(&self) -> bool {
66 matches!(self, Self::U32 | Self::I32)
67 }
68
69 #[must_use]
71 pub fn is_pointer(&self) -> bool {
72 matches!(self, Self::Ptr(_))
73 }
74
75 #[must_use]
77 pub fn register_count(&self) -> usize {
78 match self {
79 Self::U32 | Self::I32 | Self::F32 | Self::F16 | Self::Bool | Self::Ptr(_) => 1,
80 Self::F64 => 2,
81 Self::Array(elem, count) => elem.register_count() * count,
82 Self::Void => 0,
83 }
84 }
85}
86
87impl std::fmt::Display for Type {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 match self {
90 Self::U32 => write!(f, "u32"),
91 Self::I32 => write!(f, "i32"),
92 Self::F32 => write!(f, "f32"),
93 Self::F16 => write!(f, "f16"),
94 Self::F64 => write!(f, "f64"),
95 Self::Bool => write!(f, "bool"),
96 Self::Ptr(space) => write!(f, "ptr<{space:?}>"),
97 Self::Array(elem, size) => write!(f, "[{elem}; {size}]"),
98 Self::Void => write!(f, "void"),
99 }
100 }
101}
102
103impl std::fmt::Display for AddressSpace {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match self {
106 Self::Private => write!(f, "private"),
107 Self::Local => write!(f, "local"),
108 Self::Device => write!(f, "device"),
109 }
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn test_type_sizes() {
119 assert_eq!(Type::U32.size_bytes(), 4);
120 assert_eq!(Type::I32.size_bytes(), 4);
121 assert_eq!(Type::F32.size_bytes(), 4);
122 assert_eq!(Type::F16.size_bytes(), 2);
123 assert_eq!(Type::F64.size_bytes(), 8);
124 assert_eq!(Type::Bool.size_bytes(), 4);
125 assert_eq!(Type::Ptr(AddressSpace::Device).size_bytes(), 4);
126 assert_eq!(Type::Array(Box::new(Type::F32), 4).size_bytes(), 16);
127 assert_eq!(Type::Void.size_bytes(), 0);
128 }
129
130 #[test]
131 fn test_type_classification() {
132 assert!(Type::F32.is_float());
133 assert!(Type::F16.is_float());
134 assert!(Type::F64.is_float());
135 assert!(!Type::U32.is_float());
136 assert!(Type::U32.is_integer());
137 assert!(Type::I32.is_integer());
138 assert!(!Type::F32.is_integer());
139 assert!(Type::Ptr(AddressSpace::Device).is_pointer());
140 assert!(!Type::U32.is_pointer());
141 }
142
143 #[test]
144 fn test_register_count() {
145 assert_eq!(Type::U32.register_count(), 1);
146 assert_eq!(Type::F64.register_count(), 2);
147 assert_eq!(Type::Array(Box::new(Type::F32), 4).register_count(), 4);
148 assert_eq!(Type::Void.register_count(), 0);
149 }
150
151 #[test]
152 fn test_type_display() {
153 assert_eq!(format!("{}", Type::U32), "u32");
154 assert_eq!(
155 format!("{}", Type::Ptr(AddressSpace::Device)),
156 "ptr<Device>"
157 );
158 assert_eq!(
159 format!("{}", Type::Array(Box::new(Type::F32), 4)),
160 "[f32; 4]"
161 );
162 }
163}