Skip to main content

vyre_foundation/runtime/
program_caps.rs

1//! Program → required-capability analysis.
2//!
3//! Scan a `Program` and report the hardware capabilities its lowering will
4//! need. Callers (backends, conformance harnesses, certificate emitters)
5//! compare the required set against what a backend advertises and surface
6//! `MissingCapability` *before* handing the kernel to the device, avoiding
7//! panics inside `create_shader_module` / `createComputePipeline`.
8//!
9//! The scanner is strictly syntactic: it walks every `Expr` and `Node` in
10//! the program and checks the IR surface. It intentionally does **not**
11//! know anything about backend-specific lowering rules — that would make it
12//! a circular dependency of the very thing it is supposed to gate.
13
14use std::fmt;
15
16use crate::ir::Program;
17
18/// Capabilities a `Program` needs from whichever backend executes it.
19///
20/// This is a structured replacement for hardcoded "exempt op" lists. A
21/// universal diff harness asks `scan(program)` which bits the program
22/// needs, asks the backend which bits it advertises, and skips the pair
23/// when they disagree. The result reasons are attached for telemetry.
24#[derive(Clone, Debug, Default, Eq, PartialEq)]
25#[non_exhaustive]
26pub struct RequiredCapabilities {
27    /// The program invokes `Expr::SubgroupAdd`, `SubgroupBallot`, or
28    /// `SubgroupShuffle`. Lowering paths need the SUBGROUP / wave-op
29    /// feature on the target device.
30    pub subgroup_ops: bool,
31    /// The program uses any IEEE 754 binary16 operand.
32    pub f16: bool,
33    /// The program uses any bfloat16 operand.
34    pub bf16: bool,
35    /// The program uses 64-bit floats.
36    pub f64: bool,
37    /// The program dispatches async DMA (`Node::AsyncLoad` / `AsyncStore`).
38    pub async_dispatch: bool,
39    /// The program emits `Node::IndirectDispatch`.
40    pub indirect_dispatch: bool,
41    /// The program reaches into tensor / tensor-core operand types.
42    pub tensor_ops: bool,
43    /// The program uses a `Node::Trap` — backend needs trap propagation.
44    pub trap: bool,
45    /// Maximum workgroup size declared by the program across all axes.
46    pub max_workgroup_size: [u32; 3],
47    /// Sum of `BufferDecl::count * sizeof(DataType)` across every buffer
48    /// whose size can be computed statically. `0` means every buffer has
49    /// dynamic size.
50    pub static_storage_bytes: u64,
51}
52
53impl RequiredCapabilities {
54    /// Empty set — the Program needs nothing beyond the minimum substrate.
55    #[must_use]
56    pub fn none() -> Self {
57        Self::default()
58    }
59
60    /// Build the union of two capability sets (field-wise `OR` and `max`).
61    #[must_use]
62    pub fn union(mut self, other: RequiredCapabilities) -> Self {
63        self.subgroup_ops |= other.subgroup_ops;
64        self.f16 |= other.f16;
65        self.bf16 |= other.bf16;
66        self.f64 |= other.f64;
67        self.async_dispatch |= other.async_dispatch;
68        self.indirect_dispatch |= other.indirect_dispatch;
69        self.tensor_ops |= other.tensor_ops;
70        self.trap |= other.trap;
71        for axis in 0..3 {
72            self.max_workgroup_size[axis] =
73                self.max_workgroup_size[axis].max(other.max_workgroup_size[axis]);
74        }
75        self.static_storage_bytes = self
76            .static_storage_bytes
77            .saturating_add(other.static_storage_bytes);
78        self
79    }
80}
81
82/// The reason a backend cannot execute a program.
83///
84/// Returned by [`check_backend_capabilities`] when the scan finds a
85/// capability the backend did not advertise. Carries every missing bit
86/// so callers can emit one actionable error instead of bisecting.
87#[derive(Clone, Debug, Eq, PartialEq)]
88pub struct MissingCapability {
89    /// Backend identifier that was asked to run the program.
90    pub backend: String,
91    /// Flat list of human-readable capability names the backend lacks.
92    pub missing: Vec<&'static str>,
93}
94
95impl fmt::Display for MissingCapability {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        write!(
98            f,
99            "backend `{}` is missing required capabilities: {}. \
100             Fix: pick a backend that advertises these capabilities, \
101             or run the program on the CPU reference.",
102            self.backend,
103            self.missing.join(", ")
104        )
105    }
106}
107
108impl std::error::Error for MissingCapability {}
109
110/// Walk the program and collect the union of capabilities it requires.
111#[must_use]
112pub fn scan(program: &Program) -> RequiredCapabilities {
113    let stats = program.stats();
114    RequiredCapabilities {
115        subgroup_ops: stats.subgroup_ops(),
116        f16: stats.f16(),
117        bf16: stats.bf16(),
118        f64: stats.f64(),
119        async_dispatch: stats.async_dispatch(),
120        indirect_dispatch: stats.indirect_dispatch(),
121        tensor_ops: stats.tensor_ops(),
122        trap: stats.trap(),
123        max_workgroup_size: program.workgroup_size,
124        static_storage_bytes: stats.static_storage_bytes,
125    }
126}
127
128/// Return `Ok(())` when a backend with the given advertised capabilities
129/// can run a program whose required set is `required`, otherwise return
130/// the missing-capability explanation.
131///
132/// The caller passes in the boolean capability queries from
133/// [`crate::ir::Program`]'s backend trait (`supports_subgroup_ops`,
134/// `supports_f16`, etc.) so this function stays free of the
135/// `VyreBackend` trait import and can live in vyre-foundation.
136pub fn check_backend_capabilities(
137    backend_id: &str,
138    supports_subgroup_ops: bool,
139    supports_f16: bool,
140    supports_bf16: bool,
141    supports_indirect_dispatch: bool,
142    supports_trap_propagation: bool,
143    max_workgroup_size: [u32; 3],
144    required: &RequiredCapabilities,
145) -> Result<(), MissingCapability> {
146    let mut missing = Vec::new();
147    if required.subgroup_ops && !supports_subgroup_ops {
148        missing.push("subgroup_ops");
149    }
150    if required.f16 && !supports_f16 {
151        missing.push("f16");
152    }
153    if required.bf16 && !supports_bf16 {
154        missing.push("bf16");
155    }
156    if required.indirect_dispatch && !supports_indirect_dispatch {
157        missing.push("indirect_dispatch");
158    }
159    if required.trap && !supports_trap_propagation {
160        missing.push("trap_propagation");
161    }
162    for (req_size, max_size) in required
163        .max_workgroup_size
164        .iter()
165        .zip(max_workgroup_size.iter())
166    {
167        if *req_size > *max_size && *max_size != 0 {
168            missing.push("workgroup_size");
169            break;
170        }
171    }
172
173    if missing.is_empty() {
174        Ok(())
175    } else {
176        Err(MissingCapability {
177            backend: backend_id.to_string(),
178            missing,
179        })
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use crate::ir::{BufferAccess, BufferDecl, DataType, Expr as IrExpr, Node as IrNode, Program};
187
188    fn empty_program() -> Program {
189        Program::wrapped(
190            vec![BufferDecl::storage(
191                "out",
192                0,
193                BufferAccess::ReadWrite,
194                DataType::U32,
195            )],
196            [1, 1, 1],
197            vec![IrNode::let_bind("x", IrExpr::u32(0))],
198        )
199    }
200
201    #[test]
202    fn scan_scalar_program_declares_no_capabilities() {
203        let caps = scan(&empty_program());
204        assert!(!caps.subgroup_ops);
205        assert!(!caps.f16);
206        assert!(!caps.async_dispatch);
207    }
208
209    #[test]
210    fn scan_subgroup_add_requires_subgroup_ops() {
211        let program = Program::wrapped(
212            vec![BufferDecl::storage(
213                "out",
214                0,
215                BufferAccess::ReadWrite,
216                DataType::U32,
217            )],
218            [1, 1, 1],
219            vec![IrNode::let_bind(
220                "s",
221                IrExpr::SubgroupAdd {
222                    value: Box::new(IrExpr::u32(1)),
223                },
224            )],
225        );
226        let caps = scan(&program);
227        assert!(caps.subgroup_ops);
228    }
229
230    #[test]
231    fn scan_call_to_subgroup_intrinsic_requires_subgroup_ops() {
232        let program = Program::wrapped(
233            vec![BufferDecl::storage(
234                "out",
235                0,
236                BufferAccess::ReadWrite,
237                DataType::U32,
238            )],
239            [1, 1, 1],
240            vec![IrNode::let_bind(
241                "s",
242                IrExpr::call(
243                    "vyre-intrinsics::math::subgroup_inclusive_add",
244                    vec![IrExpr::u32(1)],
245                ),
246            )],
247        );
248        let caps = scan(&program);
249        assert!(caps.subgroup_ops);
250    }
251
252    #[test]
253    fn check_backend_reports_every_missing_bit() {
254        let required = RequiredCapabilities {
255            subgroup_ops: true,
256            f16: true,
257            trap: true,
258            ..RequiredCapabilities::default()
259        };
260        let error = check_backend_capabilities(
261            "test_backend",
262            false,
263            false,
264            false,
265            false,
266            false,
267            [64, 1, 1],
268            &required,
269        )
270        .unwrap_err();
271        assert_eq!(error.backend, "test_backend");
272        assert!(error.missing.contains(&"subgroup_ops"));
273        assert!(error.missing.contains(&"f16"));
274        assert!(error.missing.contains(&"trap_propagation"));
275    }
276}