1use shape_ast::Program;
4use shape_runtime::engine::{ExecutionType, ProgramExecutor, ShapeEngine};
5use shape_runtime::error::Result;
6use shape_wire::WireValue;
7use std::time::Instant;
8
9pub struct JITExecutor;
15
16impl ProgramExecutor for JITExecutor {
17 fn execute_program(
18 &mut self,
19 engine: &mut ShapeEngine,
20 program: &Program,
21 ) -> Result<shape_runtime::engine::ProgramExecutorResult> {
22 use shape_vm::BytecodeCompiler;
23 let emit_phase_metrics = std::env::var_os("SHAPE_JIT_PHASE_METRICS").is_some();
24
25 let source_for_compilation = engine.current_source().map(|s| s.to_string());
27
28 let runtime = engine.get_runtime_mut();
30
31 let known_bindings: Vec<String> = if let Some(ctx) = runtime.persistent_context() {
33 let names = ctx.root_scope_binding_names();
34 if names.is_empty() {
35 shape_vm::stdlib::core_binding_names()
36 } else {
37 names
38 }
39 } else {
40 shape_vm::stdlib::core_binding_names()
41 };
42
43 let module_binding_registry = runtime.module_binding_registry();
45 let imported_program =
46 shape_vm::BytecodeExecutor::create_program_from_imports(&module_binding_registry)?;
47
48 let mut merged_program = imported_program;
50 merged_program.items.extend(program.items.clone());
51 let stdlib_names =
52 shape_vm::module_resolution::prepend_prelude_items(&mut merged_program);
53
54 let bytecode_compile_start = Instant::now();
56 let mut compiler = BytecodeCompiler::new();
57 compiler.stdlib_function_names = stdlib_names;
58 compiler.register_known_bindings(&known_bindings);
59 let mut bytecode = if let Some(source) = &source_for_compilation {
60 compiler.compile_with_source(&merged_program, source)
61 } else {
62 compiler.compile(&merged_program)
63 }
64 .map_err(|e| shape_runtime::error::ShapeError::RuntimeError {
65 message: format!("Bytecode compilation failed: {}", e),
66 location: None,
67 })?;
68 let bytecode_compile_ms = bytecode_compile_start.elapsed().as_millis();
69
70 self.execute_with_jit(engine, &bytecode, bytecode_compile_ms, emit_phase_metrics)
71 }
72}
73
74impl JITExecutor {
75 fn execute_with_jit(
76 &self,
77 engine: &mut ShapeEngine,
78 bytecode: &shape_vm::bytecode::BytecodeProgram,
79 bytecode_compile_ms: u128,
80 emit_phase_metrics: bool,
81 ) -> Result<shape_runtime::engine::ProgramExecutorResult> {
82 use crate::JITConfig;
83 use crate::JITContext;
84 use crate::compiler::JITCompiler;
85
86 let jit_config = JITConfig::default();
88 let mut jit = JITCompiler::new(jit_config).map_err(|e| {
89 shape_runtime::error::ShapeError::RuntimeError {
90 message: format!("JIT compiler initialization failed: {}", e),
91 location: None,
92 }
93 })?;
94
95 if std::env::var_os("SHAPE_JIT_DEBUG").is_some() {
98 eprintln!("[jit-debug] starting compile_program_selective with {} instructions, {} functions",
99 bytecode.instructions.len(), bytecode.functions.len());
100 for (i, instr) in bytecode.instructions.iter().enumerate() {
101 eprintln!("[jit-debug] instr[{}]: {:?} {:?}", i, instr.opcode, instr.operand);
102 }
103 }
104 let jit_compile_start = Instant::now();
105 let compile_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
106 jit.compile_program_selective("main", bytecode)
107 }));
108 let jit_compile_ms = jit_compile_start.elapsed().as_millis();
109 let (jit_fn, _mixed_table) = match compile_result {
110 Ok(Ok(result)) => result,
111 Ok(Err(e)) => {
112 return Err(shape_runtime::error::ShapeError::RuntimeError {
113 message: format!("JIT compilation failed: {}", e),
114 location: None,
115 });
116 }
117 Err(panic_info) => {
118 let msg = if let Some(s) = panic_info.downcast_ref::<String>() {
119 s.clone()
120 } else if let Some(s) = panic_info.downcast_ref::<&str>() {
121 s.to_string()
122 } else {
123 "unknown panic".to_string()
124 };
125 return Err(shape_runtime::error::ShapeError::RuntimeError {
126 message: format!("JIT compilation panicked: {}", msg),
127 location: None,
128 });
129 }
130 };
131
132 let foreign_bridge = {
133 let runtime = engine.get_runtime_mut();
134 crate::foreign_bridge::link_foreign_functions_for_jit(
135 bytecode,
136 runtime.persistent_context(),
137 )
138 .map_err(|e| shape_runtime::error::ShapeError::RuntimeError {
139 message: format!("JIT foreign-function linking failed: {}", e),
140 location: None,
141 })?
142 };
143
144 let mut jit_ctx = JITContext::default();
146 if let Some(state) = foreign_bridge.as_ref() {
147 jit_ctx.foreign_bridge_ptr = state.as_ref() as *const _ as *const std::ffi::c_void;
148 }
149
150 {
152 let runtime = engine.get_runtime_mut();
153 if let Some(ctx) = runtime.persistent_context_mut() {
154 jit_ctx.exec_context_ptr = ctx as *mut _ as *mut std::ffi::c_void;
155 }
156 }
157
158 if std::env::var_os("SHAPE_JIT_DEBUG").is_some() {
160 eprintln!("[jit-debug] compilation OK, about to execute...");
161 }
162 let jit_exec_start = Instant::now();
163 let signal = unsafe { jit_fn(&mut jit_ctx) };
164 let jit_exec_ms = jit_exec_start.elapsed().as_millis();
165
166 let raw_result = if jit_ctx.stack_ptr > 0 {
168 jit_ctx.stack[0]
169 } else {
170 crate::nan_boxing::TAG_NULL
171 };
172
173 if signal < 0 {
175 return Err(shape_runtime::error::ShapeError::RuntimeError {
176 message: format!("JIT execution error (code: {})", signal),
177 location: None,
178 });
179 }
180
181 let return_hint = bytecode.top_level_frame.as_ref().and_then(|fd| {
184 if fd.return_kind != shape_vm::type_tracking::SlotKind::Unknown {
185 Some(fd.return_kind)
186 } else {
187 fd.slots.last().copied()
188 }
189 });
190 let result_scalar =
191 crate::ffi::object::conversion::jit_bits_to_typed_scalar(raw_result, return_hint);
192
193 let wire_value = self.typed_scalar_to_wire(&result_scalar, raw_result);
195
196 if emit_phase_metrics {
197 let total_ms = bytecode_compile_ms + jit_compile_ms + jit_exec_ms;
198 eprintln!(
199 "[shape-jit-phases] bytecode_compile_ms={} jit_compile_ms={} jit_exec_ms={} total_ms={}",
200 bytecode_compile_ms, jit_compile_ms, jit_exec_ms, total_ms
201 );
202 }
203
204 Ok(shape_runtime::engine::ProgramExecutorResult {
205 wire_value,
206 type_info: None,
207 execution_type: ExecutionType::Script,
208 content_json: None,
209 content_html: None,
210 content_terminal: None,
211 })
212 }
213
214 fn typed_scalar_to_wire(&self, ts: &shape_value::TypedScalar, raw_bits: u64) -> WireValue {
219 use shape_value::ScalarKind;
220
221 match ts.kind {
222 ScalarKind::I8
223 | ScalarKind::I16
224 | ScalarKind::I32
225 | ScalarKind::I64
226 | ScalarKind::U8
227 | ScalarKind::U16
228 | ScalarKind::U32
229 | ScalarKind::U64
230 | ScalarKind::I128
231 | ScalarKind::U128 => {
232 WireValue::Number(ts.payload_lo as i64 as f64)
234 }
235 ScalarKind::F64 | ScalarKind::F32 => WireValue::Number(f64::from_bits(ts.payload_lo)),
236 ScalarKind::Bool => WireValue::Bool(ts.payload_lo != 0),
237 ScalarKind::Unit => WireValue::Null,
238 ScalarKind::None => {
239 self.nan_boxed_to_wire(raw_bits)
242 }
243 }
244 }
245
246 fn nan_boxed_to_wire(&self, bits: u64) -> WireValue {
247 use crate::nan_boxing::{
248 HK_STRING, TAG_BOOL_FALSE, TAG_BOOL_TRUE, TAG_NULL, is_heap_kind, is_number, jit_unbox,
249 unbox_number,
250 };
251 use shape_value::tags::{get_tag, is_tagged, sign_extend_i48, get_payload, TAG_INT};
252
253 if is_number(bits) {
254 WireValue::Number(unbox_number(bits))
255 } else if bits == TAG_NULL {
256 WireValue::Null
257 } else if bits == TAG_BOOL_TRUE {
258 WireValue::Bool(true)
259 } else if bits == TAG_BOOL_FALSE {
260 WireValue::Bool(false)
261 } else if is_tagged(bits) && get_tag(bits) == TAG_INT {
262 let int_val = sign_extend_i48(get_payload(bits));
264 WireValue::Integer(int_val)
265 } else if is_heap_kind(bits, HK_STRING) {
266 let s = unsafe { jit_unbox::<String>(bits) };
267 WireValue::String(s.clone())
268 } else {
269 WireValue::Number(f64::from_bits(bits))
271 }
272 }
273}