Skip to main content

vyre_foundation/ir_inner/model/program/
core.rs

1use std::sync::{Arc, OnceLock};
2
3use rustc_hash::FxHashMap;
4
5use crate::ir_inner::model::node::Node;
6
7use super::BufferDecl;
8
9/// A complete vyre program.
10///
11/// Contains everything needed to execute a GPU compute dispatch:
12/// buffer declarations, workgroup configuration, and the entry point body.
13///
14/// # Example
15///
16/// A program that XORs two input buffers element-wise:
17///
18/// ```rust
19/// use vyre::ir::{Program, BufferDecl, BufferAccess, DataType, Node, Expr, BinOp};
20///
21/// let program = Program::wrapped(
22///     vec![
23///         BufferDecl::storage("a", 0, BufferAccess::ReadOnly, DataType::U32),
24///         BufferDecl::storage("b", 1, BufferAccess::ReadOnly, DataType::U32),
25///         BufferDecl::storage("out", 2, BufferAccess::ReadWrite, DataType::U32),
26///     ],
27///     [64, 1, 1],
28///     vec![
29///         Node::let_bind("idx", Expr::gid_x()),
30///         Node::if_then(
31///             Expr::lt(Expr::var("idx"), Expr::buf_len("out")),
32///             vec![
33///                 Node::store("out", Expr::var("idx"),
34///                     Expr::bitxor(
35///                         Expr::load("a", Expr::var("idx")),
36///                         Expr::load("b", Expr::var("idx")),
37///                     ),
38///                 ),
39///             ],
40///         ),
41///     ],
42/// );
43/// assert_eq!(program.buffers().len(), 3);
44/// ```
45#[derive(Debug)]
46pub struct Program {
47    /// Stable ID of the certified operation this program implements.
48    ///
49    /// Runtime lowering must reject programs without an ID because anonymous IR
50    /// cannot be tied back to a conform registry entry.
51    pub entry_op_id: Option<String>,
52    /// Buffer declarations. Each declares a named, typed, bound memory region.
53    pub buffers: Arc<[BufferDecl]>,
54    /// Sidecar index for O(1) buffer lookup by name.
55    pub(crate) buffer_index: Arc<FxHashMap<Arc<str>, usize>>,
56    /// Workgroup size: `[x, y, z]`. Controls `@workgroup_size` in target-text.
57    pub workgroup_size: [u32; 3],
58    /// Entry point body. Executes once per invocation.
59    pub entry: Arc<Vec<Node>>,
60    /// Cached blake3 hash of the program for fast equality and cache lookups.
61    pub(crate) hash: OnceLock<blake3::Hash>,
62    /// True once structural validation has succeeded for this immutable shape.
63    #[doc(hidden)]
64    pub(crate) validation_set: Arc<dashmap::DashSet<Arc<str>>>,
65    pub(crate) structural_validated: std::sync::atomic::AtomicBool,
66    pub(crate) fingerprint: OnceLock<[u8; 32]>,
67    // VYRE_IR_HOTSPOTS HIGH (core.rs:100-117): both caches were
68    // plain values, so `Program::clone` copied the whole Vec / whole
69    // ProgramStats by value. Wrapping them in Arc turns the clone
70    // into a refcount bump, keeping Program::clone O(1) on every
71    // field.
72    pub(crate) output_buffer_index: OnceLock<Arc<Vec<u32>>>,
73    pub(crate) has_indirect_dispatch: OnceLock<bool>,
74    /// Cached statistics computed from a single walk of the program.
75    ///
76    /// This is a transient cache: it is not serialized to wire format and is
77    /// invalidated whenever the program shape mutates.
78    pub(crate) stats: OnceLock<Arc<super::ProgramStats>>,
79    /// When true, this program must not be fused with another copy of itself
80    /// in the same megakernel. Parser programs that use workgroup-local scratch
81    /// buffers set this to avoid state corruption when two invocations share
82    /// the same workgroup memory.
83    pub non_composable_with_self: bool,
84}
85
86impl Default for Program {
87    #[inline]
88    fn default() -> Self {
89        Self::empty()
90    }
91}
92
93impl Clone for Program {
94    fn clone(&self) -> Self {
95        let cloned = Self {
96            entry_op_id: self.entry_op_id.clone(),
97            buffers: Arc::clone(&self.buffers),
98            buffer_index: Arc::clone(&self.buffer_index),
99            workgroup_size: self.workgroup_size,
100            entry: Arc::clone(&self.entry),
101            hash: OnceLock::new(),
102            validation_set: Arc::clone(&self.validation_set),
103            structural_validated: std::sync::atomic::AtomicBool::new(
104                self.is_structurally_validated(),
105            ),
106            fingerprint: OnceLock::new(),
107            output_buffer_index: OnceLock::new(),
108            has_indirect_dispatch: OnceLock::new(),
109            stats: OnceLock::new(),
110            non_composable_with_self: self.non_composable_with_self,
111        };
112        if let Some(hash) = self.hash.get() {
113            let _ = cloned.hash.set(*hash);
114        }
115        if let Some(fingerprint) = self.fingerprint.get() {
116            let _ = cloned.fingerprint.set(*fingerprint);
117        }
118        if let Some(output_buffer_index) = self.output_buffer_index.get() {
119            // Arc::clone = refcount bump, no Vec<u32> copy.
120            let _ = cloned
121                .output_buffer_index
122                .set(Arc::clone(output_buffer_index));
123        }
124        if let Some(has_indirect_dispatch) = self.has_indirect_dispatch.get() {
125            let _ = cloned.has_indirect_dispatch.set(*has_indirect_dispatch);
126        }
127        if let Some(stats) = self.stats.get() {
128            // Arc::clone = refcount bump, no ProgramStats copy.
129            let _ = cloned.stats.set(Arc::clone(stats));
130        }
131        cloned
132    }
133}
134
135impl PartialEq for Program {
136    fn eq(&self, other: &Self) -> bool {
137        self.structural_eq(other)
138    }
139}
140
141impl Eq for Program {}