Skip to main content

shape_vm/
tier.rs

1//! Tiered compilation support for Shape VM.
2//!
3//! Functions start in Tier 0 (interpreted) and are promoted to JIT compilation
4//! tiers based on call frequency:
5//!
6//! - Tier 0: Interpreted (all functions start here)
7//! - Tier 1: Baseline JIT (per-function, no cross-function optimization) — after 100 calls
8//! - Tier 2: Optimizing JIT (inlining, constant propagation) — after 10,000 calls
9
10use std::collections::{HashMap, HashSet};
11use std::sync::{Arc, mpsc};
12
13use crate::bytecode::BytecodeProgram;
14use crate::deopt::DeoptTracker;
15use crate::feedback::FeedbackVector;
16
17/// Execution tier for a function.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
19pub enum Tier {
20    /// Interpreted execution (default).
21    Interpreted,
22    /// Baseline JIT — per-function compilation, no cross-function optimization.
23    BaselineJit,
24    /// Optimizing JIT — inlining, constant propagation, devirtualization.
25    OptimizingJit,
26}
27
28impl Tier {
29    /// The call count threshold to promote to this tier.
30    pub fn threshold(&self) -> u32 {
31        match self {
32            Self::Interpreted => 0,
33            Self::BaselineJit => 100,
34            Self::OptimizingJit => 10_000,
35        }
36    }
37}
38
39/// Per-function call counter and tier tracking.
40#[derive(Debug)]
41pub struct FunctionTierState {
42    /// Current execution tier.
43    pub tier: Tier,
44    /// Total call count since program start.
45    pub call_count: u32,
46    /// Whether a compilation request is pending for this function.
47    pub compilation_pending: bool,
48}
49
50impl Default for FunctionTierState {
51    fn default() -> Self {
52        Self {
53            tier: Tier::Interpreted,
54            call_count: 0,
55            compilation_pending: false,
56        }
57    }
58}
59
60/// Request to compile a function at a higher tier.
61#[derive(Debug)]
62pub struct CompilationRequest {
63    /// Function index in the program.
64    pub function_id: u16,
65    /// Target tier for compilation.
66    pub target_tier: Tier,
67    /// Content hash of the function blob (for cache lookup).
68    pub blob_hash: Option<[u8; 32]>,
69    /// If true, this is an OSR compilation request for a specific loop.
70    /// The `loop_header_ip` field specifies which loop to compile.
71    pub osr: bool,
72    /// Bytecode IP of the loop header for OSR compilation.
73    /// Only meaningful when `osr == true`.
74    pub loop_header_ip: Option<usize>,
75    /// Feedback vector snapshot for this function (Tier 2+ only).
76    ///
77    /// At Tier 1, feedback is not yet collected, so this is `None`.
78    /// At Tier 2 (optimizing), the JIT reads IC state from this vector
79    /// to emit speculative guards and type-specialized code paths.
80    pub feedback: Option<FeedbackVector>,
81    /// Feedback vectors for inline callee functions (Tier 2+ only).
82    ///
83    /// Maps callee function_id → its FeedbackVector. When the Tier 2 JIT
84    /// inlines a callee, it merges the callee's feedback into the compilation
85    /// so speculative guards can fire inside inlined code.
86    pub callee_feedback: HashMap<u16, FeedbackVector>,
87}
88
89/// Result of background compilation.
90#[derive(Debug)]
91pub struct CompilationResult {
92    /// Function index that was compiled.
93    pub function_id: u16,
94    /// Tier that was compiled to.
95    pub compiled_tier: Tier,
96    /// Native code pointer if JIT compilation succeeded.
97    pub native_code: Option<*const u8>,
98    /// Error message if compilation failed (function stays at current tier).
99    pub error: Option<String>,
100    /// If this is an OSR compilation, the entry point metadata.
101    pub osr_entry: Option<crate::bytecode::OsrEntryPoint>,
102    /// Deopt info for all guard points in the compiled code.
103    /// Each entry describes how to reconstruct interpreter state when a
104    /// speculative guard fails inside the JIT-compiled code.
105    pub deopt_points: Vec<crate::bytecode::DeoptInfo>,
106    /// Bytecode IP of the loop header for OSR results. Used for blacklisting
107    /// failed loops so compilation is not re-attempted.
108    pub loop_header_ip: Option<usize>,
109    /// Shape IDs guarded by this compilation. Used by DeoptTracker to
110    /// invalidate the function when a guarded shape transitions.
111    pub shape_guards: Vec<shape_value::shape_graph::ShapeId>,
112}
113
114/// Backend trait for pluggable JIT compilation.
115///
116/// Implementations receive compilation requests and produce results. The
117/// `TierManager` owns the worker thread that drives the backend.
118pub trait CompilationBackend: Send + 'static {
119    /// Compile a function or loop according to the request.
120    fn compile(
121        &mut self,
122        request: &CompilationRequest,
123        program: &BytecodeProgram,
124    ) -> CompilationResult;
125}
126
127// SAFETY: native_code pointers are valid for the lifetime of the JIT compilation
128// and are only used within the VM execution context.
129unsafe impl Send for CompilationResult {}
130
131/// Default OSR back-edge threshold: 1000 iterations triggers OSR compilation.
132const DEFAULT_OSR_THRESHOLD: u32 = 1000;
133
134/// Manages tiered compilation state for all functions in a program.
135pub struct TierManager {
136    /// Per-function tier state, indexed by function_id.
137    function_states: Vec<FunctionTierState>,
138    /// Channel to send compilation requests to the background thread.
139    compilation_tx: Option<mpsc::Sender<CompilationRequest>>,
140    /// Channel to receive compilation results from the background thread.
141    compilation_rx: Option<mpsc::Receiver<CompilationResult>>,
142    /// Native function pointers from JIT compilation (function_id -> code pointer).
143    native_code_table: HashMap<u16, *const u8>,
144    /// Whether tiered compilation is enabled.
145    enabled: bool,
146    /// Per-function, per-loop back-edge counters: (func_id, loop_ip) -> count.
147    /// Incremented each time the interpreter executes a loop back-edge.
148    loop_counters: HashMap<(u16, usize), u32>,
149    /// OSR-compiled loop entries: (func_id, loop_ip) -> native code pointer.
150    /// Populated when an OSR compilation completes successfully.
151    osr_table: HashMap<(u16, usize), *const u8>,
152    /// Number of loop back-edge iterations before requesting OSR compilation.
153    osr_threshold: u32,
154    /// Loops that failed compilation and should not be retried.
155    /// Key is (function_id, loop_header_ip).
156    osr_blacklist: HashSet<(u16, usize)>,
157    /// Deopt info tables for Tier 2 compiled functions.
158    /// function_id -> Vec<DeoptInfo> (indexed by deopt_id).
159    /// Populated by `poll_completions()` when a compilation result includes
160    /// deopt_points (from speculative guard emission).
161    deopt_tables: HashMap<u16, Vec<crate::bytecode::DeoptInfo>>,
162    /// Shape dependency tracker for JIT invalidation.
163    /// Tracks which functions depend on which shape IDs, enabling
164    /// automatic invalidation when shape transitions occur.
165    deopt_tracker: DeoptTracker,
166}
167
168// SAFETY: The raw pointers in native_code_table are JIT-compiled code that
169// lives for the duration of the VM. Only accessed from the VM thread.
170unsafe impl Send for TierManager {}
171
172impl TierManager {
173    /// Create a new tier manager for a program with the given number of functions.
174    pub fn new(function_count: usize, enabled: bool) -> Self {
175        let mut function_states = Vec::with_capacity(function_count);
176        function_states.resize_with(function_count, FunctionTierState::default);
177
178        Self {
179            function_states,
180            compilation_tx: None,
181            compilation_rx: None,
182            native_code_table: HashMap::new(),
183            enabled,
184            loop_counters: HashMap::new(),
185            osr_table: HashMap::new(),
186            osr_threshold: DEFAULT_OSR_THRESHOLD,
187            osr_blacklist: HashSet::new(),
188            deopt_tables: HashMap::new(),
189            deopt_tracker: DeoptTracker::new(),
190        }
191    }
192
193    /// Set up the background compilation channels.
194    ///
195    /// The caller is responsible for spawning the background thread that reads
196    /// from `request_rx` and sends results to `result_tx`.
197    pub fn set_channels(
198        &mut self,
199        compilation_tx: mpsc::Sender<CompilationRequest>,
200        compilation_rx: mpsc::Receiver<CompilationResult>,
201    ) {
202        self.compilation_tx = Some(compilation_tx);
203        self.compilation_rx = Some(compilation_rx);
204    }
205
206    /// Record a function call and check for tier promotion.
207    ///
208    /// Returns `true` if the function should be compiled at a higher tier.
209    /// This is called in the `Call` opcode handler.
210    ///
211    /// When promoting to OptimizingJit and a feedback vector is available,
212    /// it is attached to the compilation request for speculative optimization.
213    #[inline]
214    pub fn record_call(&mut self, function_id: u16, feedback: Option<&FeedbackVector>) -> bool {
215        if !self.enabled {
216            return false;
217        }
218
219        let idx = function_id as usize;
220        if idx >= self.function_states.len() {
221            return false;
222        }
223
224        let state = &mut self.function_states[idx];
225        state.call_count = state.call_count.saturating_add(1);
226
227        // Check if promotion is warranted.
228        let next_tier = match state.tier {
229            Tier::Interpreted if state.call_count >= Tier::BaselineJit.threshold() => {
230                Some(Tier::BaselineJit)
231            }
232            Tier::BaselineJit if state.call_count >= Tier::OptimizingJit.threshold() => {
233                Some(Tier::OptimizingJit)
234            }
235            _ => None,
236        };
237
238        if let Some(target) = next_tier {
239            if !state.compilation_pending {
240                state.compilation_pending = true;
241                // Tier 2 (OptimizingJit) benefits from feedback for speculation.
242                // Tier 1 (BaselineJit) compiles without feedback.
243                if target == Tier::OptimizingJit {
244                    if let Some(fv) = feedback {
245                        self.request_compilation_with_feedback(function_id, target, fv.clone());
246                        return true;
247                    }
248                }
249                self.request_compilation(function_id, target);
250                return true;
251            }
252        }
253
254        false
255    }
256
257    /// Send a compilation request to the background thread.
258    fn request_compilation(&self, function_id: u16, target_tier: Tier) {
259        if let Some(ref tx) = self.compilation_tx {
260            let _ = tx.send(CompilationRequest {
261                function_id,
262                target_tier,
263                blob_hash: None, // Caller can set this from function metadata
264                osr: false,
265                loop_header_ip: None,
266                callee_feedback: HashMap::new(),
267                feedback: None, // Feedback attached by executor when available
268            });
269        }
270    }
271
272    /// Send a compilation request with a feedback vector snapshot.
273    ///
274    /// Used for Tier 2 (optimizing) promotion when the executor has collected
275    /// enough type feedback to enable speculative optimization.
276    pub fn request_compilation_with_feedback(
277        &self,
278        function_id: u16,
279        target_tier: Tier,
280        feedback: FeedbackVector,
281    ) {
282        if let Some(ref tx) = self.compilation_tx {
283            let _ = tx.send(CompilationRequest {
284                function_id,
285                target_tier,
286                blob_hash: None,
287                osr: false,
288                loop_header_ip: None,
289                feedback: Some(feedback),
290                callee_feedback: HashMap::new(),
291            });
292        }
293    }
294
295    /// Poll for completed compilations (non-blocking).
296    ///
297    /// Called at safe points: function entry, loop back-edges.
298    /// Applies any completed compilations by updating the native code table.
299    /// Also handles OSR compilation results by updating the osr_table.
300    pub fn poll_completions(&mut self) -> Vec<CompilationResult> {
301        let mut results = Vec::new();
302
303        if let Some(ref rx) = self.compilation_rx {
304            while let Ok(result) = rx.try_recv() {
305                let idx = result.function_id as usize;
306                if idx < self.function_states.len() {
307                    let state = &mut self.function_states[idx];
308                    state.compilation_pending = false;
309
310                    if let Some(code_ptr) = result.native_code {
311                        // Check if this is an OSR compilation result
312                        if let Some(ref osr_entry) = result.osr_entry {
313                            // Register OSR code for this loop
314                            self.osr_table
315                                .insert((result.function_id, osr_entry.bytecode_ip), code_ptr);
316                        } else {
317                            // Regular whole-function compilation
318                            state.tier = result.compiled_tier;
319                            self.native_code_table.insert(result.function_id, code_ptr);
320                        }
321
322                        // Store deopt points for speculative guard recovery
323                        if !result.deopt_points.is_empty() {
324                            self.deopt_tables
325                                .insert(result.function_id, result.deopt_points.clone());
326                        }
327
328                        // Register shape dependencies for invalidation tracking
329                        if !result.shape_guards.is_empty() {
330                            self.deopt_tracker
331                                .register(result.function_id, &result.shape_guards);
332                        }
333                    }
334                    // Blacklist failed OSR loops so we don't retry them
335                    if result.error.is_some() {
336                        if let Some(loop_ip) = result.loop_header_ip {
337                            self.osr_blacklist.insert((result.function_id, loop_ip));
338                        }
339                    }
340                }
341                results.push(result);
342            }
343        }
344
345        // Check for shape transitions that invalidate JIT-compiled code
346        self.check_shape_invalidations();
347
348        results
349    }
350
351    /// Check for shape transitions and invalidate dependent JIT code.
352    ///
353    /// Drains the global shape transition log and uses the DeoptTracker to
354    /// find functions that depend on the transitioned shapes. Those functions
355    /// are invalidated (reverted to interpreter) so they can be recompiled
356    /// with updated shape assumptions.
357    fn check_shape_invalidations(&mut self) {
358        let transitions = shape_value::shape_graph::drain_shape_transitions();
359        for (parent_id, _child_id) in transitions {
360            let invalidated = self.deopt_tracker.invalidate_shape(parent_id);
361            for func_id in invalidated {
362                self.invalidate_all(func_id);
363            }
364        }
365    }
366
367    /// Look up native code for a function, if available.
368    #[inline]
369    pub fn get_native_code(&self, function_id: u16) -> Option<*const u8> {
370        self.native_code_table.get(&function_id).copied()
371    }
372
373    /// Look up a DeoptInfo entry for a specific guard deopt point.
374    ///
375    /// `deopt_id` is an index into the `deopt_points` vector stored when the
376    /// Tier 2 compilation result was installed. Returns `None` if no deopt
377    /// table exists for this function or the index is out of bounds.
378    pub fn get_deopt_info(
379        &self,
380        function_id: u16,
381        deopt_id: usize,
382    ) -> Option<&crate::bytecode::DeoptInfo> {
383        self.deopt_tables
384            .get(&function_id)
385            .and_then(|points| points.get(deopt_id))
386    }
387
388    /// Get the current tier of a function.
389    pub fn get_tier(&self, function_id: u16) -> Tier {
390        self.function_states
391            .get(function_id as usize)
392            .map(|s| s.tier)
393            .unwrap_or(Tier::Interpreted)
394    }
395
396    /// Get the call count for a function.
397    pub fn get_call_count(&self, function_id: u16) -> u32 {
398        self.function_states
399            .get(function_id as usize)
400            .map(|s| s.call_count)
401            .unwrap_or(0)
402    }
403
404    /// Whether tiered compilation is enabled.
405    pub fn is_enabled(&self) -> bool {
406        self.enabled
407    }
408
409    /// Number of functions with native JIT code.
410    pub fn jit_compiled_count(&self) -> usize {
411        self.native_code_table.len()
412    }
413
414    // =====================================================================
415    // OSR (On-Stack Replacement) — hot loop detection and dispatch
416    // =====================================================================
417
418    /// Record a loop back-edge iteration and check if OSR compilation should
419    /// be requested.
420    ///
421    /// Returns `true` if this iteration crosses the OSR threshold and no
422    /// OSR code has been compiled for this loop yet (i.e., we should send
423    /// a compilation request).
424    #[inline]
425    pub fn record_loop_iteration(&mut self, func_id: u16, loop_ip: usize) -> bool {
426        if !self.enabled {
427            return false;
428        }
429        // Never retry blacklisted loops (compilation previously failed)
430        if self.osr_blacklist.contains(&(func_id, loop_ip)) {
431            return false;
432        }
433        let counter = self.loop_counters.entry((func_id, loop_ip)).or_insert(0);
434        *counter += 1;
435        // Only trigger once: when the counter first reaches the threshold
436        // and OSR code has not already been compiled.
437        *counter == self.osr_threshold && !self.osr_table.contains_key(&(func_id, loop_ip))
438    }
439
440    /// Register OSR-compiled native code for a specific loop.
441    pub fn register_osr_code(&mut self, func_id: u16, loop_ip: usize, code: *const u8) {
442        self.osr_table.insert((func_id, loop_ip), code);
443    }
444
445    /// Look up OSR-compiled native code for a specific loop.
446    #[inline]
447    pub fn get_osr_code(&self, func_id: u16, loop_ip: usize) -> Option<*const u8> {
448        self.osr_table.get(&(func_id, loop_ip)).copied()
449    }
450
451    /// Get the current OSR threshold.
452    pub fn osr_threshold(&self) -> u32 {
453        self.osr_threshold
454    }
455
456    /// Override the OSR threshold (useful for testing).
457    pub fn set_osr_threshold(&mut self, threshold: u32) {
458        self.osr_threshold = threshold;
459    }
460
461    /// Get the loop iteration count for a specific loop.
462    pub fn get_loop_count(&self, func_id: u16, loop_ip: usize) -> u32 {
463        self.loop_counters
464            .get(&(func_id, loop_ip))
465            .copied()
466            .unwrap_or(0)
467    }
468
469    /// Number of OSR-compiled loop entries.
470    pub fn osr_compiled_count(&self) -> usize {
471        self.osr_table.len()
472    }
473
474    /// Access the compilation request sender (for OSR requests from the executor).
475    pub fn compilation_sender(&self) -> Option<&mpsc::Sender<CompilationRequest>> {
476        self.compilation_tx.as_ref()
477    }
478
479    /// Set a compilation backend. Spawns a worker thread that drives the
480    /// backend, processing requests from the TierManager's channel.
481    ///
482    /// When the TierManager is dropped, `compilation_tx` is dropped, which
483    /// causes `req_rx.recv()` to return `Err` and the worker thread exits.
484    pub fn set_backend(
485        &mut self,
486        backend: Box<dyn CompilationBackend>,
487        program: Arc<BytecodeProgram>,
488    ) {
489        let (req_tx, req_rx) = mpsc::channel();
490        let (res_tx, res_rx) = mpsc::channel();
491        self.compilation_tx = Some(req_tx);
492        self.compilation_rx = Some(res_rx);
493        std::thread::Builder::new()
494            .name("shape-jit-worker".into())
495            .spawn(move || {
496                let mut backend = backend;
497                while let Ok(request) = req_rx.recv() {
498                    let result = backend.compile(&request, &program);
499                    if res_tx.send(result).is_err() {
500                        break;
501                    }
502                }
503            })
504            .expect("Failed to spawn JIT worker thread");
505    }
506
507    /// Check whether a loop is blacklisted (compilation previously failed).
508    pub fn is_osr_blacklisted(&self, func_id: u16, loop_ip: usize) -> bool {
509        self.osr_blacklist.contains(&(func_id, loop_ip))
510    }
511
512    // =====================================================================
513    // Invalidation — deopt/dependency tracking
514    // =====================================================================
515
516    /// Invalidate a compiled function (remove from native_code_table).
517    ///
518    /// Called when the DeoptTracker determines a dependency changed (e.g., a
519    /// global was reassigned that the JIT specialized on).
520    pub fn invalidate_function(&mut self, func_id: u16) {
521        self.native_code_table.remove(&func_id);
522        self.deopt_tables.remove(&func_id);
523        self.deopt_tracker.clear_function(func_id);
524        // Reset the function tier state so it can be recompiled
525        if let Some(state) = self.function_states.get_mut(func_id as usize) {
526            state.tier = Tier::Interpreted;
527            state.compilation_pending = false;
528        }
529    }
530
531    /// Invalidate all OSR entries for a function.
532    ///
533    /// Removes compiled loop code and resets loop counters so the loops
534    /// can be re-profiled and recompiled if still hot.
535    pub fn invalidate_osr(&mut self, func_id: u16) {
536        self.osr_table.retain(|&(fid, _), _| fid != func_id);
537        self.loop_counters.retain(|&(fid, _), _| fid != func_id);
538    }
539
540    /// Bulk invalidation: invalidate function + all its OSR entries.
541    pub fn invalidate_all(&mut self, func_id: u16) {
542        self.invalidate_function(func_id);
543        self.invalidate_osr(func_id);
544    }
545
546    /// Summary statistics.
547    pub fn stats(&self) -> TierStats {
548        let mut interpreted = 0usize;
549        let mut baseline = 0usize;
550        let mut optimizing = 0usize;
551        let mut pending = 0usize;
552
553        for state in &self.function_states {
554            match state.tier {
555                Tier::Interpreted => interpreted += 1,
556                Tier::BaselineJit => baseline += 1,
557                Tier::OptimizingJit => optimizing += 1,
558            }
559            if state.compilation_pending {
560                pending += 1;
561            }
562        }
563
564        TierStats {
565            interpreted,
566            baseline_jit: baseline,
567            optimizing_jit: optimizing,
568            pending_compilations: pending,
569            total_functions: self.function_states.len(),
570        }
571    }
572}
573
574/// Summary statistics for the tiered compilation system.
575#[derive(Debug, Clone)]
576pub struct TierStats {
577    pub interpreted: usize,
578    pub baseline_jit: usize,
579    pub optimizing_jit: usize,
580    pub pending_compilations: usize,
581    pub total_functions: usize,
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587
588    #[test]
589    fn test_default_tier_is_interpreted() {
590        let mgr = TierManager::new(10, true);
591        assert_eq!(mgr.get_tier(0), Tier::Interpreted);
592        assert_eq!(mgr.get_tier(5), Tier::Interpreted);
593    }
594
595    #[test]
596    fn test_call_count_tracking() {
597        let mut mgr = TierManager::new(5, true);
598        for _ in 0..50 {
599            mgr.record_call(0, None);
600        }
601        assert_eq!(mgr.get_call_count(0), 50);
602        assert_eq!(mgr.get_call_count(1), 0);
603    }
604
605    #[test]
606    fn test_promotion_threshold() {
607        let mut mgr = TierManager::new(5, true);
608
609        // Not promoted at 99 calls
610        for _ in 0..99 {
611            mgr.record_call(0, None);
612        }
613        // Without channels, tier stays as Interpreted but compilation_pending is set
614        let promoted = mgr.record_call(0, None); // 100th call
615        assert!(promoted);
616
617        // Still Interpreted because no background compiler responded
618        assert_eq!(mgr.get_tier(0), Tier::Interpreted);
619    }
620
621    #[test]
622    fn test_disabled_manager_no_promotion() {
623        let mut mgr = TierManager::new(5, false);
624        for _ in 0..200 {
625            assert!(!mgr.record_call(0, None));
626        }
627    }
628
629    #[test]
630    fn test_out_of_bounds_function_id() {
631        let mut mgr = TierManager::new(5, true);
632        assert!(!mgr.record_call(100, None)); // beyond function_states
633        assert_eq!(mgr.get_tier(100), Tier::Interpreted);
634        assert_eq!(mgr.get_call_count(100), 0);
635    }
636
637    #[test]
638    fn test_stats() {
639        let mgr = TierManager::new(10, true);
640        let stats = mgr.stats();
641        assert_eq!(stats.total_functions, 10);
642        assert_eq!(stats.interpreted, 10);
643        assert_eq!(stats.baseline_jit, 0);
644        assert_eq!(stats.pending_compilations, 0);
645    }
646
647    #[test]
648    fn test_channel_compilation_flow() {
649        let mut mgr = TierManager::new(5, true);
650
651        let (req_tx, req_rx) = mpsc::channel();
652        let (res_tx, res_rx) = mpsc::channel();
653        mgr.set_channels(req_tx, res_rx);
654
655        // Trigger promotion
656        for _ in 0..100 {
657            mgr.record_call(0, None);
658        }
659
660        // Background thread would receive request
661        let request = req_rx.try_recv().unwrap();
662        assert_eq!(request.function_id, 0);
663        assert_eq!(request.target_tier, Tier::BaselineJit);
664
665        // Simulate background compilation result
666        res_tx
667            .send(CompilationResult {
668                function_id: 0,
669                compiled_tier: Tier::BaselineJit,
670                native_code: Some(0x1000 as *const u8),
671                error: None,
672                osr_entry: None,
673                deopt_points: Vec::new(),
674                loop_header_ip: None,
675                shape_guards: Vec::new(),
676            })
677            .unwrap();
678
679        // Poll completions
680        let results = mgr.poll_completions();
681        assert_eq!(results.len(), 1);
682        assert_eq!(mgr.get_tier(0), Tier::BaselineJit);
683        assert!(mgr.get_native_code(0).is_some());
684    }
685
686    #[test]
687    fn test_tier_ordering() {
688        assert!(Tier::Interpreted < Tier::BaselineJit);
689        assert!(Tier::BaselineJit < Tier::OptimizingJit);
690    }
691
692    #[test]
693    fn test_get_native_code_before_and_after_promotion() {
694        let mut mgr = TierManager::new(5, true);
695
696        // No native code before promotion
697        assert!(mgr.get_native_code(0).is_none());
698
699        let (req_tx, req_rx) = mpsc::channel();
700        let (res_tx, res_rx) = mpsc::channel();
701        mgr.set_channels(req_tx, res_rx);
702
703        // Drive calls to threshold
704        for _ in 0..100 {
705            mgr.record_call(0, None);
706        }
707
708        // Verify request was sent
709        let request = req_rx.try_recv().unwrap();
710        assert_eq!(request.function_id, 0);
711
712        // Still no native code (compilation pending)
713        assert!(mgr.get_native_code(0).is_none());
714
715        // Simulate compilation result
716        let fake_ptr = 0xDEAD_BEEF as *const u8;
717        res_tx
718            .send(CompilationResult {
719                function_id: 0,
720                compiled_tier: Tier::BaselineJit,
721                native_code: Some(fake_ptr),
722                error: None,
723                osr_entry: None,
724                deopt_points: Vec::new(),
725                loop_header_ip: None,
726                shape_guards: Vec::new(),
727            })
728            .unwrap();
729
730        // Poll completions
731        mgr.poll_completions();
732
733        // Now native code is available
734        assert_eq!(mgr.get_native_code(0), Some(fake_ptr));
735        assert_eq!(mgr.get_tier(0), Tier::BaselineJit);
736    }
737
738    #[test]
739    fn test_compilation_failure_no_native_code() {
740        let mut mgr = TierManager::new(5, true);
741
742        let (req_tx, _req_rx) = mpsc::channel();
743        let (res_tx, res_rx) = mpsc::channel();
744        mgr.set_channels(req_tx, res_rx);
745
746        // Drive to threshold
747        for _ in 0..100 {
748            mgr.record_call(0, None);
749        }
750
751        // Simulate compilation failure
752        res_tx
753            .send(CompilationResult {
754                function_id: 0,
755                compiled_tier: Tier::BaselineJit,
756                native_code: None,
757                error: Some("compilation failed".to_string()),
758                osr_entry: None,
759                deopt_points: Vec::new(),
760                loop_header_ip: None,
761                shape_guards: Vec::new(),
762            })
763            .unwrap();
764
765        mgr.poll_completions();
766
767        // No native code installed, tier stays Interpreted
768        assert!(mgr.get_native_code(0).is_none());
769        assert_eq!(mgr.get_tier(0), Tier::Interpreted);
770    }
771
772    #[test]
773    fn test_no_duplicate_compilation_requests() {
774        let mut mgr = TierManager::new(5, true);
775
776        let (req_tx, req_rx) = mpsc::channel();
777        let (_res_tx, res_rx) = mpsc::channel();
778        mgr.set_channels(req_tx, res_rx);
779
780        // Drive past threshold
781        for _ in 0..200 {
782            mgr.record_call(0, None);
783        }
784
785        // Should only get one request (compilation_pending prevents duplicates)
786        let first = req_rx.try_recv();
787        assert!(first.is_ok());
788        let second = req_rx.try_recv();
789        assert!(second.is_err()); // No second request
790    }
791
792    #[test]
793    fn test_optimizing_tier_promotion() {
794        let mut mgr = TierManager::new(5, true);
795
796        let (req_tx, req_rx) = mpsc::channel();
797        let (res_tx, res_rx) = mpsc::channel();
798        mgr.set_channels(req_tx, res_rx);
799
800        // First: promote to BaselineJit
801        for _ in 0..100 {
802            mgr.record_call(0, None);
803        }
804        let request = req_rx.try_recv().unwrap();
805        assert_eq!(request.target_tier, Tier::BaselineJit);
806
807        // Complete baseline compilation
808        res_tx
809            .send(CompilationResult {
810                function_id: 0,
811                compiled_tier: Tier::BaselineJit,
812                native_code: Some(0x1000 as *const u8),
813                error: None,
814                osr_entry: None,
815                deopt_points: Vec::new(),
816                loop_header_ip: None,
817                shape_guards: Vec::new(),
818            })
819            .unwrap();
820        mgr.poll_completions();
821        assert_eq!(mgr.get_tier(0), Tier::BaselineJit);
822
823        // Continue calling until OptimizingJit threshold
824        for _ in 100..10_000 {
825            mgr.record_call(0, None);
826        }
827        let request = req_rx.try_recv().unwrap();
828        assert_eq!(request.target_tier, Tier::OptimizingJit);
829    }
830
831    // =====================================================================
832    // OSR tests
833    // =====================================================================
834
835    #[test]
836    fn test_loop_counter_threshold() {
837        let mut mgr = TierManager::new(5, true);
838
839        // Below threshold: should not trigger
840        for _ in 0..999 {
841            assert!(!mgr.record_loop_iteration(0, 42));
842        }
843        assert_eq!(mgr.get_loop_count(0, 42), 999);
844
845        // At threshold: should trigger exactly once
846        assert!(mgr.record_loop_iteration(0, 42));
847        assert_eq!(mgr.get_loop_count(0, 42), 1000);
848
849        // Past threshold: should not trigger again
850        assert!(!mgr.record_loop_iteration(0, 42));
851        assert_eq!(mgr.get_loop_count(0, 42), 1001);
852    }
853
854    #[test]
855    fn test_loop_counter_different_loops() {
856        let mut mgr = TierManager::new(5, true);
857        mgr.set_osr_threshold(10);
858
859        // Two different loops in the same function
860        for _ in 0..10 {
861            mgr.record_loop_iteration(0, 100);
862        }
863        assert_eq!(mgr.get_loop_count(0, 100), 10);
864        assert_eq!(mgr.get_loop_count(0, 200), 0);
865
866        // Different function, same loop IP
867        for _ in 0..5 {
868            mgr.record_loop_iteration(1, 100);
869        }
870        assert_eq!(mgr.get_loop_count(1, 100), 5);
871        assert_eq!(mgr.get_loop_count(0, 100), 10); // unchanged
872    }
873
874    #[test]
875    fn test_osr_table_registration() {
876        let mut mgr = TierManager::new(5, true);
877
878        // No OSR code initially
879        assert!(mgr.get_osr_code(0, 42).is_none());
880        assert_eq!(mgr.osr_compiled_count(), 0);
881
882        // Register OSR code
883        let fake_code = 0xBEEF as *const u8;
884        mgr.register_osr_code(0, 42, fake_code);
885
886        assert_eq!(mgr.get_osr_code(0, 42), Some(fake_code));
887        assert_eq!(mgr.osr_compiled_count(), 1);
888
889        // Different loop
890        assert!(mgr.get_osr_code(0, 100).is_none());
891    }
892
893    #[test]
894    fn test_osr_threshold_prevents_duplicate_request() {
895        let mut mgr = TierManager::new(5, true);
896        mgr.set_osr_threshold(10);
897
898        // Hit threshold
899        for _ in 0..9 {
900            mgr.record_loop_iteration(0, 42);
901        }
902        assert!(mgr.record_loop_iteration(0, 42)); // 10th: triggers
903
904        // Register OSR code (simulating compilation completed)
905        mgr.register_osr_code(0, 42, 0x1000 as *const u8);
906
907        // Further iterations should not trigger again
908        for _ in 0..100 {
909            assert!(!mgr.record_loop_iteration(0, 42));
910        }
911    }
912
913    #[test]
914    fn test_invalidate_function_clears_native_code() {
915        let mut mgr = TierManager::new(5, true);
916
917        let (req_tx, _req_rx) = mpsc::channel();
918        let (res_tx, res_rx) = mpsc::channel();
919        mgr.set_channels(req_tx, res_rx);
920
921        // Simulate a promoted function
922        for _ in 0..100 {
923            mgr.record_call(0, None);
924        }
925        res_tx
926            .send(CompilationResult {
927                function_id: 0,
928                compiled_tier: Tier::BaselineJit,
929                native_code: Some(0x1000 as *const u8),
930                error: None,
931                osr_entry: None,
932                deopt_points: Vec::new(),
933                loop_header_ip: None,
934                shape_guards: Vec::new(),
935            })
936            .unwrap();
937        mgr.poll_completions();
938        assert!(mgr.get_native_code(0).is_some());
939        assert_eq!(mgr.get_tier(0), Tier::BaselineJit);
940
941        // Invalidate
942        mgr.invalidate_function(0);
943        assert!(mgr.get_native_code(0).is_none());
944        assert_eq!(mgr.get_tier(0), Tier::Interpreted);
945        assert!(!mgr.function_states[0].compilation_pending);
946    }
947
948    #[test]
949    fn test_invalidate_osr_clears_loop_entries() {
950        let mut mgr = TierManager::new(5, true);
951        mgr.set_osr_threshold(10);
952
953        // Register OSR code for two loops in function 0
954        mgr.register_osr_code(0, 42, 0x1000 as *const u8);
955        mgr.register_osr_code(0, 100, 0x2000 as *const u8);
956        // And one in function 1
957        mgr.register_osr_code(1, 42, 0x3000 as *const u8);
958
959        // Set up some loop counters
960        for _ in 0..50 {
961            mgr.record_loop_iteration(0, 42);
962            mgr.record_loop_iteration(0, 100);
963            mgr.record_loop_iteration(1, 42);
964        }
965
966        // Invalidate OSR for function 0 only
967        mgr.invalidate_osr(0);
968
969        assert!(mgr.get_osr_code(0, 42).is_none());
970        assert!(mgr.get_osr_code(0, 100).is_none());
971        assert_eq!(mgr.get_loop_count(0, 42), 0);
972        assert_eq!(mgr.get_loop_count(0, 100), 0);
973
974        // Function 1 unaffected
975        assert!(mgr.get_osr_code(1, 42).is_some());
976        assert_eq!(mgr.get_loop_count(1, 42), 50);
977    }
978
979    #[test]
980    fn test_invalidate_all() {
981        let mut mgr = TierManager::new(5, true);
982
983        let (req_tx, _req_rx) = mpsc::channel();
984        let (res_tx, res_rx) = mpsc::channel();
985        mgr.set_channels(req_tx, res_rx);
986
987        // Set up whole-function JIT
988        for _ in 0..100 {
989            mgr.record_call(0, None);
990        }
991        res_tx
992            .send(CompilationResult {
993                function_id: 0,
994                compiled_tier: Tier::BaselineJit,
995                native_code: Some(0x1000 as *const u8),
996                error: None,
997                osr_entry: None,
998                deopt_points: Vec::new(),
999                loop_header_ip: None,
1000                shape_guards: Vec::new(),
1001            })
1002            .unwrap();
1003        mgr.poll_completions();
1004
1005        // Set up OSR entries
1006        mgr.register_osr_code(0, 42, 0x2000 as *const u8);
1007        for _ in 0..50 {
1008            mgr.record_loop_iteration(0, 42);
1009        }
1010
1011        // Invalidate everything for function 0
1012        mgr.invalidate_all(0);
1013
1014        assert!(mgr.get_native_code(0).is_none());
1015        assert_eq!(mgr.get_tier(0), Tier::Interpreted);
1016        assert!(mgr.get_osr_code(0, 42).is_none());
1017        assert_eq!(mgr.get_loop_count(0, 42), 0);
1018    }
1019
1020    #[test]
1021    fn test_loop_counter_disabled_manager() {
1022        let mut mgr = TierManager::new(5, false);
1023        // Should never trigger when disabled
1024        for _ in 0..2000 {
1025            assert!(!mgr.record_loop_iteration(0, 42));
1026        }
1027    }
1028
1029    #[test]
1030    fn test_custom_osr_threshold() {
1031        let mut mgr = TierManager::new(5, true);
1032        assert_eq!(mgr.osr_threshold(), DEFAULT_OSR_THRESHOLD);
1033
1034        mgr.set_osr_threshold(50);
1035        assert_eq!(mgr.osr_threshold(), 50);
1036
1037        for _ in 0..49 {
1038            assert!(!mgr.record_loop_iteration(0, 10));
1039        }
1040        assert!(mgr.record_loop_iteration(0, 10)); // 50th
1041    }
1042
1043    #[test]
1044    fn test_poll_completions_handles_osr_result() {
1045        let mut mgr = TierManager::new(5, true);
1046
1047        let (req_tx, _req_rx) = mpsc::channel();
1048        let (res_tx, res_rx) = mpsc::channel();
1049        mgr.set_channels(req_tx, res_rx);
1050
1051        // Drive calls to threshold so compilation_pending is set
1052        for _ in 0..100 {
1053            mgr.record_call(0, None);
1054        }
1055
1056        // Simulate an OSR compilation result
1057        let osr_entry = crate::bytecode::OsrEntryPoint {
1058            bytecode_ip: 42,
1059            live_locals: vec![0, 1],
1060            local_kinds: vec![
1061                crate::type_tracking::SlotKind::Int64,
1062                crate::type_tracking::SlotKind::Float64,
1063            ],
1064            exit_ip: 100,
1065        };
1066
1067        res_tx
1068            .send(CompilationResult {
1069                function_id: 0,
1070                compiled_tier: Tier::BaselineJit,
1071                native_code: Some(0xCAFE as *const u8),
1072                error: None,
1073                osr_entry: Some(osr_entry),
1074                deopt_points: Vec::new(),
1075                loop_header_ip: None,
1076                shape_guards: Vec::new(),
1077            })
1078            .unwrap();
1079
1080        mgr.poll_completions();
1081
1082        // OSR code should be in the osr_table, NOT in native_code_table
1083        assert!(mgr.get_native_code(0).is_none());
1084        assert_eq!(mgr.get_osr_code(0, 42), Some(0xCAFE as *const u8));
1085        // Tier should NOT be promoted (OSR is per-loop, not per-function)
1086        assert_eq!(mgr.get_tier(0), Tier::Interpreted);
1087    }
1088
1089    #[test]
1090    fn test_osr_blacklist_on_compilation_failure() {
1091        let mut mgr = TierManager::new(5, true);
1092        mgr.set_osr_threshold(10);
1093
1094        let (req_tx, _req_rx) = mpsc::channel();
1095        let (res_tx, res_rx) = mpsc::channel();
1096        mgr.set_channels(req_tx, res_rx);
1097
1098        // Drive loop to threshold
1099        for _ in 0..100 {
1100            mgr.record_call(0, None);
1101        }
1102
1103        // Simulate a failed OSR compilation with loop_header_ip
1104        res_tx
1105            .send(CompilationResult {
1106                function_id: 0,
1107                compiled_tier: Tier::BaselineJit,
1108                native_code: None,
1109                error: Some("unsupported opcode CallMethod".to_string()),
1110                osr_entry: None,
1111                deopt_points: Vec::new(),
1112                loop_header_ip: Some(42),
1113                shape_guards: Vec::new(),
1114            })
1115            .unwrap();
1116
1117        mgr.poll_completions();
1118
1119        // Loop should be blacklisted
1120        assert!(mgr.is_osr_blacklisted(0, 42));
1121        // Further iterations should not trigger compilation
1122        for _ in 0..2000 {
1123            assert!(!mgr.record_loop_iteration(0, 42));
1124        }
1125        // Different loop in same function is not blacklisted
1126        assert!(!mgr.is_osr_blacklisted(0, 100));
1127    }
1128
1129    #[test]
1130    fn test_compilation_result_loop_header_ip_roundtrip() {
1131        let mut mgr = TierManager::new(5, true);
1132
1133        let (req_tx, _req_rx) = mpsc::channel();
1134        let (res_tx, res_rx) = mpsc::channel();
1135        mgr.set_channels(req_tx, res_rx);
1136
1137        for _ in 0..100 {
1138            mgr.record_call(0, None);
1139        }
1140
1141        // Send result with loop_header_ip set
1142        res_tx
1143            .send(CompilationResult {
1144                function_id: 0,
1145                compiled_tier: Tier::BaselineJit,
1146                native_code: Some(0xABCD as *const u8),
1147                error: None,
1148                osr_entry: Some(crate::bytecode::OsrEntryPoint {
1149                    bytecode_ip: 55,
1150                    live_locals: vec![0],
1151                    local_kinds: vec![crate::type_tracking::SlotKind::Int64],
1152                    exit_ip: 80,
1153                }),
1154                deopt_points: Vec::new(),
1155                loop_header_ip: Some(55),
1156                shape_guards: Vec::new(),
1157            })
1158            .unwrap();
1159
1160        let results = mgr.poll_completions();
1161        assert_eq!(results.len(), 1);
1162        assert_eq!(results[0].loop_header_ip, Some(55));
1163        assert_eq!(mgr.get_osr_code(0, 55), Some(0xABCD as *const u8));
1164    }
1165
1166    #[test]
1167    fn test_deopt_table_stored_on_compilation() {
1168        let mut mgr = TierManager::new(5, true);
1169
1170        let (req_tx, _req_rx) = mpsc::channel();
1171        let (res_tx, res_rx) = mpsc::channel();
1172        mgr.set_channels(req_tx, res_rx);
1173
1174        for _ in 0..100 {
1175            mgr.record_call(0, None);
1176        }
1177
1178        let deopt_info = crate::bytecode::DeoptInfo {
1179            resume_ip: 42,
1180            local_mapping: vec![(0, 0), (1, 2)],
1181            local_kinds: vec![
1182                crate::type_tracking::SlotKind::Int64,
1183                crate::type_tracking::SlotKind::Float64,
1184            ],
1185            stack_depth: 1,
1186            innermost_function_id: None,
1187            inline_frames: Vec::new(),
1188        };
1189
1190        res_tx
1191            .send(CompilationResult {
1192                function_id: 0,
1193                compiled_tier: Tier::BaselineJit,
1194                native_code: Some(0xBEEF as *const u8),
1195                error: None,
1196                osr_entry: None,
1197                deopt_points: vec![deopt_info.clone()],
1198                loop_header_ip: None,
1199                shape_guards: Vec::new(),
1200            })
1201            .unwrap();
1202
1203        mgr.poll_completions();
1204
1205        // Deopt info should be retrievable by function_id + deopt_id
1206        let retrieved = mgr.get_deopt_info(0, 0);
1207        assert!(retrieved.is_some());
1208        assert_eq!(retrieved.unwrap().resume_ip, 42);
1209        assert_eq!(retrieved.unwrap().local_mapping.len(), 2);
1210        assert_eq!(retrieved.unwrap().stack_depth, 1);
1211
1212        // Out-of-bounds deopt_id returns None
1213        assert!(mgr.get_deopt_info(0, 1).is_none());
1214
1215        // Unknown function_id returns None
1216        assert!(mgr.get_deopt_info(1, 0).is_none());
1217    }
1218
1219    #[test]
1220    fn test_deopt_table_cleared_on_invalidation() {
1221        let mut mgr = TierManager::new(5, true);
1222
1223        let (req_tx, _req_rx) = mpsc::channel();
1224        let (res_tx, res_rx) = mpsc::channel();
1225        mgr.set_channels(req_tx, res_rx);
1226
1227        for _ in 0..100 {
1228            mgr.record_call(0, None);
1229        }
1230
1231        let deopt_info = crate::bytecode::DeoptInfo {
1232            resume_ip: 10,
1233            local_mapping: vec![(0, 0)],
1234            local_kinds: vec![crate::type_tracking::SlotKind::Int64],
1235            stack_depth: 0,
1236            innermost_function_id: None,
1237            inline_frames: Vec::new(),
1238        };
1239
1240        res_tx
1241            .send(CompilationResult {
1242                function_id: 0,
1243                compiled_tier: Tier::BaselineJit,
1244                native_code: Some(0xCAFE as *const u8),
1245                error: None,
1246                osr_entry: None,
1247                deopt_points: vec![deopt_info],
1248                loop_header_ip: None,
1249                shape_guards: Vec::new(),
1250            })
1251            .unwrap();
1252
1253        mgr.poll_completions();
1254        assert!(mgr.get_deopt_info(0, 0).is_some());
1255
1256        // Invalidating the function should clear deopt table
1257        mgr.invalidate_function(0);
1258        assert!(mgr.get_deopt_info(0, 0).is_none());
1259    }
1260
1261    #[test]
1262    fn test_deopt_table_empty_not_stored() {
1263        let mut mgr = TierManager::new(5, true);
1264
1265        let (req_tx, _req_rx) = mpsc::channel();
1266        let (res_tx, res_rx) = mpsc::channel();
1267        mgr.set_channels(req_tx, res_rx);
1268
1269        for _ in 0..100 {
1270            mgr.record_call(0, None);
1271        }
1272
1273        // Compilation with empty deopt_points
1274        res_tx
1275            .send(CompilationResult {
1276                function_id: 0,
1277                compiled_tier: Tier::BaselineJit,
1278                native_code: Some(0x1234 as *const u8),
1279                error: None,
1280                osr_entry: None,
1281                deopt_points: Vec::new(),
1282                loop_header_ip: None,
1283                shape_guards: Vec::new(),
1284            })
1285            .unwrap();
1286
1287        mgr.poll_completions();
1288
1289        // No deopt table stored for empty deopt_points
1290        assert!(mgr.get_deopt_info(0, 0).is_none());
1291    }
1292}