1use shape_ast::Program;
4use shape_runtime::engine::{ExecutionType, ProgramExecutor, ShapeEngine};
5use shape_runtime::error::Result;
6use shape_wire::WireValue;
7use std::time::Instant;
8
9pub struct JITExecutor;
15
16impl ProgramExecutor for JITExecutor {
17 fn execute_program(
18 &self,
19 engine: &mut ShapeEngine,
20 program: &Program,
21 ) -> Result<shape_runtime::engine::ProgramExecutorResult> {
22 use shape_vm::BytecodeCompiler;
23 let emit_phase_metrics = std::env::var_os("SHAPE_JIT_PHASE_METRICS").is_some();
24
25 let source_for_compilation = engine.current_source().map(|s| s.to_string());
27
28 let runtime = engine.get_runtime_mut();
30
31 let known_bindings: Vec<String> = if let Some(ctx) = runtime.persistent_context() {
33 let names = ctx.root_scope_binding_names();
34 if names.is_empty() {
35 shape_vm::stdlib::core_binding_names()
36 } else {
37 names
38 }
39 } else {
40 shape_vm::stdlib::core_binding_names()
41 };
42
43 let module_binding_registry = runtime.module_binding_registry();
45 let imported_program =
46 shape_vm::BytecodeExecutor::create_program_from_imports(&module_binding_registry)?;
47
48 let mut merged_program = imported_program;
50 merged_program.items.extend(program.items.clone());
51 shape_vm::module_resolution::prepend_prelude_items(&mut merged_program);
52
53 let bytecode_compile_start = Instant::now();
55 let mut compiler = BytecodeCompiler::new();
56 compiler.register_known_bindings(&known_bindings);
57 let mut bytecode = if let Some(source) = &source_for_compilation {
58 compiler.compile_with_source(&merged_program, source)
59 } else {
60 compiler.compile(&merged_program)
61 }
62 .map_err(|e| shape_runtime::error::ShapeError::RuntimeError {
63 message: format!("Bytecode compilation failed: {}", e),
64 location: None,
65 })?;
66 let bytecode_compile_ms = bytecode_compile_start.elapsed().as_millis();
67
68 self.execute_with_jit(engine, &bytecode, bytecode_compile_ms, emit_phase_metrics)
69 }
70}
71
72impl JITExecutor {
73 fn execute_with_jit(
74 &self,
75 engine: &mut ShapeEngine,
76 bytecode: &shape_vm::bytecode::BytecodeProgram,
77 bytecode_compile_ms: u128,
78 emit_phase_metrics: bool,
79 ) -> Result<shape_runtime::engine::ProgramExecutorResult> {
80 use crate::JITConfig;
81 use crate::JITContext;
82 use crate::compiler::JITCompiler;
83
84 let jit_config = JITConfig::default();
86 let mut jit = JITCompiler::new(jit_config).map_err(|e| {
87 shape_runtime::error::ShapeError::RuntimeError {
88 message: format!("JIT compiler initialization failed: {}", e),
89 location: None,
90 }
91 })?;
92
93 if std::env::var_os("SHAPE_JIT_DEBUG").is_some() {
96 eprintln!("[jit-debug] starting compile_program_selective with {} instructions, {} functions",
97 bytecode.instructions.len(), bytecode.functions.len());
98 for (i, instr) in bytecode.instructions.iter().enumerate() {
99 eprintln!("[jit-debug] instr[{}]: {:?} {:?}", i, instr.opcode, instr.operand);
100 }
101 }
102 let jit_compile_start = Instant::now();
103 let compile_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
104 jit.compile_program_selective("main", bytecode)
105 }));
106 let jit_compile_ms = jit_compile_start.elapsed().as_millis();
107 let (jit_fn, _mixed_table) = match compile_result {
108 Ok(Ok(result)) => result,
109 Ok(Err(e)) => {
110 return Err(shape_runtime::error::ShapeError::RuntimeError {
111 message: format!("JIT compilation failed: {}", e),
112 location: None,
113 });
114 }
115 Err(panic_info) => {
116 let msg = if let Some(s) = panic_info.downcast_ref::<String>() {
117 s.clone()
118 } else if let Some(s) = panic_info.downcast_ref::<&str>() {
119 s.to_string()
120 } else {
121 "unknown panic".to_string()
122 };
123 return Err(shape_runtime::error::ShapeError::RuntimeError {
124 message: format!("JIT compilation panicked: {}", msg),
125 location: None,
126 });
127 }
128 };
129
130 let foreign_bridge = {
131 let runtime = engine.get_runtime_mut();
132 crate::foreign_bridge::link_foreign_functions_for_jit(
133 bytecode,
134 runtime.persistent_context(),
135 )
136 .map_err(|e| shape_runtime::error::ShapeError::RuntimeError {
137 message: format!("JIT foreign-function linking failed: {}", e),
138 location: None,
139 })?
140 };
141
142 let mut jit_ctx = JITContext::default();
144 if let Some(state) = foreign_bridge.as_ref() {
145 jit_ctx.foreign_bridge_ptr = state.as_ref() as *const _ as *const std::ffi::c_void;
146 }
147
148 {
150 let runtime = engine.get_runtime_mut();
151 if let Some(ctx) = runtime.persistent_context_mut() {
152 jit_ctx.exec_context_ptr = ctx as *mut _ as *mut std::ffi::c_void;
153 }
154 }
155
156 if std::env::var_os("SHAPE_JIT_DEBUG").is_some() {
158 eprintln!("[jit-debug] compilation OK, about to execute...");
159 }
160 let jit_exec_start = Instant::now();
161 let signal = unsafe { jit_fn(&mut jit_ctx) };
162 let jit_exec_ms = jit_exec_start.elapsed().as_millis();
163
164 let raw_result = if jit_ctx.stack_ptr > 0 {
166 jit_ctx.stack[0]
167 } else {
168 crate::nan_boxing::TAG_NULL
169 };
170
171 if signal < 0 {
173 return Err(shape_runtime::error::ShapeError::RuntimeError {
174 message: format!("JIT execution error (code: {})", signal),
175 location: None,
176 });
177 }
178
179 let return_hint = bytecode.top_level_frame.as_ref().and_then(|fd| {
182 if fd.return_kind != shape_vm::type_tracking::SlotKind::Unknown {
183 Some(fd.return_kind)
184 } else {
185 fd.slots.last().copied()
186 }
187 });
188 let result_scalar =
189 crate::ffi::object::conversion::jit_bits_to_typed_scalar(raw_result, return_hint);
190
191 let wire_value = self.typed_scalar_to_wire(&result_scalar, raw_result);
193
194 if emit_phase_metrics {
195 let total_ms = bytecode_compile_ms + jit_compile_ms + jit_exec_ms;
196 eprintln!(
197 "[shape-jit-phases] bytecode_compile_ms={} jit_compile_ms={} jit_exec_ms={} total_ms={}",
198 bytecode_compile_ms, jit_compile_ms, jit_exec_ms, total_ms
199 );
200 }
201
202 Ok(shape_runtime::engine::ProgramExecutorResult {
203 wire_value,
204 type_info: None,
205 execution_type: ExecutionType::Script,
206 content_json: None,
207 content_html: None,
208 content_terminal: None,
209 })
210 }
211
212 fn typed_scalar_to_wire(&self, ts: &shape_value::TypedScalar, raw_bits: u64) -> WireValue {
217 use shape_value::ScalarKind;
218
219 match ts.kind {
220 ScalarKind::I8
221 | ScalarKind::I16
222 | ScalarKind::I32
223 | ScalarKind::I64
224 | ScalarKind::U8
225 | ScalarKind::U16
226 | ScalarKind::U32
227 | ScalarKind::U64
228 | ScalarKind::I128
229 | ScalarKind::U128 => {
230 WireValue::Number(ts.payload_lo as i64 as f64)
232 }
233 ScalarKind::F64 | ScalarKind::F32 => WireValue::Number(f64::from_bits(ts.payload_lo)),
234 ScalarKind::Bool => WireValue::Bool(ts.payload_lo != 0),
235 ScalarKind::Unit => WireValue::Null,
236 ScalarKind::None => {
237 self.nan_boxed_to_wire(raw_bits)
240 }
241 }
242 }
243
244 fn nan_boxed_to_wire(&self, bits: u64) -> WireValue {
245 use crate::nan_boxing::{
246 HK_STRING, TAG_BOOL_FALSE, TAG_BOOL_TRUE, TAG_NULL, is_heap_kind, is_number, jit_unbox,
247 unbox_number,
248 };
249 use shape_value::tags::{get_tag, is_tagged, sign_extend_i48, get_payload, TAG_INT};
250
251 if is_number(bits) {
252 WireValue::Number(unbox_number(bits))
253 } else if bits == TAG_NULL {
254 WireValue::Null
255 } else if bits == TAG_BOOL_TRUE {
256 WireValue::Bool(true)
257 } else if bits == TAG_BOOL_FALSE {
258 WireValue::Bool(false)
259 } else if is_tagged(bits) && get_tag(bits) == TAG_INT {
260 let int_val = sign_extend_i48(get_payload(bits));
262 WireValue::Integer(int_val)
263 } else if is_heap_kind(bits, HK_STRING) {
264 let s = unsafe { jit_unbox::<String>(bits) };
265 WireValue::String(s.clone())
266 } else {
267 WireValue::Number(f64::from_bits(bits))
269 }
270 }
271}