Skip to main content

shape_jit/
executor.rs

1//! JIT executor implementing the ProgramExecutor trait
2
3use shape_ast::Program;
4use shape_runtime::engine::{ExecutionType, ProgramExecutor, ShapeEngine};
5use shape_runtime::error::Result;
6use shape_wire::WireValue;
7use std::time::Instant;
8
9/// JIT executor with selective per-function compilation.
10///
11/// JIT-compatible functions are compiled to native code; incompatible functions
12/// (e.g. those using async, pattern matching, or unsupported builtins) are left
13/// as `Interpreted` entries in the mixed function table for VM fallback.
14pub struct JITExecutor;
15
16impl ProgramExecutor for JITExecutor {
17    fn execute_program(
18        &self,
19        engine: &mut ShapeEngine,
20        program: &Program,
21    ) -> Result<shape_runtime::engine::ProgramExecutorResult> {
22        use shape_vm::BytecodeCompiler;
23        let emit_phase_metrics = std::env::var_os("SHAPE_JIT_PHASE_METRICS").is_some();
24
25        // Capture source text before getting runtime reference (for error messages)
26        let source_for_compilation = engine.current_source().map(|s| s.to_string());
27
28        // Compile to bytecode first to check JIT compatibility
29        let runtime = engine.get_runtime_mut();
30
31        // Get known module bindings — prefer persistent context, fallback to precompiled names
32        let known_bindings: Vec<String> = if let Some(ctx) = runtime.persistent_context() {
33            let names = ctx.root_scope_binding_names();
34            if names.is_empty() {
35                shape_vm::stdlib::core_binding_names()
36            } else {
37                names
38            }
39        } else {
40            shape_vm::stdlib::core_binding_names()
41        };
42
43        // Extract imported functions from ModuleBindingRegistry
44        let module_binding_registry = runtime.module_binding_registry();
45        let imported_program =
46            shape_vm::BytecodeExecutor::create_program_from_imports(&module_binding_registry)?;
47
48        // Merge with main program
49        let mut merged_program = imported_program;
50        merged_program.items.extend(program.items.clone());
51        shape_vm::module_resolution::prepend_prelude_items(&mut merged_program);
52
53        // Compile to bytecode (with source text if available for better error messages)
54        let bytecode_compile_start = Instant::now();
55        let mut compiler = BytecodeCompiler::new();
56        compiler.register_known_bindings(&known_bindings);
57        let mut bytecode = if let Some(source) = &source_for_compilation {
58            compiler.compile_with_source(&merged_program, source)
59        } else {
60            compiler.compile(&merged_program)
61        }
62        .map_err(|e| shape_runtime::error::ShapeError::RuntimeError {
63            message: format!("Bytecode compilation failed: {}", e),
64            location: None,
65        })?;
66        let bytecode_compile_ms = bytecode_compile_start.elapsed().as_millis();
67
68        self.execute_with_jit(engine, &bytecode, bytecode_compile_ms, emit_phase_metrics)
69    }
70}
71
72impl JITExecutor {
73    fn execute_with_jit(
74        &self,
75        engine: &mut ShapeEngine,
76        bytecode: &shape_vm::bytecode::BytecodeProgram,
77        bytecode_compile_ms: u128,
78        emit_phase_metrics: bool,
79    ) -> Result<shape_runtime::engine::ProgramExecutorResult> {
80        use crate::JITConfig;
81        use crate::JITContext;
82        use crate::compiler::JITCompiler;
83
84        // JIT compile the bytecode
85        let jit_config = JITConfig::default();
86        let mut jit = JITCompiler::new(jit_config).map_err(|e| {
87            shape_runtime::error::ShapeError::RuntimeError {
88                message: format!("JIT compiler initialization failed: {}", e),
89                location: None,
90            }
91        })?;
92
93        // Use selective compilation: JIT-compatible functions get native code,
94        // incompatible ones get Interpreted entries for VM fallback.
95        if std::env::var_os("SHAPE_JIT_DEBUG").is_some() {
96            eprintln!("[jit-debug] starting compile_program_selective with {} instructions, {} functions",
97                bytecode.instructions.len(), bytecode.functions.len());
98            for (i, instr) in bytecode.instructions.iter().enumerate() {
99                eprintln!("[jit-debug] instr[{}]: {:?} {:?}", i, instr.opcode, instr.operand);
100            }
101        }
102        let jit_compile_start = Instant::now();
103        let compile_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
104            jit.compile_program_selective("main", bytecode)
105        }));
106        let jit_compile_ms = jit_compile_start.elapsed().as_millis();
107        let (jit_fn, _mixed_table) = match compile_result {
108            Ok(Ok(result)) => result,
109            Ok(Err(e)) => {
110                return Err(shape_runtime::error::ShapeError::RuntimeError {
111                    message: format!("JIT compilation failed: {}", e),
112                    location: None,
113                });
114            }
115            Err(panic_info) => {
116                let msg = if let Some(s) = panic_info.downcast_ref::<String>() {
117                    s.clone()
118                } else if let Some(s) = panic_info.downcast_ref::<&str>() {
119                    s.to_string()
120                } else {
121                    "unknown panic".to_string()
122                };
123                return Err(shape_runtime::error::ShapeError::RuntimeError {
124                    message: format!("JIT compilation panicked: {}", msg),
125                    location: None,
126                });
127            }
128        };
129
130        let foreign_bridge = {
131            let runtime = engine.get_runtime_mut();
132            crate::foreign_bridge::link_foreign_functions_for_jit(
133                bytecode,
134                runtime.persistent_context(),
135            )
136            .map_err(|e| shape_runtime::error::ShapeError::RuntimeError {
137                message: format!("JIT foreign-function linking failed: {}", e),
138                location: None,
139            })?
140        };
141
142        // Create JIT context and execute
143        let mut jit_ctx = JITContext::default();
144        if let Some(state) = foreign_bridge.as_ref() {
145            jit_ctx.foreign_bridge_ptr = state.as_ref() as *const _ as *const std::ffi::c_void;
146        }
147
148        // Set exec_context_ptr so JIT FFI can access cached data
149        {
150            let runtime = engine.get_runtime_mut();
151            if let Some(ctx) = runtime.persistent_context_mut() {
152                jit_ctx.exec_context_ptr = ctx as *mut _ as *mut std::ffi::c_void;
153            }
154        }
155
156        // Execute the JIT-compiled function
157        if std::env::var_os("SHAPE_JIT_DEBUG").is_some() {
158            eprintln!("[jit-debug] compilation OK, about to execute...");
159        }
160        let jit_exec_start = Instant::now();
161        let signal = unsafe { jit_fn(&mut jit_ctx) };
162        let jit_exec_ms = jit_exec_start.elapsed().as_millis();
163
164        // Get result from JIT context stack via TypedScalar boundary
165        let raw_result = if jit_ctx.stack_ptr > 0 {
166            jit_ctx.stack[0]
167        } else {
168            crate::nan_boxing::TAG_NULL
169        };
170
171        // Check for errors
172        if signal < 0 {
173            return Err(shape_runtime::error::ShapeError::RuntimeError {
174                message: format!("JIT execution error (code: {})", signal),
175                location: None,
176            });
177        }
178
179        // Use FrameDescriptor hint to preserve integer type identity.
180        // Prefer return_kind when populated; fall back to last slot.
181        let return_hint = bytecode.top_level_frame.as_ref().and_then(|fd| {
182            if fd.return_kind != shape_vm::type_tracking::SlotKind::Unknown {
183                Some(fd.return_kind)
184            } else {
185                fd.slots.last().copied()
186            }
187        });
188        let result_scalar =
189            crate::ffi::object::conversion::jit_bits_to_typed_scalar(raw_result, return_hint);
190
191        // Convert TypedScalar to WireValue
192        let wire_value = self.typed_scalar_to_wire(&result_scalar, raw_result);
193
194        if emit_phase_metrics {
195            let total_ms = bytecode_compile_ms + jit_compile_ms + jit_exec_ms;
196            eprintln!(
197                "[shape-jit-phases] bytecode_compile_ms={} jit_compile_ms={} jit_exec_ms={} total_ms={}",
198                bytecode_compile_ms, jit_compile_ms, jit_exec_ms, total_ms
199            );
200        }
201
202        Ok(shape_runtime::engine::ProgramExecutorResult {
203            wire_value,
204            type_info: None,
205            execution_type: ExecutionType::Script,
206            content_json: None,
207            content_html: None,
208            content_terminal: None,
209        })
210    }
211
212    /// Convert a TypedScalar result to WireValue.
213    ///
214    /// For scalar types, the TypedScalar carries enough information. For heap types
215    /// (strings, arrays) that TypedScalar can't represent, we fall back to raw bits.
216    fn typed_scalar_to_wire(&self, ts: &shape_value::TypedScalar, raw_bits: u64) -> WireValue {
217        use shape_value::ScalarKind;
218
219        match ts.kind {
220            ScalarKind::I8
221            | ScalarKind::I16
222            | ScalarKind::I32
223            | ScalarKind::I64
224            | ScalarKind::U8
225            | ScalarKind::U16
226            | ScalarKind::U32
227            | ScalarKind::U64
228            | ScalarKind::I128
229            | ScalarKind::U128 => {
230                // Integer result — preserve as exact integer in WireValue::Number
231                WireValue::Number(ts.payload_lo as i64 as f64)
232            }
233            ScalarKind::F64 | ScalarKind::F32 => WireValue::Number(f64::from_bits(ts.payload_lo)),
234            ScalarKind::Bool => WireValue::Bool(ts.payload_lo != 0),
235            ScalarKind::Unit => WireValue::Null,
236            ScalarKind::None => {
237                // None could also be a fallback for non-scalar heap types.
238                // Check if raw_bits is actually a heap value.
239                self.nan_boxed_to_wire(raw_bits)
240            }
241        }
242    }
243
244    fn nan_boxed_to_wire(&self, bits: u64) -> WireValue {
245        use crate::nan_boxing::{
246            HK_STRING, TAG_BOOL_FALSE, TAG_BOOL_TRUE, TAG_NULL, is_heap_kind, is_number, jit_unbox,
247            unbox_number,
248        };
249        use shape_value::tags::{get_tag, is_tagged, sign_extend_i48, get_payload, TAG_INT};
250
251        if is_number(bits) {
252            WireValue::Number(unbox_number(bits))
253        } else if bits == TAG_NULL {
254            WireValue::Null
255        } else if bits == TAG_BOOL_TRUE {
256            WireValue::Bool(true)
257        } else if bits == TAG_BOOL_FALSE {
258            WireValue::Bool(false)
259        } else if is_tagged(bits) && get_tag(bits) == TAG_INT {
260            // NaN-boxed i48 integer — sign-extend to i64 and return as integer
261            let int_val = sign_extend_i48(get_payload(bits));
262            WireValue::Integer(int_val)
263        } else if is_heap_kind(bits, HK_STRING) {
264            let s = unsafe { jit_unbox::<String>(bits) };
265            WireValue::String(s.clone())
266        } else {
267            // Default to interpreting as a number for unknown tags
268            WireValue::Number(f64::from_bits(bits))
269        }
270    }
271}