Skip to main content

shape_jit/
jit_cache.rs

1//! Content-addressed JIT code cache with dependency-based invalidation.
2//!
3//! Caches compiled native function pointers keyed by `FunctionHash`.
4//! When the same function blob appears in a subsequent compilation
5//! (same content hash), we skip recompilation and reuse the existing
6//! native code pointer.
7//!
8//! Supports dependency tracking: when a function is inlined into another,
9//! the caller records the callee hash as a dependency. If the callee
10//! changes, all dependents can be invalidated via `invalidate_by_dependency()`.
11
12use shape_vm::bytecode::FunctionHash;
13use std::collections::HashMap;
14
15use crate::optimizer::Tier2CacheKey;
16
17/// Extended cache entry with dependency tracking.
18#[derive(Debug, Clone)]
19pub struct CacheEntry {
20    /// Native code pointer.
21    pub code_ptr: *const u8,
22    /// Content hash of the function blob.
23    pub function_hash: FunctionHash,
24    /// Schema version at compilation time (for shape guard invalidation).
25    pub schema_version: u32,
26    /// Feedback epoch at compilation time (for speculation invalidation).
27    pub feedback_epoch: u32,
28    /// Hashes of functions this compiled code depends on (e.g., inlined callees).
29    pub dependencies: Vec<FunctionHash>,
30    /// Tier 2 cache key, present when this entry was produced by the
31    /// optimizing compiler with cross-function inlining.
32    pub tier2_key: Option<Tier2CacheKey>,
33}
34
35// SAFETY: CacheEntry contains a raw pointer produced by Cranelift.
36// The same safety argument as JitCodeCache applies (see below).
37unsafe impl Send for CacheEntry {}
38unsafe impl Sync for CacheEntry {}
39
40/// Cache of JIT-compiled function pointers, keyed by content hash.
41///
42/// Same blob hash = skip recompilation, reuse function pointer.
43///
44/// Tracks dependency edges so that when an inlined callee changes,
45/// all callers that embedded it can be invalidated.
46///
47/// # Safety
48///
49/// The raw `*const u8` pointers stored here point into Cranelift
50/// `JITModule` memory regions. Callers must ensure that the
51/// `JITModule` that produced a pointer outlives any use of that
52/// pointer through this cache.
53pub struct JitCodeCache {
54    entries: HashMap<FunctionHash, CacheEntry>,
55    /// Reverse index: dependency_hash -> set of dependent function hashes.
56    /// Used by `invalidate_by_dependency()` to find affected entries.
57    dependents: HashMap<FunctionHash, Vec<FunctionHash>>,
58}
59
60// SAFETY: The function pointers are produced by Cranelift and are
61// valid for the lifetime of the owning JITModule. The cache itself
62// does not execute code, it only stores and returns pointers.
63unsafe impl Send for JitCodeCache {}
64unsafe impl Sync for JitCodeCache {}
65
66impl JitCodeCache {
67    /// Create an empty cache.
68    pub fn new() -> Self {
69        Self {
70            entries: HashMap::new(),
71            dependents: HashMap::new(),
72        }
73    }
74
75    /// Create a cache pre-sized for `capacity` entries.
76    pub fn with_capacity(capacity: usize) -> Self {
77        Self {
78            entries: HashMap::with_capacity(capacity),
79            dependents: HashMap::new(),
80        }
81    }
82
83    /// Look up a cached native code pointer by content hash.
84    pub fn get(&self, hash: &FunctionHash) -> Option<*const u8> {
85        self.entries.get(hash).map(|e| e.code_ptr)
86    }
87
88    /// Insert a compiled function pointer for the given content hash.
89    ///
90    /// Creates a minimal `CacheEntry` with no dependencies and zero
91    /// version/epoch. If an entry with the same hash already exists
92    /// it is overwritten.
93    pub fn insert(&mut self, hash: FunctionHash, ptr: *const u8) {
94        // Remove old dependency edges if overwriting.
95        self.remove_dependency_edges(&hash);
96        self.entries.insert(
97            hash,
98            CacheEntry {
99                code_ptr: ptr,
100                function_hash: hash,
101                schema_version: 0,
102                feedback_epoch: 0,
103                dependencies: Vec::new(),
104                tier2_key: None,
105            },
106        );
107    }
108
109    /// Insert a cache entry with full dependency information.
110    ///
111    /// Builds reverse-index edges so that `invalidate_by_dependency()`
112    /// can find this entry when any of its dependencies change.
113    pub fn insert_entry(&mut self, entry: CacheEntry) {
114        let hash = entry.function_hash;
115        // Remove stale dependency edges if overwriting.
116        self.remove_dependency_edges(&hash);
117        // Build reverse edges for the new entry.
118        for dep in &entry.dependencies {
119            self.dependents.entry(*dep).or_default().push(hash);
120        }
121        self.entries.insert(hash, entry);
122    }
123
124    /// Invalidate all entries that depend on the given function hash.
125    ///
126    /// Performs a transitive walk: if A depends on B and B depends on C,
127    /// invalidating C will remove both B and A.
128    ///
129    /// Returns the list of invalidated function hashes.
130    pub fn invalidate_by_dependency(&mut self, changed_hash: &FunctionHash) -> Vec<FunctionHash> {
131        let mut invalidated = Vec::new();
132        let mut worklist = vec![*changed_hash];
133
134        while let Some(current) = worklist.pop() {
135            if let Some(deps) = self.dependents.remove(&current) {
136                for dep_hash in deps {
137                    if self.entries.remove(&dep_hash).is_some() {
138                        invalidated.push(dep_hash);
139                        // Cascade: anything that depended on the now-removed
140                        // entry must also be invalidated.
141                        worklist.push(dep_hash);
142                    }
143                }
144            }
145        }
146
147        // Clean up reverse edges for invalidated entries.
148        for inv in &invalidated {
149            self.remove_dependency_edges(inv);
150        }
151
152        invalidated
153    }
154
155    /// Get a cache entry with full metadata.
156    pub fn get_entry(&self, hash: &FunctionHash) -> Option<&CacheEntry> {
157        self.entries.get(hash)
158    }
159
160    /// Check whether a function with the given hash has been compiled.
161    pub fn contains(&self, hash: &FunctionHash) -> bool {
162        self.entries.contains_key(hash)
163    }
164
165    /// Number of cached entries.
166    pub fn len(&self) -> usize {
167        self.entries.len()
168    }
169
170    /// Returns `true` if the cache is empty.
171    pub fn is_empty(&self) -> bool {
172        self.entries.is_empty()
173    }
174
175    /// Remove all entries from the cache.
176    ///
177    /// This does **not** free the underlying native code memory (that
178    /// is owned by the Cranelift `JITModule`).
179    pub fn clear(&mut self) {
180        self.entries.clear();
181        self.dependents.clear();
182    }
183
184    /// Remove reverse-index edges for a given function hash.
185    fn remove_dependency_edges(&mut self, hash: &FunctionHash) {
186        if let Some(entry) = self.entries.get(hash) {
187            let deps: Vec<FunctionHash> = entry.dependencies.clone();
188            for dep in &deps {
189                if let Some(rev) = self.dependents.get_mut(dep) {
190                    rev.retain(|h| h != hash);
191                    if rev.is_empty() {
192                        self.dependents.remove(dep);
193                    }
194                }
195            }
196        }
197    }
198}
199
200impl Default for JitCodeCache {
201    fn default() -> Self {
202        Self::new()
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn empty_cache() {
212        let cache = JitCodeCache::new();
213        assert!(cache.is_empty());
214        assert_eq!(cache.len(), 0);
215        assert!(cache.get(&FunctionHash::ZERO).is_none());
216    }
217
218    #[test]
219    fn insert_and_get() {
220        let mut cache = JitCodeCache::new();
221        let hash = FunctionHash([0xAB; 32]);
222        let fake_ptr = 0xDEAD_BEEF_usize as *const u8;
223
224        cache.insert(hash, fake_ptr);
225        assert_eq!(cache.len(), 1);
226        assert!(!cache.is_empty());
227        assert!(cache.contains(&hash));
228        assert_eq!(cache.get(&hash), Some(fake_ptr));
229    }
230
231    #[test]
232    fn missing_hash_returns_none() {
233        let mut cache = JitCodeCache::new();
234        let hash_a = FunctionHash([1u8; 32]);
235        let hash_b = FunctionHash([2u8; 32]);
236        cache.insert(hash_a, 0x1 as *const u8);
237
238        assert!(cache.get(&hash_b).is_none());
239        assert!(!cache.contains(&hash_b));
240    }
241
242    #[test]
243    fn overwrite_entry() {
244        let mut cache = JitCodeCache::new();
245        let hash = FunctionHash([0xCC; 32]);
246        let ptr1 = 0x1000_usize as *const u8;
247        let ptr2 = 0x2000_usize as *const u8;
248
249        cache.insert(hash, ptr1);
250        assert_eq!(cache.get(&hash), Some(ptr1));
251
252        cache.insert(hash, ptr2);
253        assert_eq!(cache.get(&hash), Some(ptr2));
254        assert_eq!(cache.len(), 1);
255    }
256
257    #[test]
258    fn clear_removes_all() {
259        let mut cache = JitCodeCache::new();
260        cache.insert(FunctionHash([1; 32]), 0x1 as *const u8);
261        cache.insert(FunctionHash([2; 32]), 0x2 as *const u8);
262        assert_eq!(cache.len(), 2);
263
264        cache.clear();
265        assert!(cache.is_empty());
266        assert_eq!(cache.len(), 0);
267    }
268
269    #[test]
270    fn with_capacity() {
271        let cache = JitCodeCache::with_capacity(64);
272        assert!(cache.is_empty());
273    }
274
275    // --- Dependency-tracking tests ---
276
277    fn make_entry(hash: FunctionHash, ptr: usize, deps: Vec<FunctionHash>) -> CacheEntry {
278        CacheEntry {
279            code_ptr: ptr as *const u8,
280            function_hash: hash,
281            schema_version: 1,
282            feedback_epoch: 1,
283            dependencies: deps,
284            tier2_key: None,
285        }
286    }
287
288    #[test]
289    fn test_insert_entry_with_dependencies() {
290        let mut cache = JitCodeCache::new();
291        let callee = FunctionHash([0x01; 32]);
292        let caller = FunctionHash([0x02; 32]);
293
294        // Insert the callee (no deps).
295        cache.insert_entry(make_entry(callee, 0x1000, vec![]));
296
297        // Insert the caller which depends on the callee.
298        cache.insert_entry(make_entry(caller, 0x2000, vec![callee]));
299
300        assert_eq!(cache.len(), 2);
301        assert!(cache.contains(&callee));
302        assert!(cache.contains(&caller));
303
304        // Verify metadata is accessible.
305        let entry = cache.get_entry(&caller).unwrap();
306        assert_eq!(entry.schema_version, 1);
307        assert_eq!(entry.feedback_epoch, 1);
308        assert_eq!(entry.dependencies, vec![callee]);
309    }
310
311    #[test]
312    fn test_invalidate_by_dependency() {
313        let mut cache = JitCodeCache::new();
314        let callee = FunctionHash([0x01; 32]);
315        let caller_a = FunctionHash([0x02; 32]);
316        let caller_b = FunctionHash([0x03; 32]);
317        let unrelated = FunctionHash([0x04; 32]);
318
319        cache.insert_entry(make_entry(callee, 0x1000, vec![]));
320        cache.insert_entry(make_entry(caller_a, 0x2000, vec![callee]));
321        cache.insert_entry(make_entry(caller_b, 0x3000, vec![callee]));
322        cache.insert_entry(make_entry(unrelated, 0x4000, vec![]));
323        assert_eq!(cache.len(), 4);
324
325        // Invalidate everything that depends on callee.
326        let mut invalidated = cache.invalidate_by_dependency(&callee);
327        invalidated.sort_by_key(|h| h.0);
328
329        assert_eq!(invalidated.len(), 2);
330        assert!(invalidated.contains(&caller_a));
331        assert!(invalidated.contains(&caller_b));
332
333        // callee itself is NOT removed (only its dependents are).
334        assert!(cache.contains(&callee));
335        assert!(cache.contains(&unrelated));
336        assert!(!cache.contains(&caller_a));
337        assert!(!cache.contains(&caller_b));
338        assert_eq!(cache.len(), 2);
339    }
340
341    #[test]
342    fn test_invalidate_cascading() {
343        // A depends on B, B depends on C. Invalidate C -> both B and A removed.
344        let mut cache = JitCodeCache::new();
345        let c = FunctionHash([0x01; 32]);
346        let b = FunctionHash([0x02; 32]);
347        let a = FunctionHash([0x03; 32]);
348
349        cache.insert_entry(make_entry(c, 0x1000, vec![]));
350        cache.insert_entry(make_entry(b, 0x2000, vec![c]));
351        cache.insert_entry(make_entry(a, 0x3000, vec![b]));
352        assert_eq!(cache.len(), 3);
353
354        let mut invalidated = cache.invalidate_by_dependency(&c);
355        invalidated.sort_by_key(|h| h.0);
356
357        // Both B and A should be invalidated (B directly, A transitively).
358        assert_eq!(invalidated.len(), 2);
359        assert!(invalidated.contains(&b));
360        assert!(invalidated.contains(&a));
361
362        // Only C remains.
363        assert!(cache.contains(&c));
364        assert!(!cache.contains(&b));
365        assert!(!cache.contains(&a));
366        assert_eq!(cache.len(), 1);
367    }
368
369    #[test]
370    fn test_get_entry_returns_metadata() {
371        let mut cache = JitCodeCache::new();
372        let hash = FunctionHash([0xAA; 32]);
373        let dep = FunctionHash([0xBB; 32]);
374
375        cache.insert_entry(CacheEntry {
376            code_ptr: 0x5000 as *const u8,
377            function_hash: hash,
378            schema_version: 42,
379            feedback_epoch: 7,
380            dependencies: vec![dep],
381            tier2_key: None,
382        });
383
384        let entry = cache.get_entry(&hash).unwrap();
385        assert_eq!(entry.code_ptr, 0x5000 as *const u8);
386        assert_eq!(entry.function_hash, hash);
387        assert_eq!(entry.schema_version, 42);
388        assert_eq!(entry.feedback_epoch, 7);
389        assert_eq!(entry.dependencies, vec![dep]);
390
391        // get() still returns just the pointer.
392        assert_eq!(cache.get(&hash), Some(0x5000 as *const u8));
393
394        // Missing entry returns None.
395        assert!(cache.get_entry(&FunctionHash([0xFF; 32])).is_none());
396    }
397
398    #[test]
399    fn test_tier2_cache_key_stored_in_entry() {
400        let mut cache = JitCodeCache::new();
401        let root = FunctionHash([0x10; 32]);
402        let inlined_callee = FunctionHash([0x20; 32]);
403
404        let key = Tier2CacheKey::with_versions(
405            root.0,
406            vec![inlined_callee.0],
407            1, // compiler_version
408            5, // schema_version
409            3, // feedback_epoch
410        );
411
412        cache.insert_entry(CacheEntry {
413            code_ptr: 0x8000 as *const u8,
414            function_hash: root,
415            schema_version: 5,
416            feedback_epoch: 3,
417            dependencies: vec![inlined_callee],
418            tier2_key: Some(key.clone()),
419        });
420
421        let entry = cache.get_entry(&root).unwrap();
422        let stored_key = entry.tier2_key.as_ref().unwrap();
423        assert_eq!(stored_key.root_hash, root.0);
424        assert_eq!(stored_key.inlined_hashes, vec![inlined_callee.0]);
425        assert_eq!(stored_key.schema_version, 5);
426        assert_eq!(stored_key.feedback_epoch, 3);
427        assert_eq!(stored_key.compiler_version, 1);
428
429        // Verify combined_hash includes version metadata.
430        let key_no_versions = Tier2CacheKey::new(root.0, vec![inlined_callee.0], 1);
431        assert_ne!(stored_key.combined_hash(), key_no_versions.combined_hash());
432    }
433
434    #[test]
435    fn test_invalidate_with_tier2_entries() {
436        // Tier 2 entry with inlined callee: invalidating the callee
437        // removes the tier 2 entry.
438        let mut cache = JitCodeCache::new();
439        let callee = FunctionHash([0x01; 32]);
440        let optimized = FunctionHash([0x02; 32]);
441
442        let key = Tier2CacheKey::with_versions(optimized.0, vec![callee.0], 1, 0, 0);
443
444        cache.insert_entry(CacheEntry {
445            code_ptr: 0x1000 as *const u8,
446            function_hash: callee,
447            schema_version: 0,
448            feedback_epoch: 0,
449            dependencies: vec![],
450            tier2_key: None,
451        });
452        cache.insert_entry(CacheEntry {
453            code_ptr: 0x2000 as *const u8,
454            function_hash: optimized,
455            schema_version: 0,
456            feedback_epoch: 0,
457            dependencies: vec![callee],
458            tier2_key: Some(key),
459        });
460
461        let invalidated = cache.invalidate_by_dependency(&callee);
462        assert_eq!(invalidated.len(), 1);
463        assert_eq!(invalidated[0], optimized);
464        assert!(!cache.contains(&optimized));
465        assert!(cache.contains(&callee));
466    }
467}