Skip to main content

shape_jit/compiler/
strategy.rs

1//! Strategy compilation
2//!
3//! This module provides two compilation modes:
4//!
5//! 1. **Standard ABI** (`compile_strategy`): Uses `fn(*mut JITContext) -> i32`
6//!    - Full access to VM features (closures, FFI, etc.)
7//!    - Suitable for general-purpose JIT compilation
8//!
9//! 2. **Kernel ABI** (`compile_simulation_kernel`): Uses `fn(usize, *const *const f64, *mut u8) -> i32`
10//!    - Zero-allocation hot path for simulation
11//!    - Direct memory access to series data and state
12//!    - Enables >10M ticks/sec performance
13
14use cranelift::codegen::ir::FuncRef;
15use cranelift::prelude::*;
16use cranelift_module::{Linkage, Module};
17use std::collections::HashMap;
18
19use super::setup::JITCompiler;
20use crate::context::{
21    CorrelatedKernelFn, JittedStrategyFn, SimulationKernelConfig, SimulationKernelFn,
22};
23use crate::translator::BytecodeToIR;
24use shape_vm::bytecode::BytecodeProgram;
25
26impl JITCompiler {
27    #[inline(always)]
28    pub fn compile_strategy(
29        &mut self,
30        name: &str,
31        program: &BytecodeProgram,
32    ) -> Result<JittedStrategyFn, String> {
33        let mut sig = self.module.make_signature();
34        sig.params.push(AbiParam::new(types::I64));
35        sig.returns.push(AbiParam::new(types::I32));
36
37        let func_id = self
38            .module
39            .declare_function(name, Linkage::Export, &sig)
40            .map_err(|e| format!("Failed to declare function: {}", e))?;
41
42        let mut ctx = self.module.make_context();
43        ctx.func.signature = sig;
44
45        let mut func_builder_ctx = FunctionBuilderContext::new();
46        {
47            let mut builder = FunctionBuilder::new(&mut ctx.func, &mut func_builder_ctx);
48            let entry_block = builder.create_block();
49            builder.append_block_params_for_function_params(entry_block);
50            builder.switch_to_block(entry_block);
51            builder.seal_block(entry_block);
52
53            let ctx_ptr = builder.block_params(entry_block)[0];
54
55            let ffi = self.build_ffi_refs(&mut builder);
56
57            let mut compiler = BytecodeToIR::new(
58                &mut builder,
59                program,
60                ctx_ptr,
61                ffi,
62                HashMap::new(),
63                HashMap::new(),
64            );
65            let result = compiler.compile()?;
66
67            builder.ins().return_(&[result]);
68            builder.finalize();
69        }
70
71        self.module
72            .define_function(func_id, &mut ctx)
73            .map_err(|e| format!("Failed to define function (strategy): {:?}", e))?;
74
75        self.module.clear_context(&mut ctx);
76        self.module
77            .finalize_definitions()
78            .map_err(|e| format!("Failed to finalize (strategy): {:?}", e))?;
79
80        let code_ptr = self.module.get_finalized_function(func_id);
81        self.compiled_functions.insert(name.to_string(), code_ptr);
82
83        Ok(unsafe { std::mem::transmute(code_ptr) })
84    }
85
86    #[inline(always)]
87    pub(super) fn compile_strategy_with_user_funcs(
88        &mut self,
89        name: &str,
90        program: &BytecodeProgram,
91        user_func_ids: &HashMap<u16, cranelift_module::FuncId>,
92        user_func_arities: &HashMap<u16, u16>,
93    ) -> Result<cranelift_module::FuncId, String> {
94        let mut sig = self.module.make_signature();
95        sig.params.push(AbiParam::new(types::I64));
96        sig.returns.push(AbiParam::new(types::I32));
97
98        let func_id = self
99            .module
100            .declare_function(name, Linkage::Export, &sig)
101            .map_err(|e| format!("Failed to declare function: {}", e))?;
102
103        let mut ctx = self.module.make_context();
104        ctx.func.signature = sig;
105
106        let mut func_builder_ctx = FunctionBuilderContext::new();
107        {
108            let mut builder = FunctionBuilder::new(&mut ctx.func, &mut func_builder_ctx);
109            let entry_block = builder.create_block();
110            builder.append_block_params_for_function_params(entry_block);
111            builder.switch_to_block(entry_block);
112            builder.seal_block(entry_block);
113
114            let ctx_ptr = builder.block_params(entry_block)[0];
115
116            let mut user_func_refs: HashMap<u16, FuncRef> = HashMap::new();
117            for (&fn_idx, &fn_id) in user_func_ids {
118                let func_ref = self.module.declare_func_in_func(fn_id, builder.func);
119                user_func_refs.insert(fn_idx, func_ref);
120            }
121
122            let ffi = self.build_ffi_refs(&mut builder);
123
124            let mut compiler = BytecodeToIR::new(
125                &mut builder,
126                program,
127                ctx_ptr,
128                ffi,
129                user_func_refs,
130                user_func_arities.clone(),
131            );
132
133            // Set skip ranges so the main function compilation ignores function
134            // body instructions (they are compiled separately via
135            // compile_function_with_user_funcs). Without this, LoopStart/LoopEnd
136            // and Jump targets inside function bodies create blocks in the main
137            // function context, causing dead code compilation and stack corruption.
138            compiler.skip_ranges = Self::compute_skip_ranges(program);
139
140            let result = compiler.compile()?;
141
142            builder.ins().return_(&[result]);
143            builder.finalize();
144        }
145
146        self.module
147            .define_function(func_id, &mut ctx)
148            .map_err(|e| format!("Failed to define function (strategy): {:?}", e))?;
149
150        self.module.clear_context(&mut ctx);
151
152        Ok(func_id)
153    }
154
155    /// Compute instruction index ranges to skip when compiling the main strategy.
156    ///
157    /// Bytecode layout for programs with user functions:
158    /// ```text
159    /// [0]             Jump → trampoline1     (skip func0 body)
160    /// [entry0 .. t1)  func0 body
161    /// [t1]            Jump → trampoline2     (skip func1 body)
162    /// [entry1 .. t2)  func1 body
163    /// ...
164    /// [main_start ..) main code
165    /// ```
166    ///
167    /// Returns the function body ranges (excluding trampoline jumps between them).
168    pub(super) fn compute_skip_ranges(program: &BytecodeProgram) -> Vec<(usize, usize)> {
169        let mut ranges = Vec::new();
170
171        // Skip function bodies (they are compiled separately).
172        for f in program.functions.iter() {
173            if f.body_length == 0 {
174                continue;
175            }
176            ranges.push((f.entry_point, f.entry_point + f.body_length));
177        }
178
179        ranges
180    }
181
182    // ========================================================================
183    // Simulation Kernel Compilation (Zero-Allocation Hot Path)
184    // ========================================================================
185
186    /// Compile a simulation kernel with the specialized kernel ABI.
187    ///
188    /// The kernel ABI bypasses JITContext to achieve maximum throughput:
189    /// - Direct pointer arithmetic for data access
190    /// - No allocations in the hot path
191    /// - Inlined field access with known offsets
192    ///
193    /// # Arguments
194    /// * `name` - Function name for the compiled kernel
195    /// * `program` - Bytecode program containing the strategy
196    /// * `config` - Kernel configuration with field offset mappings
197    ///
198    /// # Returns
199    /// A function pointer with signature: `fn(usize, *const *const f64, *mut u8) -> i32`
200    ///
201    /// # Generated Code Pattern
202    ///
203    /// For a strategy like:
204    /// ```shape
205    /// let price = candle.close
206    /// if price > state.threshold {
207    ///     state.signal = 1.0
208    /// }
209    /// ```
210    ///
211    /// The kernel generates:
212    /// ```asm
213    /// ; price = candle.close (column 3)
214    /// mov rax, [series_ptrs + 3*8]     ; column pointer
215    /// mov xmm0, [rax + cursor_index*8] ; price value
216    ///
217    /// ; state.threshold (offset 16)
218    /// mov xmm1, [state_ptr + 16]       ; threshold value
219    ///
220    /// ; comparison and store
221    /// ucomisd xmm0, xmm1
222    /// jbe skip
223    /// mov qword [state_ptr + 24], 1.0  ; state.signal
224    /// skip:
225    /// ```
226    #[inline(always)]
227    pub fn compile_simulation_kernel(
228        &mut self,
229        name: &str,
230        program: &BytecodeProgram,
231        config: &SimulationKernelConfig,
232    ) -> Result<SimulationKernelFn, String> {
233        // Kernel ABI signature: fn(cursor_index: usize, series_ptrs: *const *const f64, state_ptr: *mut u8) -> i32
234        let mut sig = self.module.make_signature();
235        sig.params.push(AbiParam::new(types::I64)); // cursor_index
236        sig.params.push(AbiParam::new(types::I64)); // series_ptrs
237        sig.params.push(AbiParam::new(types::I64)); // state_ptr
238        sig.returns.push(AbiParam::new(types::I32)); // result code
239
240        let func_id = self
241            .module
242            .declare_function(name, Linkage::Export, &sig)
243            .map_err(|e| format!("Failed to declare kernel function: {}", e))?;
244
245        let mut ctx = self.module.make_context();
246        ctx.func.signature = sig;
247
248        let mut func_builder_ctx = FunctionBuilderContext::new();
249        {
250            let mut builder = FunctionBuilder::new(&mut ctx.func, &mut func_builder_ctx);
251            let entry_block = builder.create_block();
252            builder.append_block_params_for_function_params(entry_block);
253            builder.switch_to_block(entry_block);
254            builder.seal_block(entry_block);
255
256            // Get kernel parameters
257            let cursor_index = builder.block_params(entry_block)[0];
258            let series_ptrs = builder.block_params(entry_block)[1];
259            let state_ptr = builder.block_params(entry_block)[2];
260
261            // Build kernel-specific IR
262            let result = self.build_kernel_ir(
263                &mut builder,
264                program,
265                config,
266                cursor_index,
267                series_ptrs,
268                state_ptr,
269            )?;
270
271            builder.ins().return_(&[result]);
272            builder.finalize();
273        }
274
275        self.module
276            .define_function(func_id, &mut ctx)
277            .map_err(|e| format!("Failed to define kernel function: {:?}", e))?;
278
279        self.module.clear_context(&mut ctx);
280        self.module
281            .finalize_definitions()
282            .map_err(|e| format!("Failed to finalize kernel: {:?}", e))?;
283
284        let code_ptr = self.module.get_finalized_function(func_id);
285        self.compiled_functions.insert(name.to_string(), code_ptr);
286
287        Ok(unsafe { std::mem::transmute(code_ptr) })
288    }
289
290    /// Build kernel IR using BytecodeToIR in kernel mode.
291    ///
292    /// This compiles bytecode to kernel ABI IR with direct memory access:
293    /// - GetFieldTyped → state_ptr + offset
294    /// - GetDataField → series_ptrs[col][cursor]
295    /// - All locals as Cranelift variables
296    fn build_kernel_ir(
297        &mut self,
298        builder: &mut FunctionBuilder,
299        program: &BytecodeProgram,
300        config: &SimulationKernelConfig,
301        cursor_index: Value,
302        series_ptrs: Value,
303        state_ptr: Value,
304    ) -> Result<Value, String> {
305        // Build FFI refs (some may still be needed for complex operations)
306        let ffi = self.build_ffi_refs(builder);
307
308        // Create BytecodeToIR in kernel mode
309        let mut compiler = BytecodeToIR::new_kernel_mode(
310            builder,
311            program,
312            cursor_index,
313            series_ptrs,
314            state_ptr,
315            ffi,
316            config.clone(),
317        );
318
319        // Compile bytecode to kernel IR
320        compiler.compile_kernel()
321    }
322
323    // ========================================================================
324    // Correlated Kernel Compilation (Multi-Series Simulation)
325    // ========================================================================
326
327    /// Compile a correlated (multi-series) simulation kernel.
328    ///
329    /// This extends the simulation kernel to support multiple aligned time series,
330    /// enabling cross-series strategies (e.g., SPY vs VIX, temperature vs pressure).
331    ///
332    /// # Arguments
333    /// * `name` - Function name for the compiled kernel
334    /// * `program` - Bytecode program containing the strategy
335    /// * `config` - Kernel configuration with series mappings
336    ///
337    /// # Returns
338    /// A function pointer with signature:
339    /// `fn(cursor_index: usize, series_ptrs: *const *const f64, table_count: usize, state_ptr: *mut u8) -> i32`
340    ///
341    /// # Generated Code Pattern
342    ///
343    /// For a strategy like:
344    /// ```shape
345    /// let spy_price = context.spy    // series index 0
346    /// let vix_level = context.vix    // series index 1
347    /// if vix_level > 25.0 && state.position == 0 {
348    ///     state.signal = 1.0
349    /// }
350    /// ```
351    ///
352    /// The kernel generates:
353    /// ```asm
354    /// ; spy_price = context.spy (series index 0)
355    /// mov rax, [series_ptrs + 0*8]     ; series 0 pointer
356    /// mov xmm0, [rax + cursor_index*8] ; spy value
357    ///
358    /// ; vix_level = context.vix (series index 1)
359    /// mov rax, [series_ptrs + 1*8]     ; series 1 pointer
360    /// mov xmm1, [rax + cursor_index*8] ; vix value
361    ///
362    /// ; comparison and conditional store
363    /// mov xmm2, [const_25.0]
364    /// ucomisd xmm1, xmm2
365    /// jbe skip
366    /// ; ... check state.position == 0 ...
367    /// mov qword [state_ptr + signal_offset], 1.0
368    /// skip:
369    /// ```
370    #[inline(always)]
371    pub fn compile_correlated_kernel(
372        &mut self,
373        name: &str,
374        program: &BytecodeProgram,
375        config: &SimulationKernelConfig,
376    ) -> Result<CorrelatedKernelFn, String> {
377        // Validate config is for multi-series mode
378        if !config.is_multi_table() {
379            return Err(
380                "compile_correlated_kernel requires multi-series config (use new_multi_table)"
381                    .to_string(),
382            );
383        }
384
385        // Correlated kernel ABI:
386        // fn(cursor_index: usize, series_ptrs: *const *const f64, table_count: usize, state_ptr: *mut u8) -> i32
387        let mut sig = self.module.make_signature();
388        sig.params.push(AbiParam::new(types::I64)); // cursor_index
389        sig.params.push(AbiParam::new(types::I64)); // series_ptrs
390        sig.params.push(AbiParam::new(types::I64)); // table_count
391        sig.params.push(AbiParam::new(types::I64)); // state_ptr
392        sig.returns.push(AbiParam::new(types::I32)); // result code
393
394        let func_id = self
395            .module
396            .declare_function(name, Linkage::Export, &sig)
397            .map_err(|e| format!("Failed to declare correlated kernel function: {}", e))?;
398
399        let mut ctx = self.module.make_context();
400        ctx.func.signature = sig;
401
402        let mut func_builder_ctx = FunctionBuilderContext::new();
403        {
404            let mut builder = FunctionBuilder::new(&mut ctx.func, &mut func_builder_ctx);
405            let entry_block = builder.create_block();
406            builder.append_block_params_for_function_params(entry_block);
407            builder.switch_to_block(entry_block);
408            builder.seal_block(entry_block);
409
410            // Get kernel parameters
411            let cursor_index = builder.block_params(entry_block)[0];
412            let series_ptrs = builder.block_params(entry_block)[1];
413            let _table_count = builder.block_params(entry_block)[2]; // For validation/debugging
414            let state_ptr = builder.block_params(entry_block)[3];
415
416            // Build correlated kernel IR
417            // Note: table_count is known at compile time from config, used for validation
418            let result = self.build_correlated_kernel_ir(
419                &mut builder,
420                program,
421                config,
422                cursor_index,
423                series_ptrs,
424                state_ptr,
425            )?;
426
427            builder.ins().return_(&[result]);
428            builder.finalize();
429        }
430
431        self.module
432            .define_function(func_id, &mut ctx)
433            .map_err(|e| format!("Failed to define correlated kernel function: {:?}", e))?;
434
435        self.module.clear_context(&mut ctx);
436        self.module
437            .finalize_definitions()
438            .map_err(|e| format!("Failed to finalize correlated kernel: {:?}", e))?;
439
440        let code_ptr = self.module.get_finalized_function(func_id);
441        self.compiled_functions.insert(name.to_string(), code_ptr);
442
443        Ok(unsafe { std::mem::transmute(code_ptr) })
444    }
445
446    /// Build correlated kernel IR for multi-series access.
447    ///
448    /// Handles series access via compile-time resolved indices:
449    /// - `context.spy` → `series_ptrs[0][cursor_idx]` (if spy mapped to index 0)
450    /// - `context.vix` → `series_ptrs[1][cursor_idx]` (if vix mapped to index 1)
451    fn build_correlated_kernel_ir(
452        &mut self,
453        builder: &mut FunctionBuilder,
454        program: &BytecodeProgram,
455        config: &SimulationKernelConfig,
456        cursor_index: Value,
457        series_ptrs: Value,
458        state_ptr: Value,
459    ) -> Result<Value, String> {
460        // Build FFI refs
461        let ffi = self.build_ffi_refs(builder);
462
463        // Create BytecodeToIR in correlated kernel mode
464        // The translator will use config.table_map to resolve series names to indices
465        let mut compiler = BytecodeToIR::new_kernel_mode(
466            builder,
467            program,
468            cursor_index,
469            series_ptrs,
470            state_ptr,
471            ffi,
472            config.clone(),
473        );
474
475        // Compile bytecode to correlated kernel IR
476        // The translator handles GetSeriesValue opcode by:
477        // 1. Looking up series name in config.table_map to get index
478        // 2. Generating: series_ptrs[index][cursor_index]
479        compiler.compile_kernel()
480    }
481}