Skip to main content

shape_jit/
executor.rs

1//! JIT executor implementing the ProgramExecutor trait
2
3use shape_ast::Program;
4use shape_runtime::engine::{ExecutionType, ProgramExecutor, ShapeEngine};
5use shape_runtime::error::Result;
6use shape_wire::WireValue;
7use std::time::Instant;
8
9/// JIT executor with selective per-function compilation.
10///
11/// JIT-compatible functions are compiled to native code; incompatible functions
12/// (e.g. those using async, pattern matching, or unsupported builtins) are left
13/// as `Interpreted` entries in the mixed function table for VM fallback.
14pub struct JITExecutor;
15
16impl ProgramExecutor for JITExecutor {
17    fn execute_program(
18        &mut self,
19        engine: &mut ShapeEngine,
20        program: &Program,
21    ) -> Result<shape_runtime::engine::ProgramExecutorResult> {
22        use shape_vm::BytecodeCompiler;
23        let emit_phase_metrics = std::env::var_os("SHAPE_JIT_PHASE_METRICS").is_some();
24
25        // Capture source text before getting runtime reference (for error messages)
26        let source_for_compilation = engine.current_source().map(|s| s.to_string());
27
28        // Compile to bytecode first to check JIT compatibility
29        let runtime = engine.get_runtime_mut();
30
31        // Get known module bindings — prefer persistent context, fallback to precompiled names
32        let known_bindings: Vec<String> = if let Some(ctx) = runtime.persistent_context() {
33            let names = ctx.root_scope_binding_names();
34            if names.is_empty() {
35                shape_vm::stdlib::core_binding_names()
36            } else {
37                names
38            }
39        } else {
40            shape_vm::stdlib::core_binding_names()
41        };
42
43        // Extract imported functions from ModuleBindingRegistry
44        let module_binding_registry = runtime.module_binding_registry();
45        let imported_program =
46            shape_vm::BytecodeExecutor::create_program_from_imports(&module_binding_registry)?;
47
48        // Merge with main program
49        let mut merged_program = imported_program;
50        merged_program.items.extend(program.items.clone());
51        let stdlib_names =
52            shape_vm::module_resolution::prepend_prelude_items(&mut merged_program);
53
54        // Compile to bytecode (with source text if available for better error messages)
55        let bytecode_compile_start = Instant::now();
56        let mut compiler = BytecodeCompiler::new();
57        compiler.stdlib_function_names = stdlib_names;
58        compiler.register_known_bindings(&known_bindings);
59        let mut bytecode = if let Some(source) = &source_for_compilation {
60            compiler.compile_with_source(&merged_program, source)
61        } else {
62            compiler.compile(&merged_program)
63        }
64        .map_err(|e| shape_runtime::error::ShapeError::RuntimeError {
65            message: format!("Bytecode compilation failed: {}", e),
66            location: None,
67        })?;
68        let bytecode_compile_ms = bytecode_compile_start.elapsed().as_millis();
69
70        self.execute_with_jit(engine, &bytecode, bytecode_compile_ms, emit_phase_metrics)
71    }
72}
73
74impl JITExecutor {
75    fn execute_with_jit(
76        &self,
77        engine: &mut ShapeEngine,
78        bytecode: &shape_vm::bytecode::BytecodeProgram,
79        bytecode_compile_ms: u128,
80        emit_phase_metrics: bool,
81    ) -> Result<shape_runtime::engine::ProgramExecutorResult> {
82        use crate::JITConfig;
83        use crate::JITContext;
84        use crate::compiler::JITCompiler;
85
86        // JIT compile the bytecode
87        let jit_config = JITConfig::default();
88        let mut jit = JITCompiler::new(jit_config).map_err(|e| {
89            shape_runtime::error::ShapeError::RuntimeError {
90                message: format!("JIT compiler initialization failed: {}", e),
91                location: None,
92            }
93        })?;
94
95        // Use selective compilation: JIT-compatible functions get native code,
96        // incompatible ones get Interpreted entries for VM fallback.
97        if std::env::var_os("SHAPE_JIT_DEBUG").is_some() {
98            eprintln!("[jit-debug] starting compile_program_selective with {} instructions, {} functions",
99                bytecode.instructions.len(), bytecode.functions.len());
100            for (i, instr) in bytecode.instructions.iter().enumerate() {
101                eprintln!("[jit-debug] instr[{}]: {:?} {:?}", i, instr.opcode, instr.operand);
102            }
103        }
104        let jit_compile_start = Instant::now();
105        let compile_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
106            jit.compile_program_selective("main", bytecode)
107        }));
108        let jit_compile_ms = jit_compile_start.elapsed().as_millis();
109        let (jit_fn, _mixed_table) = match compile_result {
110            Ok(Ok(result)) => result,
111            Ok(Err(e)) => {
112                return Err(shape_runtime::error::ShapeError::RuntimeError {
113                    message: format!("JIT compilation failed: {}", e),
114                    location: None,
115                });
116            }
117            Err(panic_info) => {
118                let msg = if let Some(s) = panic_info.downcast_ref::<String>() {
119                    s.clone()
120                } else if let Some(s) = panic_info.downcast_ref::<&str>() {
121                    s.to_string()
122                } else {
123                    "unknown panic".to_string()
124                };
125                return Err(shape_runtime::error::ShapeError::RuntimeError {
126                    message: format!("JIT compilation panicked: {}", msg),
127                    location: None,
128                });
129            }
130        };
131
132        let foreign_bridge = {
133            let runtime = engine.get_runtime_mut();
134            crate::foreign_bridge::link_foreign_functions_for_jit(
135                bytecode,
136                runtime.persistent_context(),
137            )
138            .map_err(|e| shape_runtime::error::ShapeError::RuntimeError {
139                message: format!("JIT foreign-function linking failed: {}", e),
140                location: None,
141            })?
142        };
143
144        // Create JIT context and execute
145        let mut jit_ctx = JITContext::default();
146        if let Some(state) = foreign_bridge.as_ref() {
147            jit_ctx.foreign_bridge_ptr = state.as_ref() as *const _ as *const std::ffi::c_void;
148        }
149
150        // Set exec_context_ptr so JIT FFI can access cached data
151        {
152            let runtime = engine.get_runtime_mut();
153            if let Some(ctx) = runtime.persistent_context_mut() {
154                jit_ctx.exec_context_ptr = ctx as *mut _ as *mut std::ffi::c_void;
155            }
156        }
157
158        // Execute the JIT-compiled function
159        if std::env::var_os("SHAPE_JIT_DEBUG").is_some() {
160            eprintln!("[jit-debug] compilation OK, about to execute...");
161        }
162        let jit_exec_start = Instant::now();
163        let signal = unsafe { jit_fn(&mut jit_ctx) };
164        let jit_exec_ms = jit_exec_start.elapsed().as_millis();
165
166        // Get result from JIT context stack via TypedScalar boundary
167        let raw_result = if jit_ctx.stack_ptr > 0 {
168            jit_ctx.stack[0]
169        } else {
170            crate::nan_boxing::TAG_NULL
171        };
172
173        // Check for errors
174        if signal < 0 {
175            return Err(shape_runtime::error::ShapeError::RuntimeError {
176                message: format!("JIT execution error (code: {})", signal),
177                location: None,
178            });
179        }
180
181        // Use FrameDescriptor hint to preserve integer type identity.
182        // Prefer return_kind when populated; fall back to last slot.
183        let return_hint = bytecode.top_level_frame.as_ref().and_then(|fd| {
184            if fd.return_kind != shape_vm::type_tracking::SlotKind::Unknown {
185                Some(fd.return_kind)
186            } else {
187                fd.slots.last().copied()
188            }
189        });
190        let result_scalar =
191            crate::ffi::object::conversion::jit_bits_to_typed_scalar(raw_result, return_hint);
192
193        // Convert TypedScalar to WireValue
194        let wire_value = self.typed_scalar_to_wire(&result_scalar, raw_result);
195
196        if emit_phase_metrics {
197            let total_ms = bytecode_compile_ms + jit_compile_ms + jit_exec_ms;
198            eprintln!(
199                "[shape-jit-phases] bytecode_compile_ms={} jit_compile_ms={} jit_exec_ms={} total_ms={}",
200                bytecode_compile_ms, jit_compile_ms, jit_exec_ms, total_ms
201            );
202        }
203
204        Ok(shape_runtime::engine::ProgramExecutorResult {
205            wire_value,
206            type_info: None,
207            execution_type: ExecutionType::Script,
208            content_json: None,
209            content_html: None,
210            content_terminal: None,
211        })
212    }
213
214    /// Convert a TypedScalar result to WireValue.
215    ///
216    /// For scalar types, the TypedScalar carries enough information. For heap types
217    /// (strings, arrays) that TypedScalar can't represent, we fall back to raw bits.
218    fn typed_scalar_to_wire(&self, ts: &shape_value::TypedScalar, raw_bits: u64) -> WireValue {
219        use shape_value::ScalarKind;
220
221        match ts.kind {
222            ScalarKind::I8
223            | ScalarKind::I16
224            | ScalarKind::I32
225            | ScalarKind::I64
226            | ScalarKind::U8
227            | ScalarKind::U16
228            | ScalarKind::U32
229            | ScalarKind::U64
230            | ScalarKind::I128
231            | ScalarKind::U128 => {
232                // Integer result — preserve as exact integer in WireValue::Number
233                WireValue::Number(ts.payload_lo as i64 as f64)
234            }
235            ScalarKind::F64 | ScalarKind::F32 => WireValue::Number(f64::from_bits(ts.payload_lo)),
236            ScalarKind::Bool => WireValue::Bool(ts.payload_lo != 0),
237            ScalarKind::Unit => WireValue::Null,
238            ScalarKind::None => {
239                // None could also be a fallback for non-scalar heap types.
240                // Check if raw_bits is actually a heap value.
241                self.nan_boxed_to_wire(raw_bits)
242            }
243        }
244    }
245
246    fn nan_boxed_to_wire(&self, bits: u64) -> WireValue {
247        use crate::nan_boxing::{
248            HK_STRING, TAG_BOOL_FALSE, TAG_BOOL_TRUE, TAG_NULL, is_heap_kind, is_number, jit_unbox,
249            unbox_number,
250        };
251        use shape_value::tags::{get_tag, is_tagged, sign_extend_i48, get_payload, TAG_INT};
252
253        if is_number(bits) {
254            WireValue::Number(unbox_number(bits))
255        } else if bits == TAG_NULL {
256            WireValue::Null
257        } else if bits == TAG_BOOL_TRUE {
258            WireValue::Bool(true)
259        } else if bits == TAG_BOOL_FALSE {
260            WireValue::Bool(false)
261        } else if is_tagged(bits) && get_tag(bits) == TAG_INT {
262            // NaN-boxed i48 integer — sign-extend to i64 and return as integer
263            let int_val = sign_extend_i48(get_payload(bits));
264            WireValue::Integer(int_val)
265        } else if is_heap_kind(bits, HK_STRING) {
266            let s = unsafe { jit_unbox::<String>(bits) };
267            WireValue::String(s.clone())
268        } else {
269            // Default to interpreting as a number for unknown tags
270            WireValue::Number(f64::from_bits(bits))
271        }
272    }
273}