Skip to main content

vyre_foundation/ir_inner/model/program/
builder.rs

1use std::sync::{Arc, OnceLock};
2
3use rustc_hash::FxHashMap;
4
5use crate::ir_inner::model::arena::{ArenaProgram, ExprArena};
6use crate::ir_inner::model::node::Node;
7
8use super::{BufferDecl, Program};
9
10impl Program {
11    /// Synthetic generator id used when callers submit a raw top-level body
12    /// instead of an explicit `Node::Region`.
13    pub const ROOT_REGION_GENERATOR: &'static str = "vyre.program.root";
14
15    /// Create a complete program from buffer declarations, workgroup size, and
16    /// entry-point nodes, auto-wrapping the top-level body in a root Region
17    /// when necessary.
18    ///
19    /// This is the default construction path for runnable Programs.
20    ///
21    /// # Examples
22    ///
23    /// ```
24    /// use vyre::ir::{BufferAccess, BufferDecl, DataType, Node, Program};
25    ///
26    /// let program = Program::wrapped(
27    ///     vec![BufferDecl::storage(
28    ///         "output",
29    ///         0,
30    ///         BufferAccess::ReadWrite,
31    ///         DataType::U32,
32    ///     )],
33    ///     [64, 1, 1],
34    ///     Vec::new(),
35    /// );
36    ///
37    /// assert_eq!(program.workgroup_size(), [64, 1, 1]);
38    /// assert_eq!(program.buffers().len(), 1);
39    /// assert!(matches!(program.entry(), [Node::Region { .. }]));
40    /// ```
41    #[must_use]
42    #[inline]
43    pub fn wrapped(buffers: Vec<BufferDecl>, workgroup_size: [u32; 3], entry: Vec<Node>) -> Self {
44        Self::new_raw(buffers, workgroup_size, Self::wrap_entry(entry))
45    }
46
47    /// Create a complete program from buffer declarations, workgroup size, and
48    /// entry-point nodes.
49    ///
50    /// # Examples
51    ///
52    /// ```
53    /// use vyre::ir::{BufferAccess, BufferDecl, DataType, Node, Program};
54    ///
55    /// let program = Program::wrapped(
56    ///     vec![BufferDecl::storage(
57    ///         "output",
58    ///         0,
59    ///         BufferAccess::ReadWrite,
60    ///         DataType::U32,
61    ///     )],
62    ///     [64, 1, 1],
63    ///     Vec::new(),
64    /// );
65    ///
66    /// assert_eq!(program.workgroup_size(), [64, 1, 1]);
67    /// assert_eq!(program.buffers().len(), 1);
68    /// assert!(matches!(program.entry(), [Node::Region { .. }]));
69    /// ```
70    #[deprecated(
71        note = "Program::new preserves raw top-level entry nodes. Use Program::wrapped for runnable programs; reserve Program::new for wire decode and negative tests."
72    )]
73    #[must_use]
74    #[inline]
75    pub fn new(buffers: Vec<BufferDecl>, workgroup_size: [u32; 3], entry: Vec<Node>) -> Self {
76        Self::new_raw(buffers, workgroup_size, entry)
77    }
78
79    #[must_use]
80    #[inline]
81    pub(crate) fn new_raw(
82        buffers: Vec<BufferDecl>,
83        workgroup_size: [u32; 3],
84        entry: Vec<Node>,
85    ) -> Self {
86        let mut interner = FxHashMap::<Arc<str>, Arc<str>>::default();
87        interner.reserve(buffers.len());
88        let buffers: Vec<BufferDecl> = buffers
89            .into_iter()
90            .map(|mut b| {
91                let arc = interner
92                    .entry(Arc::clone(&b.name))
93                    .or_insert_with(|| Arc::clone(&b.name))
94                    .clone();
95                b.name = arc;
96                b
97            })
98            .collect();
99        let buffer_index = Self::build_buffer_index(&buffers);
100        Self {
101            entry_op_id: None,
102            buffers: Arc::from(buffers),
103            buffer_index: Arc::new(buffer_index),
104            workgroup_size,
105            entry: Arc::new(entry),
106            hash: OnceLock::new(),
107            validation_set: OnceLock::new(),
108            structural_validated: std::sync::atomic::AtomicBool::new(false),
109            structural_validation_fingerprint: std::sync::atomic::AtomicU64::new(0),
110            mutation_provenance: std::sync::atomic::AtomicU8::new(0),
111            fingerprint: OnceLock::new(),
112            output_buffer_index: OnceLock::new(),
113            has_indirect_dispatch: OnceLock::new(),
114            stats: OnceLock::new(),
115            non_composable_with_self: false,
116        }
117    }
118
119    /// Same as [`Self::with_rewritten_entry`] but wraps the entry first via
120    /// the runnable-Region root contract (matches [`Self::wrapped`]). Use
121    /// from passes that produce a fully fresh entry body but want to reuse
122    /// the existing buffer Arc instead of paying for a full
123    /// [`Self::wrapped`] (which deep-clones buffers, re-interns names, and
124    /// rebuilds the buffer index).
125    #[must_use]
126    #[inline]
127    pub fn with_rewritten_wrapped_entry(&self, entry: Vec<Node>) -> Self {
128        self.with_rewritten_entry(Self::wrap_entry(entry))
129    }
130
131    /// Consume this program and rebuild it with `f` applied to the owned
132    /// entry vec. Reuses the entry Arc when uniquely owned (the common
133    /// case under the optimizer fixpoint)  -  no deep clone of the entry
134    /// body and no scaffold allocation. Equivalent to:
135    ///
136    /// ```ignore
137    /// let scaffold = program.with_rewritten_entry(Vec::new());
138    /// let entry = f(program.into_entry_vec());
139    /// scaffold.with_rewritten_entry(entry)
140    /// ```
141    ///
142    /// but produces only one new `Program` value instead of two.
143    #[must_use]
144    #[inline]
145    pub fn map_entry<F: FnOnce(Vec<Node>) -> Vec<Node>>(self, f: F) -> Self {
146        let entry_op_id = self.entry_op_id.clone();
147        let buffers = Arc::clone(&self.buffers);
148        let buffer_index = Arc::clone(&self.buffer_index);
149        let workgroup_size = self.workgroup_size;
150        let non_composable_with_self = self.non_composable_with_self;
151        let entry = f(self.into_entry_vec());
152        Self {
153            entry_op_id,
154            buffers,
155            buffer_index,
156            workgroup_size,
157            entry: Arc::new(entry),
158            hash: OnceLock::new(),
159            validation_set: OnceLock::new(),
160            structural_validated: std::sync::atomic::AtomicBool::new(false),
161            structural_validation_fingerprint: std::sync::atomic::AtomicU64::new(0),
162            mutation_provenance: std::sync::atomic::AtomicU8::new(0),
163            fingerprint: OnceLock::new(),
164            output_buffer_index: OnceLock::new(),
165            has_indirect_dispatch: OnceLock::new(),
166            stats: OnceLock::new(),
167            non_composable_with_self,
168        }
169    }
170
171    /// Clone this program with a replacement entry body while preserving the
172    /// existing buffer table, workgroup size, and optional certified op id.
173    #[must_use]
174    #[inline]
175    pub fn with_rewritten_entry(&self, entry: Vec<Node>) -> Self {
176        Self {
177            entry_op_id: self.entry_op_id.clone(),
178            buffers: Arc::clone(&self.buffers),
179            buffer_index: Arc::clone(&self.buffer_index),
180            workgroup_size: self.workgroup_size,
181            entry: Arc::new(entry),
182            hash: OnceLock::new(),
183            validation_set: OnceLock::new(),
184            structural_validated: std::sync::atomic::AtomicBool::new(false),
185            structural_validation_fingerprint: std::sync::atomic::AtomicU64::new(0),
186            mutation_provenance: std::sync::atomic::AtomicU8::new(0),
187            fingerprint: OnceLock::new(),
188            output_buffer_index: OnceLock::new(),
189            has_indirect_dispatch: OnceLock::new(),
190            stats: OnceLock::new(),
191            non_composable_with_self: self.non_composable_with_self,
192        }
193    }
194
195    /// Clone this program with replacement buffer declarations while
196    /// preserving the entry body, workgroup size, and metadata flags.
197    #[must_use]
198    #[inline]
199    pub fn with_rewritten_buffers(&self, buffers: Vec<BufferDecl>) -> Self {
200        let buffer_index = Self::build_buffer_index(&buffers);
201        Self {
202            entry_op_id: self.entry_op_id.clone(),
203            buffers: Arc::from(buffers),
204            buffer_index: Arc::new(buffer_index),
205            workgroup_size: self.workgroup_size,
206            entry: Arc::clone(&self.entry),
207            hash: OnceLock::new(),
208            validation_set: OnceLock::new(),
209            structural_validated: std::sync::atomic::AtomicBool::new(false),
210            structural_validation_fingerprint: std::sync::atomic::AtomicU64::new(0),
211            mutation_provenance: std::sync::atomic::AtomicU8::new(0),
212            fingerprint: OnceLock::new(),
213            output_buffer_index: OnceLock::new(),
214            has_indirect_dispatch: OnceLock::new(),
215            stats: OnceLock::new(),
216            non_composable_with_self: self.non_composable_with_self,
217        }
218    }
219
220    /// Clone this program with replacement dispatch dimensions and entry body
221    /// while preserving the existing buffer table, indexes, and metadata flags.
222    #[must_use]
223    #[inline]
224    pub fn with_rewritten_workgroup_size_and_entry(
225        &self,
226        workgroup_size: [u32; 3],
227        entry: Vec<Node>,
228    ) -> Self {
229        Self {
230            entry_op_id: self.entry_op_id.clone(),
231            buffers: Arc::clone(&self.buffers),
232            buffer_index: Arc::clone(&self.buffer_index),
233            workgroup_size,
234            entry: Arc::new(entry),
235            hash: OnceLock::new(),
236            validation_set: OnceLock::new(),
237            structural_validated: std::sync::atomic::AtomicBool::new(false),
238            structural_validation_fingerprint: std::sync::atomic::AtomicU64::new(0),
239            mutation_provenance: std::sync::atomic::AtomicU8::new(0),
240            fingerprint: OnceLock::new(),
241            output_buffer_index: OnceLock::new(),
242            has_indirect_dispatch: OnceLock::new(),
243            stats: OnceLock::new(),
244            non_composable_with_self: self.non_composable_with_self,
245        }
246    }
247
248    /// Consume the program and return its entry nodes, reusing the
249    /// backing vector when this program owns the entry body uniquely.
250    #[must_use]
251    #[inline]
252    pub fn into_entry_vec(self) -> Vec<Node> {
253        Arc::try_unwrap(self.entry).unwrap_or_else(|entry| entry.as_ref().clone())
254    }
255
256    /// Create an arena-backed program scaffold.
257    ///
258    /// This constructor is the opt-in migration path for builders that want
259    /// [`ExprRef`](crate::ir_inner::model::arena::ExprRef) handles instead of boxed
260    /// expression trees. [`Program::new`] remains the boxed-tree constructor.
261    #[must_use]
262    #[inline]
263    pub fn with_arena(
264        arena: &ExprArena,
265        buffers: Vec<BufferDecl>,
266        workgroup_size: [u32; 3],
267    ) -> ArenaProgram<'_> {
268        ArenaProgram::new(arena, buffers, workgroup_size)
269    }
270
271    /// Create a minimal program with no buffers and an empty body.
272    ///
273    /// # Examples
274    ///
275    /// ```
276    /// use vyre::ir::Program;
277    ///
278    /// let program = Program::empty();
279    ///
280    /// assert!(program.buffers().is_empty());
281    /// assert_eq!(program.workgroup_size(), [1, 1, 1]);
282    /// assert!(program.is_explicit_noop());
283    /// ```
284    #[must_use]
285    #[inline]
286    pub fn empty() -> Self {
287        Self::wrapped(Vec::new(), [1, 1, 1], Vec::new())
288    }
289
290    /// Attach the stable operation ID whose conform registry entry certifies
291    /// this program for runtime lowering.
292    #[must_use]
293    #[inline]
294    pub fn with_entry_op_id(mut self, op_id: impl Into<String>) -> Self {
295        self.entry_op_id = Some(op_id.into());
296        self.invalidate_caches();
297        self
298    }
299
300    /// Stable operation ID required by the conform gate.
301    #[must_use]
302    #[inline]
303    pub fn entry_op_id(&self) -> Option<&str> {
304        self.entry_op_id.as_deref()
305    }
306
307    /// Attach an optional operation ID while preserving anonymous test IR.
308    #[must_use]
309    #[inline]
310    pub(crate) fn with_optional_entry_op_id(mut self, op_id: Option<String>) -> Self {
311        self.entry_op_id = op_id;
312        self.invalidate_caches();
313        self
314    }
315}