1#![allow(
2 clippy::doc_lazy_continuation,
3 clippy::double_must_use,
4 clippy::manual_div_ceil,
5 clippy::needless_range_loop,
6 clippy::collapsible_if,
7 clippy::match_like_matches_macro,
8 clippy::redundant_closure,
9 clippy::too_many_arguments,
10 clippy::nonminimal_bool,
11 clippy::derivable_impls,
12 clippy::unnecessary_lazy_evaluations,
13 clippy::needless_lifetimes,
14 clippy::bind_instead_of_map,
15 clippy::needless_borrows_for_generic_args,
16 clippy::map_entry,
17 clippy::map_identity,
18 clippy::manual_map,
19 clippy::match_single_binding,
20 clippy::field_reassign_with_default,
21 dead_code,
22 unused_variables
23)]
24use std::collections::VecDeque;
32use std::sync::{mpsc, Mutex, MutexGuard, OnceLock};
33
34use rustc_hash::FxHashMap;
35use vyre_lower::KernelDescriptor;
36
37mod emitter;
38mod error;
39pub mod patterns;
40pub mod program;
41pub use error::EmitError;
42pub use vyre_lower;
43
44#[derive(serde::Serialize, serde::Deserialize, Debug)]
45pub struct BindResultEntry {
46 pub vyre_op_id: u32,
47 pub op_kind: String,
48 pub init_handle: u32,
49 pub init_scalar_kind: Option<String>,
50 pub child_body_depth: usize,
51 pub value_types_at_call: Option<u32>,
52 pub publish_path: String,
53 pub local_allocated_ty: Option<u32>,
54}
55
56const MODULE_CACHE_CAPACITY: usize = 64;
57static MODULE_CACHE: OnceLock<Mutex<ModuleCache>> = OnceLock::new();
58
59#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
60struct ModuleCacheKey([u8; 16]);
61
62#[derive(Clone)]
63struct CachedModule {
64 descriptor: KernelDescriptor,
65 module: naga::Module,
66}
67
68#[derive(Default)]
69struct ModuleCache {
70 entries: FxHashMap<ModuleCacheKey, CachedModule>,
71 order: VecDeque<ModuleCacheKey>,
72 #[cfg(test)]
73 hits: usize,
74}
75
76impl ModuleCache {
77 fn get(&mut self, key: ModuleCacheKey, desc: &KernelDescriptor) -> Option<naga::Module> {
78 let cached = self.entries.get(&key)?;
79 if cached.descriptor != *desc {
80 return None;
81 }
82 #[cfg(test)]
83 {
84 self.hits += 1;
85 }
86 Some(cached.module.clone())
87 }
88
89 fn insert(&mut self, key: ModuleCacheKey, desc: &KernelDescriptor, module: &naga::Module) {
90 if self.entries.contains_key(&key) {
91 self.entries.insert(
92 key,
93 CachedModule {
94 descriptor: desc.clone(),
95 module: module.clone(),
96 },
97 );
98 return;
99 }
100 if self.entries.len() >= MODULE_CACHE_CAPACITY {
101 if let Some(oldest) = self.order.pop_front() {
102 self.entries.remove(&oldest);
103 }
104 }
105 self.order.push_back(key);
106 self.entries.insert(
107 key,
108 CachedModule {
109 descriptor: desc.clone(),
110 module: module.clone(),
111 },
112 );
113 }
114}
115
116fn module_cache() -> &'static Mutex<ModuleCache> {
117 MODULE_CACHE.get_or_init(|| Mutex::new(ModuleCache::default()))
118}
119
120fn lock_module_cache() -> MutexGuard<'static, ModuleCache> {
121 module_cache().lock().unwrap_or_else(|error| {
122 panic!(
123 "Vyre Naga module cache lock was poisoned: {error}. Fix: discard the process-local shader module cache after a panic; continuing could reuse corrupted module state."
124 )
125 })
126}
127
128fn descriptor_cache_key(desc: &KernelDescriptor) -> ModuleCacheKey {
129 let mut hasher = blake3::Hasher::new();
130 let stable_debug = format!("{desc:?}");
131 hasher.update(&(stable_debug.len() as u64).to_le_bytes());
132 hasher.update(stable_debug.as_bytes());
133 let digest = hasher.finalize();
134 let mut out = [0u8; 16];
135 out.copy_from_slice(&digest.as_bytes()[..16]);
136 ModuleCacheKey(out)
137}
138
139#[cfg(test)]
140fn clear_module_cache_for_tests() {
141 *lock_module_cache() = ModuleCache::default();
142}
143
144#[cfg(test)]
145fn module_cache_hits_for_tests() -> usize {
146 lock_module_cache().hits
147}
148
149pub fn emit_optimized(desc: &KernelDescriptor) -> Result<naga::Module, EmitError> {
163 emit_optimized_with_stats(desc).map(|(m, _)| m)
164}
165
166pub fn emit_optimized_with_stats(
172 desc: &KernelDescriptor,
173) -> Result<(naga::Module, vyre_lower::rewrites::OptimizationStats), EmitError> {
174 let (optimized, stats) = vyre_lower::rewrites::run_all_with_stats(desc);
175 debug_assert!(
176 vyre_lower::verify::verify(&optimized).is_ok(),
177 "rewrite pipeline produced an invalid descriptor - see vyre_lower::verify for the contract"
178 );
179 let module = emit(&optimized)?;
180 Ok((module, stats))
181}
182
183#[must_use]
190pub fn emit_many_optimized(descs: &[KernelDescriptor]) -> Vec<Result<naga::Module, EmitError>> {
191 emit_many_with(descs, emit_optimized)
192}
193
194pub fn emit(desc: &KernelDescriptor) -> Result<naga::Module, EmitError> {
205 let cache_key = descriptor_cache_key(desc);
206 if let Some(module) = lock_module_cache().get(cache_key, desc) {
207 return Ok(module);
208 }
209 let module = emitter::emit_uncached(desc)?;
210 lock_module_cache().insert(cache_key, desc, &module);
211 Ok(module)
212}
213
214#[must_use]
219pub fn emit_many(descs: &[KernelDescriptor]) -> Vec<Result<naga::Module, EmitError>> {
220 emit_many_with(descs, emit)
221}
222
223fn emit_many_with(
224 descs: &[KernelDescriptor],
225 emit_one: fn(&KernelDescriptor) -> Result<naga::Module, EmitError>,
226) -> Vec<Result<naga::Module, EmitError>> {
227 if descs.len() <= 1 {
228 return descs.iter().map(emit_one).collect();
229 }
230 let worker_count = std::thread::available_parallelism()
231 .map(usize::from)
232 .unwrap_or(1)
233 .min(descs.len())
234 .max(1);
235 let chunk_size = descs.len().div_ceil(worker_count);
236 let (tx, rx) = mpsc::channel();
237 std::thread::scope(|scope| {
238 for (chunk_index, chunk) in descs.chunks(chunk_size).enumerate() {
239 let tx = tx.clone();
240 let start = chunk_index * chunk_size;
241 scope.spawn(move || {
242 for (offset, desc) in chunk.iter().enumerate() {
243 if tx.send((start + offset, emit_one(desc))).is_err() {
244 break;
245 }
246 }
247 });
248 }
249 });
250 drop(tx);
251
252 let mut results: Vec<Option<Result<naga::Module, EmitError>>> =
253 std::iter::repeat_with(|| None).take(descs.len()).collect();
254 for (index, result) in rx {
255 if let Some(slot) = results.get_mut(index) {
256 *slot = Some(result);
257 }
258 }
259 results
260 .into_iter()
261 .enumerate()
262 .map(|(index, result)| {
263 result.unwrap_or_else(|| {
264 Err(EmitError::InvalidDescriptor(format!(
265 "parallel Naga emit worker did not return descriptor index {index}. Fix: keep emit_many chunk scheduling and result collection synchronized."
266 )))
267 })
268 })
269 .collect()
270}
271
272#[cfg(test)]
273mod tests;