Skip to main content

vyre_emit_naga/
lib.rs

1#![allow(
2    clippy::doc_lazy_continuation,
3    clippy::double_must_use,
4    clippy::manual_div_ceil,
5    clippy::needless_range_loop,
6    clippy::collapsible_if,
7    clippy::match_like_matches_macro,
8    clippy::redundant_closure,
9    clippy::too_many_arguments,
10    clippy::nonminimal_bool,
11    clippy::derivable_impls,
12    clippy::unnecessary_lazy_evaluations,
13    clippy::needless_lifetimes,
14    clippy::bind_instead_of_map,
15    clippy::needless_borrows_for_generic_args,
16    clippy::map_entry,
17    clippy::map_identity,
18    clippy::manual_map,
19    clippy::match_single_binding,
20    clippy::field_reassign_with_default,
21    dead_code,
22    unused_variables
23)]
24//! Naga IR emitter for vyre `KernelDescriptor`.
25//!
26//! Consumes a substrate-neutral `vyre_lower::KernelDescriptor` and
27//! produces a `naga::Module`. The emitter owns only Naga construction;
28//! descriptor shaping and substrate-neutral analyses stay in
29//! `vyre-lower`.
30
31use std::collections::VecDeque;
32use std::sync::{mpsc, Mutex, MutexGuard, OnceLock};
33
34use rustc_hash::FxHashMap;
35use vyre_lower::KernelDescriptor;
36
37mod emitter;
38mod error;
39pub mod patterns;
40pub mod program;
41pub use error::EmitError;
42pub use vyre_lower;
43
44#[derive(serde::Serialize, serde::Deserialize, Debug)]
45pub struct BindResultEntry {
46    pub vyre_op_id: u32,
47    pub op_kind: String,
48    pub init_handle: u32,
49    pub init_scalar_kind: Option<String>,
50    pub child_body_depth: usize,
51    pub value_types_at_call: Option<u32>,
52    pub publish_path: String,
53    pub local_allocated_ty: Option<u32>,
54}
55
56const MODULE_CACHE_CAPACITY: usize = 64;
57static MODULE_CACHE: OnceLock<Mutex<ModuleCache>> = OnceLock::new();
58
59#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
60struct ModuleCacheKey([u8; 16]);
61
62#[derive(Clone)]
63struct CachedModule {
64    descriptor: KernelDescriptor,
65    module: naga::Module,
66}
67
68#[derive(Default)]
69struct ModuleCache {
70    entries: FxHashMap<ModuleCacheKey, CachedModule>,
71    order: VecDeque<ModuleCacheKey>,
72    #[cfg(test)]
73    hits: usize,
74}
75
76impl ModuleCache {
77    fn get(&mut self, key: ModuleCacheKey, desc: &KernelDescriptor) -> Option<naga::Module> {
78        let cached = self.entries.get(&key)?;
79        if cached.descriptor != *desc {
80            return None;
81        }
82        #[cfg(test)]
83        {
84            self.hits += 1;
85        }
86        Some(cached.module.clone())
87    }
88
89    fn insert(&mut self, key: ModuleCacheKey, desc: &KernelDescriptor, module: &naga::Module) {
90        if self.entries.contains_key(&key) {
91            self.entries.insert(
92                key,
93                CachedModule {
94                    descriptor: desc.clone(),
95                    module: module.clone(),
96                },
97            );
98            return;
99        }
100        if self.entries.len() >= MODULE_CACHE_CAPACITY {
101            if let Some(oldest) = self.order.pop_front() {
102                self.entries.remove(&oldest);
103            }
104        }
105        self.order.push_back(key);
106        self.entries.insert(
107            key,
108            CachedModule {
109                descriptor: desc.clone(),
110                module: module.clone(),
111            },
112        );
113    }
114}
115
116fn module_cache() -> &'static Mutex<ModuleCache> {
117    MODULE_CACHE.get_or_init(|| Mutex::new(ModuleCache::default()))
118}
119
120fn lock_module_cache() -> MutexGuard<'static, ModuleCache> {
121    module_cache().lock().unwrap_or_else(|error| {
122        panic!(
123            "Vyre Naga module cache lock was poisoned: {error}. Fix: discard the process-local shader module cache after a panic; continuing could reuse corrupted module state."
124        )
125    })
126}
127
128fn descriptor_cache_key(desc: &KernelDescriptor) -> ModuleCacheKey {
129    let mut hasher = blake3::Hasher::new();
130    let stable_debug = format!("{desc:?}");
131    hasher.update(&(stable_debug.len() as u64).to_le_bytes());
132    hasher.update(stable_debug.as_bytes());
133    let digest = hasher.finalize();
134    let mut out = [0u8; 16];
135    out.copy_from_slice(&digest.as_bytes()[..16]);
136    ModuleCacheKey(out)
137}
138
139#[cfg(test)]
140fn clear_module_cache_for_tests() {
141    *lock_module_cache() = ModuleCache::default();
142}
143
144#[cfg(test)]
145fn module_cache_hits_for_tests() -> usize {
146    lock_module_cache().hits
147}
148
149/// Emit a `naga::Module` from a `KernelDescriptor` after running the
150/// full `vyre_lower::rewrites::run_all` optimization pipeline.
151///
152/// This is the recommended emission entry point  -  call this whenever
153/// you don't have a specific reason to emit the raw descriptor. The
154/// optimized form has fewer ops (dead code dropped, identity ops
155/// eliminated, common subexpressions merged, redundant loads
156/// forwarded, etc.) and produces tighter Naga IR with no semantic
157/// change.
158///
159/// # Errors
160///
161/// Same as [`emit`].
162pub fn emit_optimized(desc: &KernelDescriptor) -> Result<naga::Module, EmitError> {
163    emit_optimized_with_stats(desc).map(|(m, _)| m)
164}
165
166/// Like [`emit_optimized`] but also returns
167/// [`vyre_lower::rewrites::OptimizationStats`] so the caller can see
168/// what the rewrite stack did (op count delta, bindings dropped,
169/// fixed-point iterations needed). No duplicate work  -  `emit_optimized`
170/// is now a thin wrapper around this.
171pub fn emit_optimized_with_stats(
172    desc: &KernelDescriptor,
173) -> Result<(naga::Module, vyre_lower::rewrites::OptimizationStats), EmitError> {
174    let (optimized, stats) = vyre_lower::rewrites::run_all_with_stats(desc);
175    debug_assert!(
176        vyre_lower::verify::verify(&optimized).is_ok(),
177        "rewrite pipeline produced an invalid descriptor  -  see vyre_lower::verify for the contract"
178    );
179    let module = emit(&optimized)?;
180    Ok((module, stats))
181}
182
183/// Emit many independent descriptors after running the canonical lower rewrite
184/// pipeline on each descriptor.
185///
186/// Results preserve input order. Each descriptor still flows through the
187/// process-wide module cache, so repeated arms return cached `naga::Module`
188/// clones while unrelated arms can lower concurrently.
189#[must_use]
190pub fn emit_many_optimized(descs: &[KernelDescriptor]) -> Vec<Result<naga::Module, EmitError>> {
191    emit_many_with(descs, emit_optimized)
192}
193
194/// Emit a `naga::Module` from a `KernelDescriptor`.
195///
196/// Lowers the descriptor exactly as given. Use [`emit_optimized`] if
197/// you also want the rewrite stack applied first.
198///
199/// # Errors
200///
201/// Returns [`EmitError`] when a binding layout cannot be represented in
202/// Naga IR or when the descriptor contains an operation outside this emitter's
203/// supported lowering set.
204pub fn emit(desc: &KernelDescriptor) -> Result<naga::Module, EmitError> {
205    let cache_key = descriptor_cache_key(desc);
206    if let Some(module) = lock_module_cache().get(cache_key, desc) {
207        return Ok(module);
208    }
209    let module = emitter::emit_uncached(desc)?;
210    lock_module_cache().insert(cache_key, desc, &module);
211    Ok(module)
212}
213
214/// Emit many independent descriptors exactly as provided.
215///
216/// Results preserve input order and each descriptor uses the same cache path as
217/// [`emit`]. Use [`emit_many_optimized`] for the canonical optimized path.
218#[must_use]
219pub fn emit_many(descs: &[KernelDescriptor]) -> Vec<Result<naga::Module, EmitError>> {
220    emit_many_with(descs, emit)
221}
222
223fn emit_many_with(
224    descs: &[KernelDescriptor],
225    emit_one: fn(&KernelDescriptor) -> Result<naga::Module, EmitError>,
226) -> Vec<Result<naga::Module, EmitError>> {
227    if descs.len() <= 1 {
228        return descs.iter().map(emit_one).collect();
229    }
230    let worker_count = std::thread::available_parallelism()
231        .map(usize::from)
232        .unwrap_or(1)
233        .min(descs.len())
234        .max(1);
235    let chunk_size = descs.len().div_ceil(worker_count);
236    let (tx, rx) = mpsc::channel();
237    std::thread::scope(|scope| {
238        for (chunk_index, chunk) in descs.chunks(chunk_size).enumerate() {
239            let tx = tx.clone();
240            let start = chunk_index * chunk_size;
241            scope.spawn(move || {
242                for (offset, desc) in chunk.iter().enumerate() {
243                    if tx.send((start + offset, emit_one(desc))).is_err() {
244                        break;
245                    }
246                }
247            });
248        }
249    });
250    drop(tx);
251
252    let mut results: Vec<Option<Result<naga::Module, EmitError>>> =
253        std::iter::repeat_with(|| None).take(descs.len()).collect();
254    for (index, result) in rx {
255        if let Some(slot) = results.get_mut(index) {
256            *slot = Some(result);
257        }
258    }
259    results
260        .into_iter()
261        .enumerate()
262        .map(|(index, result)| {
263            result.unwrap_or_else(|| {
264                Err(EmitError::InvalidDescriptor(format!(
265                    "parallel Naga emit worker did not return descriptor index {index}. Fix: keep emit_many chunk scheduling and result collection synchronized."
266                )))
267            })
268        })
269        .collect()
270}
271
272#[cfg(test)]
273mod tests;