Skip to main content

shape_vm/
deopt.rs

1//! Deoptimization tracking for JIT-compiled functions.
2//!
3//! Tracks which JIT-compiled functions depend on specific shape IDs,
4//! so that when a shape transitions (e.g., a HashMap gains a property),
5//! all functions that guarded on that shape can be invalidated.
6
7use std::collections::{HashMap, HashSet};
8
9use shape_value::shape_graph::ShapeId;
10
11/// Tracks shape dependencies for JIT-compiled functions.
12///
13/// When a function is compiled with shape guards (e.g., guarding that an
14/// object has shape X for inline caching), the shape IDs it depends on
15/// are registered here. When a shape transition occurs, all functions
16/// that depend on the transitioning shape are invalidated.
17pub struct DeoptTracker {
18    /// function_id → set of ShapeIds it depends on
19    dependencies: HashMap<u16, HashSet<ShapeId>>,
20    /// shape_id → set of function_ids that depend on it
21    shape_dependents: HashMap<ShapeId, HashSet<u16>>,
22}
23
24impl DeoptTracker {
25    /// Create an empty deopt tracker.
26    pub fn new() -> Self {
27        Self {
28            dependencies: HashMap::new(),
29            shape_dependents: HashMap::new(),
30        }
31    }
32
33    /// Register shape dependencies for a compiled function.
34    ///
35    /// Called after successful JIT compilation when the compilation result
36    /// includes shape guard IDs.
37    pub fn register(&mut self, function_id: u16, shape_ids: &[ShapeId]) {
38        if shape_ids.is_empty() {
39            return;
40        }
41        let dep_set = self
42            .dependencies
43            .entry(function_id)
44            .or_insert_with(HashSet::new);
45        for &sid in shape_ids {
46            dep_set.insert(sid);
47            self.shape_dependents
48                .entry(sid)
49                .or_insert_with(HashSet::new)
50                .insert(function_id);
51        }
52    }
53
54    /// Invalidate all functions that depend on the given shape.
55    ///
56    /// Returns the list of function IDs that were invalidated (need to
57    /// have their JIT code removed from the native_code_table).
58    pub fn invalidate_shape(&mut self, shape_id: ShapeId) -> Vec<u16> {
59        let dependents = match self.shape_dependents.remove(&shape_id) {
60            Some(set) => set,
61            None => return Vec::new(),
62        };
63
64        let mut invalidated = Vec::with_capacity(dependents.len());
65        for func_id in dependents {
66            // Remove all of this function's dependencies
67            if let Some(dep_shapes) = self.dependencies.remove(&func_id) {
68                // Clean up reverse mappings for other shapes this function depended on
69                for sid in &dep_shapes {
70                    if *sid != shape_id {
71                        if let Some(funcs) = self.shape_dependents.get_mut(sid) {
72                            funcs.remove(&func_id);
73                            if funcs.is_empty() {
74                                self.shape_dependents.remove(sid);
75                            }
76                        }
77                    }
78                }
79            }
80            invalidated.push(func_id);
81        }
82
83        invalidated
84    }
85
86    /// Clear all dependencies for a function (e.g., when it's recompiled).
87    pub fn clear_function(&mut self, function_id: u16) {
88        if let Some(dep_shapes) = self.dependencies.remove(&function_id) {
89            for sid in dep_shapes {
90                if let Some(funcs) = self.shape_dependents.get_mut(&sid) {
91                    funcs.remove(&function_id);
92                    if funcs.is_empty() {
93                        self.shape_dependents.remove(&sid);
94                    }
95                }
96            }
97        }
98    }
99
100    /// Number of functions being tracked.
101    pub fn tracked_function_count(&self) -> usize {
102        self.dependencies.len()
103    }
104
105    /// Number of shapes being watched.
106    pub fn watched_shape_count(&self) -> usize {
107        self.shape_dependents.len()
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn test_register_and_invalidate() {
117        let mut tracker = DeoptTracker::new();
118        let s1 = ShapeId(1);
119        let s2 = ShapeId(2);
120
121        tracker.register(0, &[s1, s2]);
122        tracker.register(1, &[s1]);
123        assert_eq!(tracker.tracked_function_count(), 2);
124        assert_eq!(tracker.watched_shape_count(), 2);
125
126        // Invalidate shape 1 — both functions depend on it
127        let invalidated = tracker.invalidate_shape(s1);
128        assert_eq!(invalidated.len(), 2);
129        assert!(invalidated.contains(&0));
130        assert!(invalidated.contains(&1));
131
132        // Both functions fully removed
133        assert_eq!(tracker.tracked_function_count(), 0);
134        // Shape 2 no longer watched (function 0 was the only dependent)
135        assert_eq!(tracker.watched_shape_count(), 0);
136    }
137
138    #[test]
139    fn test_invalidate_no_dependents() {
140        let mut tracker = DeoptTracker::new();
141        let invalidated = tracker.invalidate_shape(ShapeId(99));
142        assert!(invalidated.is_empty());
143    }
144
145    #[test]
146    fn test_clear_function() {
147        let mut tracker = DeoptTracker::new();
148        let s1 = ShapeId(1);
149        tracker.register(0, &[s1]);
150        tracker.register(1, &[s1]);
151
152        tracker.clear_function(0);
153        assert_eq!(tracker.tracked_function_count(), 1);
154
155        // Shape 1 still watched by function 1
156        let invalidated = tracker.invalidate_shape(s1);
157        assert_eq!(invalidated, vec![1]);
158    }
159
160    #[test]
161    fn test_register_empty_shapes() {
162        let mut tracker = DeoptTracker::new();
163        tracker.register(0, &[]);
164        assert_eq!(tracker.tracked_function_count(), 0);
165    }
166
167    #[test]
168    fn test_duplicate_registration() {
169        let mut tracker = DeoptTracker::new();
170        let s1 = ShapeId(1);
171        tracker.register(0, &[s1]);
172        tracker.register(0, &[s1]); // duplicate
173        assert_eq!(tracker.tracked_function_count(), 1);
174        assert_eq!(tracker.watched_shape_count(), 1);
175    }
176
177    #[test]
178    fn test_invalidate_partial_overlap() {
179        let mut tracker = DeoptTracker::new();
180        let s1 = ShapeId(1);
181        let s2 = ShapeId(2);
182        let s3 = ShapeId(3);
183
184        tracker.register(0, &[s1, s2]); // depends on s1, s2
185        tracker.register(1, &[s2, s3]); // depends on s2, s3
186
187        // Invalidate s2 — both functions invalidated
188        let invalidated = tracker.invalidate_shape(s2);
189        assert_eq!(invalidated.len(), 2);
190
191        // All cleaned up
192        assert_eq!(tracker.tracked_function_count(), 0);
193        assert_eq!(tracker.watched_shape_count(), 0);
194    }
195}