Skip to main content

vyre_foundation/ir_inner/model/
arena.rs

1//! Arena-backed expression storage for opt-in IR construction.
2#![allow(unsafe_code)]
3
4use crate::ir_inner::model::expr::Expr;
5use crate::ir_inner::model::program::BufferDecl;
6use bumpalo::Bump;
7use rustc_hash::FxHashMap;
8use std::cell::{Cell, UnsafeCell};
9use std::sync::Arc;
10
11/// Stable handle to an expression allocated in an [`ExprArena`].
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub struct ExprRef {
14    index: usize,
15}
16
17impl ExprRef {
18    /// Zero-based expression index within the arena.
19    #[must_use]
20    #[inline]
21    pub fn index(self) -> usize {
22        self.index
23    }
24}
25
26/// Bump-allocated expression arena.
27///
28/// This is an opt-in migration path for builders that create many temporary
29/// expression nodes. Existing callers can continue to use boxed [`Expr`] trees
30/// through `Program::new`.
31#[derive(Default)]
32pub struct ExprArena {
33    bump: Bump,
34    exprs: UnsafeCell<Vec<*const Expr>>,
35    len: Cell<usize>,
36}
37
38impl ExprArena {
39    /// Create an empty expression arena.
40    #[must_use]
41    #[inline]
42    pub fn new() -> Self {
43        Self::default()
44    }
45
46    /// Allocate an expression and return its stable arena handle.
47    #[must_use]
48    pub fn alloc(&self, expr: Expr) -> ExprRef {
49        let index = self.len.get();
50        let ptr = self.bump.alloc(expr) as *const Expr;
51        // SAFETY: ExprArena is a single-writer builder and never shared across writer threads.
52        unsafe {
53            (*self.exprs.get()).push(ptr);
54        }
55        self.len.set(index + 1);
56        ExprRef { index }
57    }
58
59    /// Borrow an allocated expression by handle.
60    #[must_use]
61    pub fn get(&self, expr_ref: ExprRef) -> Option<&Expr> {
62        // SAFETY: pointers are produced only by `self.bump.alloc` and remain
63        // stable until `reset(&mut self)`, which requires exclusive access.
64        unsafe {
65            let vec: &Vec<*const Expr> = &*self.exprs.get();
66            vec.get(expr_ref.index).and_then(|ptr| ptr.as_ref())
67        }
68    }
69
70    /// Clear allocated expressions.
71    pub fn reset(&mut self) {
72        self.exprs.get_mut().clear();
73        self.len.set(0);
74        self.bump.reset();
75    }
76
77    /// Number of expressions allocated in this arena.
78    #[must_use]
79    #[inline]
80    pub fn len(&self) -> usize {
81        self.len.get()
82    }
83
84    /// Return true if no expressions have been allocated.
85    #[must_use]
86    #[inline]
87    pub fn is_empty(&self) -> bool {
88        self.len() == 0
89    }
90}
91
92/// Lightweight program scaffold for arena-backed expression builders.
93pub struct ArenaProgram<'a> {
94    arena: &'a ExprArena,
95    buffers: Vec<BufferDecl>,
96    buffer_index: FxHashMap<Arc<str>, usize>,
97    workgroup_size: [u32; 3],
98    entry: Vec<ExprRef>,
99}
100
101impl<'a> ArenaProgram<'a> {
102    pub(crate) fn new(
103        arena: &'a ExprArena,
104        buffers: Vec<BufferDecl>,
105        workgroup_size: [u32; 3],
106    ) -> Self {
107        let mut buffer_index = FxHashMap::default();
108        buffer_index.reserve(buffers.len());
109        for (index, buffer) in buffers.iter().enumerate() {
110            buffer_index
111                .entry(Arc::clone(&buffer.name))
112                .or_insert(index);
113        }
114        Self {
115            arena,
116            buffers,
117            buffer_index,
118            workgroup_size,
119            entry: Vec::new(),
120        }
121    }
122
123    /// Allocate `expr` in the backing arena and append it to the entry list.
124    #[must_use]
125    pub fn push_expr(&mut self, expr: Expr) -> ExprRef {
126        let expr_ref = self.arena.alloc(expr);
127        self.entry.push(expr_ref);
128        expr_ref
129    }
130
131    /// Return an expression previously appended to this arena program.
132    #[must_use]
133    pub fn expr(&self, expr_ref: ExprRef) -> Option<&Expr> {
134        self.arena.get(expr_ref)
135    }
136
137    /// Declared buffers.
138    #[must_use]
139    pub fn buffers(&self) -> &[BufferDecl] {
140        &self.buffers
141    }
142
143    /// Look up a declared buffer by name.
144    #[must_use]
145    pub fn buffer(&self, name: &str) -> Option<&BufferDecl> {
146        self.buffer_index
147            .get(name)
148            .and_then(|&index| self.buffers.get(index))
149    }
150
151    /// Workgroup dimensions.
152    #[must_use]
153    pub fn workgroup_size(&self) -> [u32; 3] {
154        self.workgroup_size
155    }
156
157    /// Entry expression handles in append order.
158    #[must_use]
159    pub fn entry(&self) -> &[ExprRef] {
160        &self.entry
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::{ArenaProgram, ExprArena};
167    use crate::ir_inner::model::expr::Expr;
168    use crate::ir_inner::model::program::BufferDecl;
169    use crate::ir_inner::model::types::DataType;
170
171    #[test]
172    fn arena_allocates_stable_expression_refs() {
173        let arena = ExprArena::new();
174        let first = arena.alloc(Expr::u32(7));
175        let second = arena.alloc(Expr::var("x"));
176        assert_eq!(first.index(), 0);
177        assert_eq!(second.index(), 1);
178        assert_eq!(arena.get(first), Some(&Expr::u32(7)));
179        assert_eq!(arena.get(second), Some(&Expr::var("x")));
180    }
181
182    #[test]
183    fn arena_program_keeps_buffers_and_expression_handles() {
184        let arena = ExprArena::new();
185        let mut program = ArenaProgram::new(
186            &arena,
187            vec![BufferDecl::read("input", 0, DataType::U32)],
188            [64, 1, 1],
189        );
190        let expr_ref = program.push_expr(Expr::load("input", Expr::u32(0)));
191        assert_eq!(program.entry(), &[expr_ref]);
192        assert_eq!(program.buffer("input").map(BufferDecl::binding), Some(0));
193        assert_eq!(
194            program.expr(expr_ref),
195            Some(&Expr::load("input", Expr::u32(0)))
196        );
197    }
198}