Skip to main content

vyre_foundation/ir_inner/model/program/
builder.rs

1use std::sync::{Arc, OnceLock};
2
3use rustc_hash::FxHashMap;
4
5use crate::ir_inner::model::arena::{ArenaProgram, ExprArena};
6use crate::ir_inner::model::node::Node;
7
8use super::{BufferDecl, Program};
9
10impl Program {
11    /// Synthetic generator id used when callers submit a raw top-level body
12    /// instead of an explicit `Node::Region`.
13    pub const ROOT_REGION_GENERATOR: &'static str = "vyre.program.root";
14
15    /// Create a complete program from buffer declarations, workgroup size, and
16    /// entry-point nodes, auto-wrapping the top-level body in a root Region
17    /// when necessary.
18    ///
19    /// This is the default construction path for runnable Programs.
20    ///
21    /// # Examples
22    ///
23    /// ```
24    /// use vyre::ir::{BufferAccess, BufferDecl, DataType, Node, Program};
25    ///
26    /// let program = Program::wrapped(
27    ///     vec![BufferDecl::storage(
28    ///         "output",
29    ///         0,
30    ///         BufferAccess::ReadWrite,
31    ///         DataType::U32,
32    ///     )],
33    ///     [64, 1, 1],
34    ///     Vec::new(),
35    /// );
36    ///
37    /// assert_eq!(program.workgroup_size(), [64, 1, 1]);
38    /// assert_eq!(program.buffers().len(), 1);
39    /// assert!(matches!(program.entry(), [Node::Region { .. }]));
40    /// ```
41    #[must_use]
42    #[inline]
43    pub fn wrapped(buffers: Vec<BufferDecl>, workgroup_size: [u32; 3], entry: Vec<Node>) -> Self {
44        Self::new_raw(buffers, workgroup_size, Self::wrap_entry(entry))
45    }
46
47    /// Create a complete program from buffer declarations, workgroup size, and
48    /// entry-point nodes.
49    ///
50    /// # Examples
51    ///
52    /// ```
53    /// use vyre::ir::{BufferAccess, BufferDecl, DataType, Program};
54    ///
55    /// let program = Program::wrapped(
56    ///     vec![BufferDecl::storage(
57    ///         "output",
58    ///         0,
59    ///         BufferAccess::ReadWrite,
60    ///         DataType::U32,
61    ///     )],
62    ///     [64, 1, 1],
63    ///     Vec::new(),
64    /// );
65    ///
66    /// assert_eq!(program.workgroup_size(), [64, 1, 1]);
67    /// assert_eq!(program.buffers().len(), 1);
68    /// assert!(matches!(program.entry(), [Node::Region { .. }]));
69    /// ```
70    #[deprecated(
71        note = "Program::new preserves raw top-level entry nodes. Use Program::wrapped for runnable programs; reserve Program::new for wire decode and negative tests."
72    )]
73    #[must_use]
74    #[inline]
75    pub fn new(buffers: Vec<BufferDecl>, workgroup_size: [u32; 3], entry: Vec<Node>) -> Self {
76        Self::new_raw(buffers, workgroup_size, entry)
77    }
78
79    #[must_use]
80    #[inline]
81    pub(crate) fn new_raw(
82        buffers: Vec<BufferDecl>,
83        workgroup_size: [u32; 3],
84        entry: Vec<Node>,
85    ) -> Self {
86        let mut interner = FxHashMap::<Arc<str>, Arc<str>>::default();
87        interner.reserve(buffers.len());
88        let buffers: Vec<BufferDecl> = buffers
89            .into_iter()
90            .map(|mut b| {
91                let arc = interner
92                    .entry(Arc::clone(&b.name))
93                    .or_insert_with(|| Arc::clone(&b.name))
94                    .clone();
95                b.name = arc;
96                b
97            })
98            .collect();
99        let buffer_index = Self::build_buffer_index(&buffers);
100        Self {
101            entry_op_id: None,
102            buffers: Arc::from(buffers),
103            buffer_index: Arc::new(buffer_index),
104            workgroup_size,
105            entry: Arc::new(entry),
106            hash: OnceLock::new(),
107            validation_set: Arc::new(dashmap::DashSet::new()),
108            structural_validated: std::sync::atomic::AtomicBool::new(false),
109            fingerprint: OnceLock::new(),
110            output_buffer_index: OnceLock::new(),
111            has_indirect_dispatch: OnceLock::new(),
112            stats: OnceLock::new(),
113            non_composable_with_self: false,
114        }
115    }
116
117    /// Clone this program with a replacement entry body while preserving the
118    /// existing buffer table, workgroup size, and optional certified op id.
119    #[must_use]
120    #[inline]
121    pub fn with_rewritten_entry(&self, entry: Vec<Node>) -> Self {
122        Self {
123            entry_op_id: self.entry_op_id.clone(),
124            buffers: Arc::clone(&self.buffers),
125            buffer_index: Arc::clone(&self.buffer_index),
126            workgroup_size: self.workgroup_size,
127            entry: Arc::new(entry),
128            hash: OnceLock::new(),
129            validation_set: Arc::new(dashmap::DashSet::new()),
130            structural_validated: std::sync::atomic::AtomicBool::new(false),
131            fingerprint: OnceLock::new(),
132            output_buffer_index: OnceLock::new(),
133            has_indirect_dispatch: OnceLock::new(),
134            stats: OnceLock::new(),
135            non_composable_with_self: self.non_composable_with_self,
136        }
137    }
138
139    /// Clone this program with replacement buffer declarations while
140    /// preserving the entry body, workgroup size, and metadata flags.
141    #[must_use]
142    #[inline]
143    pub fn with_rewritten_buffers(&self, buffers: Vec<BufferDecl>) -> Self {
144        let buffer_index = Self::build_buffer_index(&buffers);
145        Self {
146            entry_op_id: self.entry_op_id.clone(),
147            buffers: Arc::from(buffers),
148            buffer_index: Arc::new(buffer_index),
149            workgroup_size: self.workgroup_size,
150            entry: Arc::clone(&self.entry),
151            hash: OnceLock::new(),
152            validation_set: Arc::new(dashmap::DashSet::new()),
153            structural_validated: std::sync::atomic::AtomicBool::new(false),
154            fingerprint: OnceLock::new(),
155            output_buffer_index: OnceLock::new(),
156            has_indirect_dispatch: OnceLock::new(),
157            stats: OnceLock::new(),
158            non_composable_with_self: self.non_composable_with_self,
159        }
160    }
161
162    /// Clone this program with replacement dispatch dimensions and entry body
163    /// while preserving the existing buffer table, indexes, and metadata flags.
164    #[must_use]
165    #[inline]
166    pub fn with_rewritten_workgroup_size_and_entry(
167        &self,
168        workgroup_size: [u32; 3],
169        entry: Vec<Node>,
170    ) -> Self {
171        Self {
172            entry_op_id: self.entry_op_id.clone(),
173            buffers: Arc::clone(&self.buffers),
174            buffer_index: Arc::clone(&self.buffer_index),
175            workgroup_size,
176            entry: Arc::new(entry),
177            hash: OnceLock::new(),
178            validation_set: Arc::new(dashmap::DashSet::new()),
179            structural_validated: std::sync::atomic::AtomicBool::new(false),
180            fingerprint: OnceLock::new(),
181            output_buffer_index: OnceLock::new(),
182            has_indirect_dispatch: OnceLock::new(),
183            stats: OnceLock::new(),
184            non_composable_with_self: self.non_composable_with_self,
185        }
186    }
187
188    /// Consume the program and return its entry nodes, reusing the
189    /// backing vector when this program owns the entry body uniquely.
190    #[must_use]
191    #[inline]
192    pub fn into_entry_vec(self) -> Vec<Node> {
193        Arc::try_unwrap(self.entry).unwrap_or_else(|entry| entry.as_ref().clone())
194    }
195
196    /// Create an arena-backed program scaffold.
197    ///
198    /// This constructor is the opt-in migration path for builders that want
199    /// [`ExprRef`](crate::ir_inner::model::arena::ExprRef) handles instead of boxed
200    /// expression trees. [`Program::new`] remains the boxed-tree constructor.
201    #[must_use]
202    #[inline]
203    pub fn with_arena(
204        arena: &ExprArena,
205        buffers: Vec<BufferDecl>,
206        workgroup_size: [u32; 3],
207    ) -> ArenaProgram<'_> {
208        ArenaProgram::new(arena, buffers, workgroup_size)
209    }
210
211    /// Create a minimal program with no buffers and an empty body.
212    ///
213    /// # Examples
214    ///
215    /// ```
216    /// use vyre::ir::Program;
217    ///
218    /// let program = Program::empty();
219    ///
220    /// assert!(program.buffers().is_empty());
221    /// assert_eq!(program.workgroup_size(), [1, 1, 1]);
222    /// assert!(program.is_explicit_noop());
223    /// ```
224    #[must_use]
225    #[inline]
226    pub fn empty() -> Self {
227        Self::wrapped(Vec::new(), [1, 1, 1], Vec::new())
228    }
229
230    /// Attach the stable operation ID whose conform registry entry certifies
231    /// this program for runtime lowering.
232    #[must_use]
233    #[inline]
234    pub fn with_entry_op_id(mut self, op_id: impl Into<String>) -> Self {
235        self.entry_op_id = Some(op_id.into());
236        self.invalidate_caches();
237        self
238    }
239
240    /// Stable operation ID required by the conform gate.
241    #[must_use]
242    #[inline]
243    pub fn entry_op_id(&self) -> Option<&str> {
244        self.entry_op_id.as_deref()
245    }
246
247    /// Attach an optional operation ID while preserving anonymous test IR.
248    #[must_use]
249    #[inline]
250    pub(crate) fn with_optional_entry_op_id(mut self, op_id: Option<String>) -> Self {
251        self.entry_op_id = op_id;
252        self.invalidate_caches();
253        self
254    }
255}