vyre_foundation/ir_inner/model/
arena.rs1#![allow(unsafe_code)]
3
4use crate::ir_inner::model::expr::Expr;
5use crate::ir_inner::model::program::BufferDecl;
6use bumpalo::Bump;
7use rustc_hash::FxHashMap;
8use std::cell::{Cell, UnsafeCell};
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub struct ExprRef {
14 index: usize,
15}
16
17impl ExprRef {
18 #[must_use]
20 #[inline]
21 pub fn index(self) -> usize {
22 self.index
23 }
24}
25
26#[derive(Default)]
32pub struct ExprArena {
33 bump: Bump,
34 exprs: UnsafeCell<Vec<*const Expr>>,
35 len: Cell<usize>,
36}
37
38impl ExprArena {
39 #[must_use]
41 #[inline]
42 pub fn new() -> Self {
43 Self::default()
44 }
45
46 #[must_use]
48 pub fn alloc(&self, expr: Expr) -> ExprRef {
49 let index = self.len.get();
50 let ptr = self.bump.alloc(expr) as *const Expr;
51 unsafe {
53 (*self.exprs.get()).push(ptr);
54 }
55 self.len.set(index + 1);
56 ExprRef { index }
57 }
58
59 #[must_use]
61 pub fn get(&self, expr_ref: ExprRef) -> Option<&Expr> {
62 unsafe {
65 let vec: &Vec<*const Expr> = &*self.exprs.get();
66 vec.get(expr_ref.index).and_then(|ptr| ptr.as_ref())
67 }
68 }
69
70 pub fn reset(&mut self) {
72 self.exprs.get_mut().clear();
73 self.len.set(0);
74 self.bump.reset();
75 }
76
77 #[must_use]
79 #[inline]
80 pub fn len(&self) -> usize {
81 self.len.get()
82 }
83
84 #[must_use]
86 #[inline]
87 pub fn is_empty(&self) -> bool {
88 self.len() == 0
89 }
90}
91
92pub struct ArenaProgram<'a> {
94 arena: &'a ExprArena,
95 buffers: Vec<BufferDecl>,
96 buffer_index: FxHashMap<Arc<str>, usize>,
97 workgroup_size: [u32; 3],
98 entry: Vec<ExprRef>,
99}
100
101impl<'a> ArenaProgram<'a> {
102 pub(crate) fn new(
103 arena: &'a ExprArena,
104 buffers: Vec<BufferDecl>,
105 workgroup_size: [u32; 3],
106 ) -> Self {
107 let mut buffer_index = FxHashMap::default();
108 buffer_index.reserve(buffers.len());
109 for (index, buffer) in buffers.iter().enumerate() {
110 buffer_index
111 .entry(Arc::clone(&buffer.name))
112 .or_insert(index);
113 }
114 Self {
115 arena,
116 buffers,
117 buffer_index,
118 workgroup_size,
119 entry: Vec::new(),
120 }
121 }
122
123 #[must_use]
125 pub fn push_expr(&mut self, expr: Expr) -> ExprRef {
126 let expr_ref = self.arena.alloc(expr);
127 self.entry.push(expr_ref);
128 expr_ref
129 }
130
131 #[must_use]
133 pub fn expr(&self, expr_ref: ExprRef) -> Option<&Expr> {
134 self.arena.get(expr_ref)
135 }
136
137 #[must_use]
139 pub fn buffers(&self) -> &[BufferDecl] {
140 &self.buffers
141 }
142
143 #[must_use]
145 pub fn buffer(&self, name: &str) -> Option<&BufferDecl> {
146 self.buffer_index
147 .get(name)
148 .and_then(|&index| self.buffers.get(index))
149 }
150
151 #[must_use]
153 pub fn workgroup_size(&self) -> [u32; 3] {
154 self.workgroup_size
155 }
156
157 #[must_use]
159 pub fn entry(&self) -> &[ExprRef] {
160 &self.entry
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::{ArenaProgram, ExprArena};
167 use crate::ir_inner::model::expr::Expr;
168 use crate::ir_inner::model::program::BufferDecl;
169 use crate::ir_inner::model::types::DataType;
170
171 #[test]
172 fn arena_allocates_stable_expression_refs() {
173 let arena = ExprArena::new();
174 let first = arena.alloc(Expr::u32(7));
175 let second = arena.alloc(Expr::var("x"));
176 assert_eq!(first.index(), 0);
177 assert_eq!(second.index(), 1);
178 assert_eq!(arena.get(first), Some(&Expr::u32(7)));
179 assert_eq!(arena.get(second), Some(&Expr::var("x")));
180 }
181
182 #[test]
183 fn arena_program_keeps_buffers_and_expression_handles() {
184 let arena = ExprArena::new();
185 let mut program = ArenaProgram::new(
186 &arena,
187 vec![BufferDecl::read("input", 0, DataType::U32)],
188 [64, 1, 1],
189 );
190 let expr_ref = program.push_expr(Expr::load("input", Expr::u32(0)));
191 assert_eq!(program.entry(), &[expr_ref]);
192 assert_eq!(program.buffer("input").map(BufferDecl::binding), Some(0));
193 assert_eq!(
194 program.expr(expr_ref),
195 Some(&Expr::load("input", Expr::u32(0)))
196 );
197 }
198}