Skip to main content

vyre_driver/backend/
validation.rs

1//! Backend support validation before dispatch.
2
3use super::capability::Backend;
4use std::sync::Arc;
5use vyre_foundation::ir::model::node::Node;
6use vyre_foundation::ir::{OpId, Program, ValidationError};
7
8const CORE_SUPPORTED_OP_IDS: &[&str] = &[
9    "vyre.node.let",
10    "vyre.node.assign",
11    "vyre.node.store",
12    "vyre.node.if",
13    "vyre.node.loop",
14    "vyre.node.return",
15    "vyre.node.block",
16    "vyre.node.barrier",
17    "vyre.node.indirect_dispatch",
18    "vyre.node.async_load",
19    "vyre.node.async_wait",
20    "vyre.node.region",
21    "vyre.lit_u32",
22    "vyre.lit_i32",
23    "vyre.lit_f32",
24    "vyre.lit_bool",
25    "vyre.var",
26    "vyre.bin_op",
27    "vyre.un_op",
28    "vyre.load",
29    "vyre.store",
30];
31
32/// Validate that `backend` supports every operation in `program`.
33pub fn validate_program(program: &Program, backend: &dyn Backend) -> Result<(), ValidationError> {
34    for (index, node) in program.entry().iter().enumerate() {
35        validate_node(node, index, backend.id(), backend.supported_ops())?;
36    }
37    Ok(())
38}
39
40/// Default core operation support set for legacy backends.
41pub fn default_supported_ops() -> &'static std::collections::HashSet<OpId> {
42    static OPS: std::sync::OnceLock<std::collections::HashSet<OpId>> = std::sync::OnceLock::new();
43    OPS.get_or_init(|| {
44        let mut ops = std::collections::HashSet::new();
45        let _ = ops.try_reserve(CORE_SUPPORTED_OP_IDS.len());
46        ops.extend(CORE_SUPPORTED_OP_IDS.iter().copied().map(Arc::<str>::from));
47        ops
48    })
49}
50
51/// Default core operation set plus `Node::Trap`.
52///
53/// `Trap` is a structural control-flow node, not a concrete-driver extension:
54/// backends that lower it as lane termination should use this shared set
55/// instead of carrying a backend-local `OnceLock` and literal allocation.
56pub fn default_supported_ops_with_trap() -> &'static std::collections::HashSet<OpId> {
57    static OPS: std::sync::OnceLock<std::collections::HashSet<OpId>> = std::sync::OnceLock::new();
58    OPS.get_or_init(|| {
59        let base = default_supported_ops();
60        let reserve = base.len().saturating_add(1);
61        let mut ops = std::collections::HashSet::new();
62        let _ = ops.try_reserve(reserve);
63        ops.extend(base.iter().cloned());
64        ops.insert(Arc::<str>::from("vyre.node.trap"));
65        ops
66    })
67}
68
69fn validate_node(
70    node: &Node,
71    index: usize,
72    backend: &'static str,
73    supported: &std::collections::HashSet<OpId>,
74) -> Result<(), ValidationError> {
75    let op = node_op_id(node);
76    if !supported.contains(op) {
77        let op_id = Arc::<str>::from(op);
78        return Err(ValidationError::unsupported_op(backend, &op_id, index));
79    }
80    match node {
81        Node::If {
82            then, otherwise, ..
83        } => {
84            for (offset, nested) in then.iter().enumerate() {
85                validate_node(nested, offset, backend, supported)?;
86            }
87            for (offset, nested) in otherwise.iter().enumerate() {
88                validate_node(nested, offset, backend, supported)?;
89            }
90        }
91        Node::Loop { body, .. } | Node::Block(body) => {
92            for (offset, nested) in body.iter().enumerate() {
93                validate_node(nested, offset, backend, supported)?;
94            }
95        }
96        Node::Region { body, .. } => {
97            for (offset, nested) in body.iter().enumerate() {
98                validate_node(nested, offset, backend, supported)?;
99            }
100        }
101        // Leaf nodes and backend-transparent nodes (opaque extensions
102        // validate themselves via `NodeExtension::validate_extension`).
103        Node::Let { .. }
104        | Node::Assign { .. }
105        | Node::Store { .. }
106        | Node::Return
107        | Node::Barrier { .. }
108        | Node::IndirectDispatch { .. }
109        | Node::AsyncLoad { .. }
110        | Node::AsyncWait { .. }
111        | Node::Opaque(_) => {}
112        // `Node` is `#[non_exhaustive]` in vyre-foundation. Future variants
113        // land here as transparent leaves until a dedicated arm is added.
114        _ => {}
115    }
116    Ok(())
117}
118
119/// Return the stable operation id for legacy statement nodes.
120#[must_use]
121pub fn node_op_id(node: &Node) -> &'static str {
122    match node {
123        Node::Let { .. } => "vyre.node.let",
124        Node::Assign { .. } => "vyre.node.assign",
125        Node::Store { .. } => "vyre.node.store",
126        Node::If { .. } => "vyre.node.if",
127        Node::Loop { .. } => "vyre.node.loop",
128        Node::Return => "vyre.node.return",
129        Node::Block(_) => "vyre.node.block",
130        Node::Barrier { .. } => "vyre.node.barrier",
131        Node::IndirectDispatch { .. } => "vyre.node.indirect_dispatch",
132        Node::AsyncLoad { .. } => "vyre.node.async_load",
133        Node::AsyncWait { .. } => "vyre.node.async_wait",
134        Node::Trap { .. } => "vyre.node.trap",
135        Node::Resume { .. } => "vyre.node.resume",
136        Node::AllReduce { .. } => "vyre.node.all_reduce",
137        Node::AllGather { .. } => "vyre.node.all_gather",
138        Node::ReduceScatter { .. } => "vyre.node.reduce_scatter",
139        Node::Broadcast { .. } => "vyre.node.broadcast",
140        // Region is a debug wrapper produced by vyre-libs Cat-A
141        // compositions. Every backend must accept it  -  either by
142        // lowering its body transparently or via the region_inline
143        // optimizer pass. Treat it as a structural node
144        // with no capability requirement.
145        Node::Region { .. } => "vyre.node.region",
146        Node::Opaque(extension) => extension.extension_kind(),
147        // Non-exhaustive safety net: future Node variants added in
148        // vyre-foundation must receive a dedicated op id before release.
149        _ => "vyre.node.unknown",
150    }
151}