Skip to main content

vyre_foundation/ir_inner/model/program/
builder.rs

1use std::sync::{Arc, OnceLock};
2
3use rustc_hash::FxHashMap;
4
5use crate::ir_inner::model::arena::{ArenaProgram, ExprArena};
6use crate::ir_inner::model::node::Node;
7
8use super::{BufferDecl, Program};
9
10impl Program {
11    /// Synthetic generator id used when callers submit a raw top-level body
12    /// instead of an explicit `Node::Region`.
13    pub const ROOT_REGION_GENERATOR: &'static str = "vyre.program.root";
14
15    /// Create a complete program from buffer declarations, workgroup size, and
16    /// entry-point nodes, auto-wrapping the top-level body in a root Region
17    /// when necessary.
18    ///
19    /// This is the default construction path for runnable Programs.
20    ///
21    /// # Examples
22    ///
23    /// ```
24    /// use vyre::ir::{BufferAccess, BufferDecl, DataType, Node, Program};
25    ///
26    /// let program = Program::wrapped(
27    ///     vec![BufferDecl::storage(
28    ///         "output",
29    ///         0,
30    ///         BufferAccess::ReadWrite,
31    ///         DataType::U32,
32    ///     )],
33    ///     [64, 1, 1],
34    ///     Vec::new(),
35    /// );
36    ///
37    /// assert_eq!(program.workgroup_size(), [64, 1, 1]);
38    /// assert_eq!(program.buffers().len(), 1);
39    /// assert!(matches!(program.entry(), [Node::Region { .. }]));
40    /// ```
41    #[must_use]
42    #[inline]
43    pub fn wrapped(buffers: Vec<BufferDecl>, workgroup_size: [u32; 3], entry: Vec<Node>) -> Self {
44        Self::new_raw(buffers, workgroup_size, Self::wrap_entry(entry))
45    }
46
47    /// Create a complete program from buffer declarations, workgroup size, and
48    /// entry-point nodes.
49    ///
50    /// # Examples
51    ///
52    /// ```
53    /// use vyre::ir::{BufferAccess, BufferDecl, DataType, Node, Program};
54    ///
55    /// let program = Program::wrapped(
56    ///     vec![BufferDecl::storage(
57    ///         "output",
58    ///         0,
59    ///         BufferAccess::ReadWrite,
60    ///         DataType::U32,
61    ///     )],
62    ///     [64, 1, 1],
63    ///     Vec::new(),
64    /// );
65    ///
66    /// assert_eq!(program.workgroup_size(), [64, 1, 1]);
67    /// assert_eq!(program.buffers().len(), 1);
68    /// assert!(matches!(program.entry(), [Node::Region { .. }]));
69    /// ```
70    #[deprecated(
71        note = "Program::new preserves raw top-level entry nodes. Use Program::wrapped for runnable programs; reserve Program::new for wire decode and negative tests."
72    )]
73    #[must_use]
74    #[inline]
75    pub fn new(buffers: Vec<BufferDecl>, workgroup_size: [u32; 3], entry: Vec<Node>) -> Self {
76        Self::new_raw(buffers, workgroup_size, entry)
77    }
78
79    #[must_use]
80    #[inline]
81    pub(crate) fn new_raw(
82        buffers: Vec<BufferDecl>,
83        workgroup_size: [u32; 3],
84        entry: Vec<Node>,
85    ) -> Self {
86        let mut interner = FxHashMap::<Arc<str>, Arc<str>>::default();
87        interner.reserve(buffers.len());
88        let buffers: Vec<BufferDecl> = buffers
89            .into_iter()
90            .map(|mut b| {
91                let arc = interner
92                    .entry(Arc::clone(&b.name))
93                    .or_insert_with(|| Arc::clone(&b.name))
94                    .clone();
95                b.name = arc;
96                b
97            })
98            .collect();
99        let buffer_index = Self::build_buffer_index(&buffers);
100        Self {
101            entry_op_id: None,
102            buffers: Arc::from(buffers),
103            buffer_index: Arc::new(buffer_index),
104            workgroup_size,
105            entry: Arc::new(entry),
106            hash: OnceLock::new(),
107            validation_set: OnceLock::new(),
108            structural_validated: std::sync::atomic::AtomicBool::new(false),
109            fingerprint: OnceLock::new(),
110            output_buffer_index: OnceLock::new(),
111            has_indirect_dispatch: OnceLock::new(),
112            stats: OnceLock::new(),
113            non_composable_with_self: false,
114        }
115    }
116
117    /// Same as [`Self::with_rewritten_entry`] but wraps the entry first via
118    /// the runnable-Region root contract (matches [`Self::wrapped`]). Use
119    /// from passes that produce a fully fresh entry body but want to reuse
120    /// the existing buffer Arc instead of paying for a full
121    /// [`Self::wrapped`] (which deep-clones buffers, re-interns names, and
122    /// rebuilds the buffer index).
123    #[must_use]
124    #[inline]
125    pub fn with_rewritten_wrapped_entry(&self, entry: Vec<Node>) -> Self {
126        self.with_rewritten_entry(Self::wrap_entry(entry))
127    }
128
129    /// Consume this program and rebuild it with `f` applied to the owned
130    /// entry vec. Reuses the entry Arc when uniquely owned (the common
131    /// case under the optimizer fixpoint)  -  no deep clone of the entry
132    /// body and no scaffold allocation. Equivalent to:
133    ///
134    /// ```ignore
135    /// let scaffold = program.with_rewritten_entry(Vec::new());
136    /// let entry = f(program.into_entry_vec());
137    /// scaffold.with_rewritten_entry(entry)
138    /// ```
139    ///
140    /// but produces only one new `Program` value instead of two.
141    #[must_use]
142    #[inline]
143    pub fn map_entry<F: FnOnce(Vec<Node>) -> Vec<Node>>(self, f: F) -> Self {
144        let entry_op_id = self.entry_op_id.clone();
145        let buffers = Arc::clone(&self.buffers);
146        let buffer_index = Arc::clone(&self.buffer_index);
147        let workgroup_size = self.workgroup_size;
148        let non_composable_with_self = self.non_composable_with_self;
149        let entry = f(self.into_entry_vec());
150        Self {
151            entry_op_id,
152            buffers,
153            buffer_index,
154            workgroup_size,
155            entry: Arc::new(entry),
156            hash: OnceLock::new(),
157            validation_set: OnceLock::new(),
158            structural_validated: std::sync::atomic::AtomicBool::new(false),
159            fingerprint: OnceLock::new(),
160            output_buffer_index: OnceLock::new(),
161            has_indirect_dispatch: OnceLock::new(),
162            stats: OnceLock::new(),
163            non_composable_with_self,
164        }
165    }
166
167    /// Clone this program with a replacement entry body while preserving the
168    /// existing buffer table, workgroup size, and optional certified op id.
169    #[must_use]
170    #[inline]
171    pub fn with_rewritten_entry(&self, entry: Vec<Node>) -> Self {
172        Self {
173            entry_op_id: self.entry_op_id.clone(),
174            buffers: Arc::clone(&self.buffers),
175            buffer_index: Arc::clone(&self.buffer_index),
176            workgroup_size: self.workgroup_size,
177            entry: Arc::new(entry),
178            hash: OnceLock::new(),
179            validation_set: OnceLock::new(),
180            structural_validated: std::sync::atomic::AtomicBool::new(false),
181            fingerprint: OnceLock::new(),
182            output_buffer_index: OnceLock::new(),
183            has_indirect_dispatch: OnceLock::new(),
184            stats: OnceLock::new(),
185            non_composable_with_self: self.non_composable_with_self,
186        }
187    }
188
189    /// Clone this program with replacement buffer declarations while
190    /// preserving the entry body, workgroup size, and metadata flags.
191    #[must_use]
192    #[inline]
193    pub fn with_rewritten_buffers(&self, buffers: Vec<BufferDecl>) -> Self {
194        let buffer_index = Self::build_buffer_index(&buffers);
195        Self {
196            entry_op_id: self.entry_op_id.clone(),
197            buffers: Arc::from(buffers),
198            buffer_index: Arc::new(buffer_index),
199            workgroup_size: self.workgroup_size,
200            entry: Arc::clone(&self.entry),
201            hash: OnceLock::new(),
202            validation_set: OnceLock::new(),
203            structural_validated: std::sync::atomic::AtomicBool::new(false),
204            fingerprint: OnceLock::new(),
205            output_buffer_index: OnceLock::new(),
206            has_indirect_dispatch: OnceLock::new(),
207            stats: OnceLock::new(),
208            non_composable_with_self: self.non_composable_with_self,
209        }
210    }
211
212    /// Clone this program with replacement dispatch dimensions and entry body
213    /// while preserving the existing buffer table, indexes, and metadata flags.
214    #[must_use]
215    #[inline]
216    pub fn with_rewritten_workgroup_size_and_entry(
217        &self,
218        workgroup_size: [u32; 3],
219        entry: Vec<Node>,
220    ) -> Self {
221        Self {
222            entry_op_id: self.entry_op_id.clone(),
223            buffers: Arc::clone(&self.buffers),
224            buffer_index: Arc::clone(&self.buffer_index),
225            workgroup_size,
226            entry: Arc::new(entry),
227            hash: OnceLock::new(),
228            validation_set: OnceLock::new(),
229            structural_validated: std::sync::atomic::AtomicBool::new(false),
230            fingerprint: OnceLock::new(),
231            output_buffer_index: OnceLock::new(),
232            has_indirect_dispatch: OnceLock::new(),
233            stats: OnceLock::new(),
234            non_composable_with_self: self.non_composable_with_self,
235        }
236    }
237
238    /// Consume the program and return its entry nodes, reusing the
239    /// backing vector when this program owns the entry body uniquely.
240    #[must_use]
241    #[inline]
242    pub fn into_entry_vec(self) -> Vec<Node> {
243        Arc::try_unwrap(self.entry).unwrap_or_else(|entry| entry.as_ref().clone())
244    }
245
246    /// Create an arena-backed program scaffold.
247    ///
248    /// This constructor is the opt-in migration path for builders that want
249    /// [`ExprRef`](crate::ir_inner::model::arena::ExprRef) handles instead of boxed
250    /// expression trees. [`Program::new`] remains the boxed-tree constructor.
251    #[must_use]
252    #[inline]
253    pub fn with_arena(
254        arena: &ExprArena,
255        buffers: Vec<BufferDecl>,
256        workgroup_size: [u32; 3],
257    ) -> ArenaProgram<'_> {
258        ArenaProgram::new(arena, buffers, workgroup_size)
259    }
260
261    /// Create a minimal program with no buffers and an empty body.
262    ///
263    /// # Examples
264    ///
265    /// ```
266    /// use vyre::ir::Program;
267    ///
268    /// let program = Program::empty();
269    ///
270    /// assert!(program.buffers().is_empty());
271    /// assert_eq!(program.workgroup_size(), [1, 1, 1]);
272    /// assert!(program.is_explicit_noop());
273    /// ```
274    #[must_use]
275    #[inline]
276    pub fn empty() -> Self {
277        Self::wrapped(Vec::new(), [1, 1, 1], Vec::new())
278    }
279
280    /// Attach the stable operation ID whose conform registry entry certifies
281    /// this program for runtime lowering.
282    #[must_use]
283    #[inline]
284    pub fn with_entry_op_id(mut self, op_id: impl Into<String>) -> Self {
285        self.entry_op_id = Some(op_id.into());
286        self.invalidate_caches();
287        self
288    }
289
290    /// Stable operation ID required by the conform gate.
291    #[must_use]
292    #[inline]
293    pub fn entry_op_id(&self) -> Option<&str> {
294        self.entry_op_id.as_deref()
295    }
296
297    /// Attach an optional operation ID while preserving anonymous test IR.
298    #[must_use]
299    #[inline]
300    pub(crate) fn with_optional_entry_op_id(mut self, op_id: Option<String>) -> Self {
301        self.entry_op_id = op_id;
302        self.invalidate_caches();
303        self
304    }
305}