1use shape_vm::bytecode::FunctionHash;
13use std::collections::HashMap;
14
15use crate::optimizer::Tier2CacheKey;
16
17#[derive(Debug, Clone)]
19pub struct CacheEntry {
20 pub code_ptr: *const u8,
22 pub function_hash: FunctionHash,
24 pub schema_version: u32,
26 pub feedback_epoch: u32,
28 pub dependencies: Vec<FunctionHash>,
30 pub tier2_key: Option<Tier2CacheKey>,
33}
34
35unsafe impl Send for CacheEntry {}
38unsafe impl Sync for CacheEntry {}
39
40pub struct JitCodeCache {
54 entries: HashMap<FunctionHash, CacheEntry>,
55 dependents: HashMap<FunctionHash, Vec<FunctionHash>>,
58}
59
60unsafe impl Send for JitCodeCache {}
64unsafe impl Sync for JitCodeCache {}
65
66impl JitCodeCache {
67 pub fn new() -> Self {
69 Self {
70 entries: HashMap::new(),
71 dependents: HashMap::new(),
72 }
73 }
74
75 pub fn with_capacity(capacity: usize) -> Self {
77 Self {
78 entries: HashMap::with_capacity(capacity),
79 dependents: HashMap::new(),
80 }
81 }
82
83 pub fn get(&self, hash: &FunctionHash) -> Option<*const u8> {
85 self.entries.get(hash).map(|e| e.code_ptr)
86 }
87
88 pub fn insert(&mut self, hash: FunctionHash, ptr: *const u8) {
94 self.remove_dependency_edges(&hash);
96 self.entries.insert(
97 hash,
98 CacheEntry {
99 code_ptr: ptr,
100 function_hash: hash,
101 schema_version: 0,
102 feedback_epoch: 0,
103 dependencies: Vec::new(),
104 tier2_key: None,
105 },
106 );
107 }
108
109 pub fn insert_entry(&mut self, entry: CacheEntry) {
114 let hash = entry.function_hash;
115 self.remove_dependency_edges(&hash);
117 for dep in &entry.dependencies {
119 self.dependents.entry(*dep).or_default().push(hash);
120 }
121 self.entries.insert(hash, entry);
122 }
123
124 pub fn invalidate_by_dependency(&mut self, changed_hash: &FunctionHash) -> Vec<FunctionHash> {
131 let mut invalidated = Vec::new();
132 let mut worklist = vec![*changed_hash];
133
134 while let Some(current) = worklist.pop() {
135 if let Some(deps) = self.dependents.remove(¤t) {
136 for dep_hash in deps {
137 if self.entries.remove(&dep_hash).is_some() {
138 invalidated.push(dep_hash);
139 worklist.push(dep_hash);
142 }
143 }
144 }
145 }
146
147 for inv in &invalidated {
149 self.remove_dependency_edges(inv);
150 }
151
152 invalidated
153 }
154
155 pub fn get_entry(&self, hash: &FunctionHash) -> Option<&CacheEntry> {
157 self.entries.get(hash)
158 }
159
160 pub fn contains(&self, hash: &FunctionHash) -> bool {
162 self.entries.contains_key(hash)
163 }
164
165 pub fn len(&self) -> usize {
167 self.entries.len()
168 }
169
170 pub fn is_empty(&self) -> bool {
172 self.entries.is_empty()
173 }
174
175 pub fn clear(&mut self) {
180 self.entries.clear();
181 self.dependents.clear();
182 }
183
184 fn remove_dependency_edges(&mut self, hash: &FunctionHash) {
186 if let Some(entry) = self.entries.get(hash) {
187 let deps: Vec<FunctionHash> = entry.dependencies.clone();
188 for dep in &deps {
189 if let Some(rev) = self.dependents.get_mut(dep) {
190 rev.retain(|h| h != hash);
191 if rev.is_empty() {
192 self.dependents.remove(dep);
193 }
194 }
195 }
196 }
197 }
198}
199
200impl Default for JitCodeCache {
201 fn default() -> Self {
202 Self::new()
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn empty_cache() {
212 let cache = JitCodeCache::new();
213 assert!(cache.is_empty());
214 assert_eq!(cache.len(), 0);
215 assert!(cache.get(&FunctionHash::ZERO).is_none());
216 }
217
218 #[test]
219 fn insert_and_get() {
220 let mut cache = JitCodeCache::new();
221 let hash = FunctionHash([0xAB; 32]);
222 let fake_ptr = 0xDEAD_BEEF_usize as *const u8;
223
224 cache.insert(hash, fake_ptr);
225 assert_eq!(cache.len(), 1);
226 assert!(!cache.is_empty());
227 assert!(cache.contains(&hash));
228 assert_eq!(cache.get(&hash), Some(fake_ptr));
229 }
230
231 #[test]
232 fn missing_hash_returns_none() {
233 let mut cache = JitCodeCache::new();
234 let hash_a = FunctionHash([1u8; 32]);
235 let hash_b = FunctionHash([2u8; 32]);
236 cache.insert(hash_a, 0x1 as *const u8);
237
238 assert!(cache.get(&hash_b).is_none());
239 assert!(!cache.contains(&hash_b));
240 }
241
242 #[test]
243 fn overwrite_entry() {
244 let mut cache = JitCodeCache::new();
245 let hash = FunctionHash([0xCC; 32]);
246 let ptr1 = 0x1000_usize as *const u8;
247 let ptr2 = 0x2000_usize as *const u8;
248
249 cache.insert(hash, ptr1);
250 assert_eq!(cache.get(&hash), Some(ptr1));
251
252 cache.insert(hash, ptr2);
253 assert_eq!(cache.get(&hash), Some(ptr2));
254 assert_eq!(cache.len(), 1);
255 }
256
257 #[test]
258 fn clear_removes_all() {
259 let mut cache = JitCodeCache::new();
260 cache.insert(FunctionHash([1; 32]), 0x1 as *const u8);
261 cache.insert(FunctionHash([2; 32]), 0x2 as *const u8);
262 assert_eq!(cache.len(), 2);
263
264 cache.clear();
265 assert!(cache.is_empty());
266 assert_eq!(cache.len(), 0);
267 }
268
269 #[test]
270 fn with_capacity() {
271 let cache = JitCodeCache::with_capacity(64);
272 assert!(cache.is_empty());
273 }
274
275 fn make_entry(hash: FunctionHash, ptr: usize, deps: Vec<FunctionHash>) -> CacheEntry {
278 CacheEntry {
279 code_ptr: ptr as *const u8,
280 function_hash: hash,
281 schema_version: 1,
282 feedback_epoch: 1,
283 dependencies: deps,
284 tier2_key: None,
285 }
286 }
287
288 #[test]
289 fn test_insert_entry_with_dependencies() {
290 let mut cache = JitCodeCache::new();
291 let callee = FunctionHash([0x01; 32]);
292 let caller = FunctionHash([0x02; 32]);
293
294 cache.insert_entry(make_entry(callee, 0x1000, vec![]));
296
297 cache.insert_entry(make_entry(caller, 0x2000, vec![callee]));
299
300 assert_eq!(cache.len(), 2);
301 assert!(cache.contains(&callee));
302 assert!(cache.contains(&caller));
303
304 let entry = cache.get_entry(&caller).unwrap();
306 assert_eq!(entry.schema_version, 1);
307 assert_eq!(entry.feedback_epoch, 1);
308 assert_eq!(entry.dependencies, vec![callee]);
309 }
310
311 #[test]
312 fn test_invalidate_by_dependency() {
313 let mut cache = JitCodeCache::new();
314 let callee = FunctionHash([0x01; 32]);
315 let caller_a = FunctionHash([0x02; 32]);
316 let caller_b = FunctionHash([0x03; 32]);
317 let unrelated = FunctionHash([0x04; 32]);
318
319 cache.insert_entry(make_entry(callee, 0x1000, vec![]));
320 cache.insert_entry(make_entry(caller_a, 0x2000, vec![callee]));
321 cache.insert_entry(make_entry(caller_b, 0x3000, vec![callee]));
322 cache.insert_entry(make_entry(unrelated, 0x4000, vec![]));
323 assert_eq!(cache.len(), 4);
324
325 let mut invalidated = cache.invalidate_by_dependency(&callee);
327 invalidated.sort_by_key(|h| h.0);
328
329 assert_eq!(invalidated.len(), 2);
330 assert!(invalidated.contains(&caller_a));
331 assert!(invalidated.contains(&caller_b));
332
333 assert!(cache.contains(&callee));
335 assert!(cache.contains(&unrelated));
336 assert!(!cache.contains(&caller_a));
337 assert!(!cache.contains(&caller_b));
338 assert_eq!(cache.len(), 2);
339 }
340
341 #[test]
342 fn test_invalidate_cascading() {
343 let mut cache = JitCodeCache::new();
345 let c = FunctionHash([0x01; 32]);
346 let b = FunctionHash([0x02; 32]);
347 let a = FunctionHash([0x03; 32]);
348
349 cache.insert_entry(make_entry(c, 0x1000, vec![]));
350 cache.insert_entry(make_entry(b, 0x2000, vec![c]));
351 cache.insert_entry(make_entry(a, 0x3000, vec![b]));
352 assert_eq!(cache.len(), 3);
353
354 let mut invalidated = cache.invalidate_by_dependency(&c);
355 invalidated.sort_by_key(|h| h.0);
356
357 assert_eq!(invalidated.len(), 2);
359 assert!(invalidated.contains(&b));
360 assert!(invalidated.contains(&a));
361
362 assert!(cache.contains(&c));
364 assert!(!cache.contains(&b));
365 assert!(!cache.contains(&a));
366 assert_eq!(cache.len(), 1);
367 }
368
369 #[test]
370 fn test_get_entry_returns_metadata() {
371 let mut cache = JitCodeCache::new();
372 let hash = FunctionHash([0xAA; 32]);
373 let dep = FunctionHash([0xBB; 32]);
374
375 cache.insert_entry(CacheEntry {
376 code_ptr: 0x5000 as *const u8,
377 function_hash: hash,
378 schema_version: 42,
379 feedback_epoch: 7,
380 dependencies: vec![dep],
381 tier2_key: None,
382 });
383
384 let entry = cache.get_entry(&hash).unwrap();
385 assert_eq!(entry.code_ptr, 0x5000 as *const u8);
386 assert_eq!(entry.function_hash, hash);
387 assert_eq!(entry.schema_version, 42);
388 assert_eq!(entry.feedback_epoch, 7);
389 assert_eq!(entry.dependencies, vec![dep]);
390
391 assert_eq!(cache.get(&hash), Some(0x5000 as *const u8));
393
394 assert!(cache.get_entry(&FunctionHash([0xFF; 32])).is_none());
396 }
397
398 #[test]
399 fn test_tier2_cache_key_stored_in_entry() {
400 let mut cache = JitCodeCache::new();
401 let root = FunctionHash([0x10; 32]);
402 let inlined_callee = FunctionHash([0x20; 32]);
403
404 let key = Tier2CacheKey::with_versions(
405 root.0,
406 vec![inlined_callee.0],
407 1, 5, 3, );
411
412 cache.insert_entry(CacheEntry {
413 code_ptr: 0x8000 as *const u8,
414 function_hash: root,
415 schema_version: 5,
416 feedback_epoch: 3,
417 dependencies: vec![inlined_callee],
418 tier2_key: Some(key.clone()),
419 });
420
421 let entry = cache.get_entry(&root).unwrap();
422 let stored_key = entry.tier2_key.as_ref().unwrap();
423 assert_eq!(stored_key.root_hash, root.0);
424 assert_eq!(stored_key.inlined_hashes, vec![inlined_callee.0]);
425 assert_eq!(stored_key.schema_version, 5);
426 assert_eq!(stored_key.feedback_epoch, 3);
427 assert_eq!(stored_key.compiler_version, 1);
428
429 let key_no_versions = Tier2CacheKey::new(root.0, vec![inlined_callee.0], 1);
431 assert_ne!(stored_key.combined_hash(), key_no_versions.combined_hash());
432 }
433
434 #[test]
435 fn test_invalidate_with_tier2_entries() {
436 let mut cache = JitCodeCache::new();
439 let callee = FunctionHash([0x01; 32]);
440 let optimized = FunctionHash([0x02; 32]);
441
442 let key = Tier2CacheKey::with_versions(optimized.0, vec![callee.0], 1, 0, 0);
443
444 cache.insert_entry(CacheEntry {
445 code_ptr: 0x1000 as *const u8,
446 function_hash: callee,
447 schema_version: 0,
448 feedback_epoch: 0,
449 dependencies: vec![],
450 tier2_key: None,
451 });
452 cache.insert_entry(CacheEntry {
453 code_ptr: 0x2000 as *const u8,
454 function_hash: optimized,
455 schema_version: 0,
456 feedback_epoch: 0,
457 dependencies: vec![callee],
458 tier2_key: Some(key),
459 });
460
461 let invalidated = cache.invalidate_by_dependency(&callee);
462 assert_eq!(invalidated.len(), 1);
463 assert_eq!(invalidated[0], optimized);
464 assert!(!cache.contains(&optimized));
465 assert!(cache.contains(&callee));
466 }
467}