1use std::fmt;
15
16use crate::ir::Program;
17
18#[derive(Clone, Debug, Default, Eq, PartialEq)]
25#[non_exhaustive]
26pub struct RequiredCapabilities {
27 pub subgroup_ops: bool,
31 pub f16: bool,
33 pub bf16: bool,
35 pub f64: bool,
37 pub async_dispatch: bool,
39 pub indirect_dispatch: bool,
41 pub tensor_ops: bool,
43 pub trap: bool,
45 pub max_workgroup_size: [u32; 3],
47 pub static_storage_bytes: u64,
51}
52
53impl RequiredCapabilities {
54 #[must_use]
56 pub fn none() -> Self {
57 Self::default()
58 }
59
60 #[must_use]
62 pub fn union(mut self, other: RequiredCapabilities) -> Self {
63 self.subgroup_ops |= other.subgroup_ops;
64 self.f16 |= other.f16;
65 self.bf16 |= other.bf16;
66 self.f64 |= other.f64;
67 self.async_dispatch |= other.async_dispatch;
68 self.indirect_dispatch |= other.indirect_dispatch;
69 self.tensor_ops |= other.tensor_ops;
70 self.trap |= other.trap;
71 for axis in 0..3 {
72 self.max_workgroup_size[axis] =
73 self.max_workgroup_size[axis].max(other.max_workgroup_size[axis]);
74 }
75 self.static_storage_bytes = self
76 .static_storage_bytes
77 .saturating_add(other.static_storage_bytes);
78 self
79 }
80}
81
82#[derive(Clone, Debug, Eq, PartialEq)]
88pub struct MissingCapability {
89 pub backend: String,
91 pub missing: Vec<&'static str>,
93}
94
95impl fmt::Display for MissingCapability {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 write!(
98 f,
99 "backend `{}` is missing required capabilities: {}. \
100 Fix: pick a backend that advertises these capabilities, \
101 or run the program on the CPU reference.",
102 self.backend,
103 self.missing.join(", ")
104 )
105 }
106}
107
108impl std::error::Error for MissingCapability {}
109
110#[must_use]
112pub fn scan(program: &Program) -> RequiredCapabilities {
113 let stats = program.stats();
114 RequiredCapabilities {
115 subgroup_ops: stats.subgroup_ops(),
116 f16: stats.f16(),
117 bf16: stats.bf16(),
118 f64: stats.f64(),
119 async_dispatch: stats.async_dispatch(),
120 indirect_dispatch: stats.indirect_dispatch(),
121 tensor_ops: stats.tensor_ops(),
122 trap: stats.trap(),
123 max_workgroup_size: program.workgroup_size,
124 static_storage_bytes: stats.static_storage_bytes,
125 }
126}
127
128pub fn check_backend_capabilities(
137 backend_id: &str,
138 supports_subgroup_ops: bool,
139 supports_f16: bool,
140 supports_bf16: bool,
141 supports_indirect_dispatch: bool,
142 supports_trap_propagation: bool,
143 max_workgroup_size: [u32; 3],
144 required: &RequiredCapabilities,
145) -> Result<(), MissingCapability> {
146 let mut missing = Vec::new();
147 if required.subgroup_ops && !supports_subgroup_ops {
148 missing.push("subgroup_ops");
149 }
150 if required.f16 && !supports_f16 {
151 missing.push("f16");
152 }
153 if required.bf16 && !supports_bf16 {
154 missing.push("bf16");
155 }
156 if required.indirect_dispatch && !supports_indirect_dispatch {
157 missing.push("indirect_dispatch");
158 }
159 if required.trap && !supports_trap_propagation {
160 missing.push("trap_propagation");
161 }
162 for (req_size, max_size) in required
163 .max_workgroup_size
164 .iter()
165 .zip(max_workgroup_size.iter())
166 {
167 if *req_size > *max_size && *max_size != 0 {
168 missing.push("workgroup_size");
169 break;
170 }
171 }
172
173 if missing.is_empty() {
174 Ok(())
175 } else {
176 Err(MissingCapability {
177 backend: backend_id.to_string(),
178 missing,
179 })
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::ir::{BufferAccess, BufferDecl, DataType, Expr as IrExpr, Node as IrNode, Program};
187
188 fn empty_program() -> Program {
189 Program::wrapped(
190 vec![BufferDecl::storage(
191 "out",
192 0,
193 BufferAccess::ReadWrite,
194 DataType::U32,
195 )],
196 [1, 1, 1],
197 vec![IrNode::let_bind("x", IrExpr::u32(0))],
198 )
199 }
200
201 #[test]
202 fn scan_scalar_program_declares_no_capabilities() {
203 let caps = scan(&empty_program());
204 assert!(!caps.subgroup_ops);
205 assert!(!caps.f16);
206 assert!(!caps.async_dispatch);
207 }
208
209 #[test]
210 fn scan_subgroup_add_requires_subgroup_ops() {
211 let program = Program::wrapped(
212 vec![BufferDecl::storage(
213 "out",
214 0,
215 BufferAccess::ReadWrite,
216 DataType::U32,
217 )],
218 [1, 1, 1],
219 vec![IrNode::let_bind(
220 "s",
221 IrExpr::SubgroupAdd {
222 value: Box::new(IrExpr::u32(1)),
223 },
224 )],
225 );
226 let caps = scan(&program);
227 assert!(caps.subgroup_ops);
228 }
229
230 #[test]
231 fn scan_call_to_subgroup_intrinsic_requires_subgroup_ops() {
232 let program = Program::wrapped(
233 vec![BufferDecl::storage(
234 "out",
235 0,
236 BufferAccess::ReadWrite,
237 DataType::U32,
238 )],
239 [1, 1, 1],
240 vec![IrNode::let_bind(
241 "s",
242 IrExpr::call(
243 "vyre-intrinsics::math::subgroup_inclusive_add",
244 vec![IrExpr::u32(1)],
245 ),
246 )],
247 );
248 let caps = scan(&program);
249 assert!(caps.subgroup_ops);
250 }
251
252 #[test]
253 fn check_backend_reports_every_missing_bit() {
254 let required = RequiredCapabilities {
255 subgroup_ops: true,
256 f16: true,
257 trap: true,
258 ..RequiredCapabilities::default()
259 };
260 let error = check_backend_capabilities(
261 "test_backend",
262 false,
263 false,
264 false,
265 false,
266 false,
267 [64, 1, 1],
268 &required,
269 )
270 .unwrap_err();
271 assert_eq!(error.backend, "test_backend");
272 assert!(error.missing.contains(&"subgroup_ops"));
273 assert!(error.missing.contains(&"f16"));
274 assert!(error.missing.contains(&"trap_propagation"));
275 }
276}