Skip to main content

shape_jit/
context.rs

1//! JIT Context and Related Data Structures
2//!
3//! Contains the runtime context and data structures used by JIT-compiled code.
4
5use super::nan_boxing::*;
6
7// ============================================================================
8// JITContext Field Offsets for Direct Memory Access
9// ============================================================================
10//
11// These must match the #[repr(C)] struct layout of JITContext
12// Regenerate with: rustc --edition 2024 scripts/jit_offsets.rs && ./jit_offsets
13
14// Generic column access - columns are accessed via column_ptrs array indexed by column_map
15// Timestamps pointer for time-based access
16pub const TIMESTAMPS_PTR_OFFSET: i32 = 24;
17
18// DataFrame access offsets
19pub const COLUMN_PTRS_OFFSET: i32 = 32;
20pub const COLUMN_COUNT_OFFSET: i32 = 40;
21pub const ROW_COUNT_OFFSET: i32 = 48;
22pub const CURRENT_ROW_OFFSET: i32 = 56;
23
24// Locals and stack offsets
25pub const LOCALS_OFFSET: i32 = 64;
26pub const STACK_OFFSET: i32 = 2112; // 64 + (256 * 8)
27pub const STACK_PTR_OFFSET: i32 = 6208; // 2112 + (512 * 8)
28
29// GC safepoint flag pointer offset (for inline safepoint check)
30pub const GC_SAFEPOINT_FLAG_PTR_OFFSET: i32 = 6328;
31
32// ============================================================================
33// Type Aliases
34// ============================================================================
35
36/// Function pointer type for JIT-compiled strategy functions
37pub type JittedStrategyFn = unsafe extern "C" fn(*mut JITContext) -> i32;
38
39/// Legacy function signature for simple numeric computations
40pub type JittedFn = unsafe extern "C" fn(*mut f64, *const f64, usize) -> f64;
41
42/// OSR entry function signature.
43///
44/// This has the same binary signature as `JittedStrategyFn` -- the difference
45/// is semantic: for OSR entry, the caller pre-fills `JITContext.locals` from
46/// the interpreter's live frame before invocation, and reads modified locals
47/// back on return.
48///
49/// # Arguments
50/// * `ctx_ptr` - Pointer to a `JITContext` with locals pre-filled from the
51///   interpreter frame (marshaled using the `OsrEntryPoint.local_kinds`).
52///
53/// # Returns
54/// * `0`          - Success: execution completed. Modified locals are in
55///                  `JITContext.locals`. The VM reads them back and continues
56///                  at `OsrEntryPoint.exit_ip`.
57/// * `i32::MIN+1` - Deopt requested: a type guard failed mid-loop. The VM
58///                  reads locals from `JITContext.locals` and resumes at
59///                  the `DeoptInfo.resume_ip` for the failing guard.
60/// * Other negative - Error.
61pub type OsrEntryFn = unsafe extern "C" fn(*mut JITContext) -> i32;
62
63// ============================================================================
64// Simulation Kernel ABI (Zero-Allocation Hot Path)
65// ============================================================================
66
67/// Function pointer type for simulation kernel functions (single series).
68///
69/// This is the "fused step" ABI that enables >10M ticks/sec by:
70/// - Bypassing JITContext setup overhead
71/// - Using direct pointer arithmetic for data access
72/// - Avoiding all allocations in the hot loop
73///
74/// # Arguments
75/// * `cursor_index` - Current position in the series (0-based)
76/// * `series_ptrs` - Pointer to array of column pointers (*const *const f64)
77/// * `state_ptr` - Pointer to TypedObject state (*mut u8)
78///
79/// # Returns
80/// * 0 = continue execution
81/// * 1 = signal generated (written to state)
82/// * negative = error
83///
84/// # Safety
85/// The caller must ensure:
86/// - `cursor_index` is within bounds
87/// - `series_ptrs` points to valid column pointer array
88/// - `state_ptr` points to valid TypedObject with correct schema
89pub type SimulationKernelFn = unsafe extern "C" fn(
90    cursor_index: usize,
91    series_ptrs: *const *const f64,
92    state_ptr: *mut u8,
93) -> i32;
94
95/// Function pointer type for correlated (multi-series) kernel functions.
96///
97/// This extends the simulation kernel ABI to support multiple aligned time series.
98/// Each series is accessed via compile-time resolved indices.
99///
100/// # Arguments
101/// * `cursor_index` - Current position in all series (0-based, must be aligned)
102/// * `series_ptrs` - Pointer to array of series data pointers (*const *const f64)
103///                   Each pointer is a single f64 array (one series's data)
104/// * `table_count` - Number of series (for bounds checking, known at compile time)
105/// * `state_ptr` - Pointer to TypedObject state (*mut u8)
106///
107/// # Memory Layout
108/// ```text
109/// series_ptrs[0] -> [spy_close[0], spy_close[1], ..., spy_close[n-1]]
110/// series_ptrs[1] -> [vix_close[0], vix_close[1], ..., vix_close[n-1]]
111/// ...
112/// ```
113///
114/// # JIT Access Pattern
115/// ```asm
116/// ; context.spy (series index 0)
117/// mov rax, [series_ptrs + 0*8]     ; series pointer
118/// mov xmm0, [rax + cursor_index*8] ; value at cursor
119/// ```
120///
121/// # Returns
122/// * 0 = continue execution
123/// * 1 = signal generated (written to state)
124/// * negative = error
125///
126/// # Safety
127/// The caller must ensure:
128/// - `cursor_index` is within bounds for ALL series
129/// - `series_ptrs` points to valid array of `table_count` data pointers
130/// - All series have the same length (aligned timestamps)
131/// - `state_ptr` points to valid TypedObject with correct schema
132pub type CorrelatedKernelFn = unsafe extern "C" fn(
133    cursor_index: usize,
134    series_ptrs: *const *const f64,
135    table_count: usize,
136    state_ptr: *mut u8,
137) -> i32;
138
139/// Configuration for compiling a simulation kernel.
140///
141/// This provides the field offset mappings needed to generate
142/// direct memory access code for the kernel ABI.
143///
144/// Supports two modes:
145/// - **Single-series**: `column_map` maps field names (close, volume) to column indices
146/// - **Multi-series**: `table_map` maps series names (spy, vix) to series indices
147#[derive(Debug, Clone)]
148pub struct SimulationKernelConfig {
149    /// Column index mappings: (field_name, column_index)
150    /// e.g., [("close", 3), ("volume", 4)]
151    /// Used in single-series mode for accessing columns within one series
152    pub column_map: Vec<(String, usize)>,
153
154    /// Series index mappings: (series_name, series_index)
155    /// e.g., [("spy", 0), ("vix", 1), ("temperature", 2)]
156    /// Used in multi-series mode for accessing multiple correlated series
157    ///
158    /// CRITICAL for JIT: Resolved at compile time, NOT runtime.
159    /// `context.spy` → `series_ptrs[0][cursor_idx]`
160    pub table_map: Vec<(String, usize)>,
161
162    /// State field offsets: (field_name, byte_offset)
163    /// e.g., [("cash", 0), ("position", 8), ("entry_price", 16)]
164    pub state_field_offsets: Vec<(String, usize)>,
165
166    /// Schema ID for the state TypedObject
167    pub state_schema_id: u32,
168
169    /// Total number of columns in the data (single-series mode)
170    pub column_count: usize,
171
172    /// Total number of series (multi-series mode)
173    pub table_count: usize,
174}
175
176impl SimulationKernelConfig {
177    /// Create a new kernel config for single-series mode.
178    pub fn new(state_schema_id: u32, column_count: usize) -> Self {
179        Self {
180            column_map: Vec::new(),
181            table_map: Vec::new(),
182            state_field_offsets: Vec::new(),
183            state_schema_id,
184            column_count,
185            table_count: 0,
186        }
187    }
188
189    /// Create a new kernel config for multi-series (correlated) mode.
190    ///
191    /// Use this when simulating across multiple aligned time series
192    /// (e.g., SPY vs VIX, temperature vs pressure).
193    pub fn new_multi_table(state_schema_id: u32, table_count: usize) -> Self {
194        Self {
195            column_map: Vec::new(),
196            table_map: Vec::new(),
197            state_field_offsets: Vec::new(),
198            state_schema_id,
199            column_count: 0,
200            table_count,
201        }
202    }
203
204    /// Map a data field name to a column index (single-series mode).
205    pub fn map_column(mut self, field_name: &str, column_index: usize) -> Self {
206        self.column_map.push((field_name.to_string(), column_index));
207        self
208    }
209
210    /// Map a series name to a series index (multi-series mode).
211    ///
212    /// CRITICAL: This mapping is resolved at compile time.
213    /// `context.spy` in Shape → `series_ptrs[0][cursor_idx]` in generated code.
214    pub fn map_series(mut self, series_name: &str, series_index: usize) -> Self {
215        self.table_map.push((series_name.to_string(), series_index));
216        self
217    }
218
219    /// Map a state field name to a byte offset.
220    pub fn map_state_field(mut self, field_name: &str, offset: usize) -> Self {
221        self.state_field_offsets
222            .push((field_name.to_string(), offset));
223        self
224    }
225
226    /// Get column index for a field name (single-series mode).
227    pub fn get_column_index(&self, field_name: &str) -> Option<usize> {
228        self.column_map
229            .iter()
230            .find(|(name, _)| name == field_name)
231            .map(|(_, idx)| *idx)
232    }
233
234    /// Get series index for a series name (multi-series mode).
235    ///
236    /// This is used by the JIT compiler at compile time.
237    pub fn get_series_index(&self, series_name: &str) -> Option<usize> {
238        self.table_map
239            .iter()
240            .find(|(name, _)| name == series_name)
241            .map(|(_, idx)| *idx)
242    }
243
244    /// Get state field offset for a field name.
245    pub fn get_state_offset(&self, field_name: &str) -> Option<usize> {
246        self.state_field_offsets
247            .iter()
248            .find(|(name, _)| name == field_name)
249            .map(|(_, offset)| *offset)
250    }
251
252    /// Check if this config is for multi-series mode.
253    pub fn is_multi_table(&self) -> bool {
254        self.table_count > 0 || !self.table_map.is_empty()
255    }
256}
257
258// ============================================================================
259// JIT Data Structures
260// ============================================================================
261
262/// JIT-compatible closure structure
263/// Holds function_id and a pointer to a heap-allocated array of captured values.
264/// Supports unlimited captures (no fixed-size limit).
265#[repr(C)]
266pub struct JITClosure {
267    pub function_id: u16,
268    pub captures_count: u16,
269    pub captures_ptr: *const u64, // Pointer to heap-allocated capture array (NaN-boxed)
270}
271
272impl JITClosure {
273    /// Create a new JITClosure with dynamically allocated captures.
274    ///
275    /// The captures slice is copied into a heap-allocated `Box<[u64]>` that is
276    /// leaked into a raw pointer. Call `drop_captures()` to reclaim the memory.
277    pub fn new(function_id: u16, captures: &[u64]) -> Box<Self> {
278        let captures_box: Box<[u64]> = captures.to_vec().into_boxed_slice();
279        let captures_ptr = Box::into_raw(captures_box) as *const u64;
280        Box::new(JITClosure {
281            function_id,
282            captures_count: captures.len() as u16,
283            captures_ptr,
284        })
285    }
286
287    /// Safely read a capture value by index.
288    ///
289    /// # Safety
290    /// The captures_ptr must be valid and index must be < captures_count.
291    #[inline]
292    pub unsafe fn get_capture(&self, index: usize) -> u64 {
293        debug_assert!(index < self.captures_count as usize);
294        unsafe { *self.captures_ptr.add(index) }
295    }
296
297    /// Free the heap-allocated captures array.
298    ///
299    /// Idempotent: safe to call multiple times (no-op after first call).
300    ///
301    /// # Safety
302    /// The captures_ptr must point to a valid allocation created by `new()`,
303    /// or be null (no-op).
304    pub unsafe fn drop_captures(&mut self) {
305        if !self.captures_ptr.is_null() && self.captures_count > 0 {
306            let count = self.captures_count as usize;
307            let _ = unsafe {
308                Box::from_raw(std::slice::from_raw_parts_mut(
309                    self.captures_ptr as *mut u64,
310                    count,
311                ))
312            };
313            self.captures_ptr = std::ptr::null();
314        }
315    }
316}
317
318impl Drop for JITClosure {
319    fn drop(&mut self) {
320        // SAFETY: drop_captures is idempotent — if captures_ptr is already null
321        // (e.g. from an explicit drop_captures() call), this is a no-op.
322        unsafe { self.drop_captures() };
323    }
324}
325
326/// JIT-compatible duration structure
327#[repr(C)]
328pub struct JITDuration {
329    pub value: f64,
330    pub unit: u8, // 0=seconds, 1=minutes, 2=hours, 3=days, 4=weeks, 5=bars
331}
332
333impl JITDuration {
334    pub fn new(value: f64, unit: u8) -> Box<Self> {
335        Box::new(JITDuration { value, unit })
336    }
337
338    pub fn box_duration(duration: Box<JITDuration>) -> u64 {
339        use crate::nan_boxing::{HK_DURATION, jit_box};
340        jit_box(HK_DURATION, *duration)
341    }
342}
343
344/// JIT-compatible range structure
345/// Represents a range with start and end values (both NaN-boxed)
346#[repr(C)]
347pub struct JITRange {
348    pub start: u64, // NaN-boxed start value
349    pub end: u64,   // NaN-boxed end value
350}
351
352impl JITRange {
353    pub fn new(start: u64, end: u64) -> Box<Self> {
354        Box::new(JITRange { start, end })
355    }
356
357    pub fn box_range(range: Box<JITRange>) -> u64 {
358        use crate::nan_boxing::{HK_RANGE, jit_box};
359        jit_box(HK_RANGE, *range)
360    }
361}
362
363/// JIT-compatible SignalBuilder structure
364/// Represents a signal builder for method chaining (series.where().then().capture())
365#[repr(C)]
366pub struct JITSignalBuilder {
367    pub series: u64,                  // NaN-boxed TAG_TABLE
368    pub conditions: Vec<u64>,         // Array of (condition_type, condition_series) pairs
369    pub captures: Vec<(String, u64)>, // (name, value) pairs for captured values
370}
371
372impl JITSignalBuilder {
373    pub fn new(series: u64) -> Box<Self> {
374        Box::new(JITSignalBuilder {
375            series,
376            conditions: Vec::new(),
377            captures: Vec::new(),
378        })
379    }
380
381    pub fn add_where(&mut self, condition_series: u64) {
382        // 0 = WHERE condition
383        self.conditions.push(0);
384        self.conditions.push(condition_series);
385    }
386
387    pub fn add_then(&mut self, condition_series: u64, max_gap: u64) {
388        // 1 = THEN condition
389        self.conditions.push(1);
390        self.conditions.push(condition_series);
391        self.conditions.push(max_gap);
392    }
393
394    pub fn add_capture(&mut self, name: String, value: u64) {
395        self.captures.push((name, value));
396    }
397
398    pub fn box_builder(builder: Box<JITSignalBuilder>) -> u64 {
399        use crate::nan_boxing::{HK_JIT_SIGNAL_BUILDER, jit_box};
400        jit_box(HK_JIT_SIGNAL_BUILDER, *builder)
401    }
402}
403
404/// JIT-compatible data reference structure
405/// Represents a reference to a specific data row in time
406#[repr(C)]
407pub struct JITDataReference {
408    pub timestamp: i64,
409    pub symbol: *const String, // Pointer to symbol string
410    pub timeframe_value: u32,  // Timeframe value
411    pub timeframe_unit: u8,    // 0=Second, 1=Minute, 2=Hour, 3=Day, 4=Week, 5=Month, 6=Bar
412    pub has_timezone: bool,
413    pub timezone: *const String, // Pointer to timezone string (may be null)
414}
415
416impl JITDataReference {
417    pub fn box_data_ref(data_ref: Box<JITDataReference>) -> u64 {
418        use crate::nan_boxing::{HK_DATA_REFERENCE, jit_box};
419        jit_box(HK_DATA_REFERENCE, *data_ref)
420    }
421}
422
423// ============================================================================
424// JITContext - Main Execution Context
425// ============================================================================
426
427/// JIT execution context passed to compiled functions
428/// This struct must be C-compatible (#[repr(C)]) for FFI
429///
430/// Uses NaN-boxing for full type support
431#[repr(C)]
432#[derive(Debug, Clone)]
433pub struct JITContext {
434    // Position state
435    pub in_position: bool,
436    pub position_side: i8,       // 0=None, 1=Long, -1=Short
437    pub entry_price: u64,        // NaN-boxed f64
438    pub unrealized_pnl_pct: u64, // NaN-boxed f64
439
440    // Timestamps pointer for time-based data access
441    pub timestamps_ptr: *const i64,
442
443    // ========== Generic DataFrame Access (industry-agnostic) ==========
444    /// Array of column pointers (SIMD-aligned f64 arrays)
445    /// Column order matches DataFrameSchema.column_names
446    pub column_ptrs: *const *const f64,
447    /// Number of columns in the DataFrame
448    pub column_count: usize,
449    /// Number of rows in the DataFrame
450    pub row_count: usize,
451    /// Current row index (for backtest iteration)
452    pub current_row: usize,
453
454    // Local variables (NaN-boxed values)
455    pub locals: [u64; 256],
456
457    // NaN-boxed stack for JIT execution
458    pub stack: [u64; 512],
459    pub stack_ptr: usize,
460
461    // Heap object storage (owned by VM, JIT just holds pointers)
462    pub heap_ptr: *mut std::ffi::c_void,
463
464    // Function table for Call opcode (pointer to array of function pointers)
465    pub function_table: *const JittedStrategyFn,
466    pub function_table_len: usize,
467
468    // ExecutionContext pointer for fallback to interpreter
469    pub exec_context_ptr: *mut std::ffi::c_void,
470
471    // Function names for closure-to-Value conversion
472    // Points to contiguous String array from BytecodeProgram.functions
473    pub function_names_ptr: *const String,
474    pub function_names_len: usize,
475
476    // ========== Async Execution Support ==========
477    /// Pointer to event queue (for FFI calls to poll/push events)
478    /// Points to a SharedEventQueue behind the scenes
479    pub event_queue_ptr: *mut std::ffi::c_void,
480
481    /// Suspension state: 0 = running, 1 = yielded, 2 = suspended
482    pub suspension_state: u32,
483
484    /// Iterations since last yield (for cooperative scheduling)
485    pub iterations_since_yield: u64,
486
487    /// Yield threshold - yield after this many iterations
488    /// 0 = never yield automatically
489    pub yield_threshold: u64,
490
491    /// Alert pipeline pointer (for FFI calls to emit alerts)
492    /// Points to AlertRouter behind the scenes
493    pub alert_pipeline_ptr: *mut std::ffi::c_void,
494
495    // ========== Simulation Mode Support ==========
496    /// Simulation mode: 0 = disabled, 1 = DenseKernel, 2 = HybridKernel
497    pub simulation_mode: u32,
498
499    /// Pointer to simulation state (TypedObject for DenseKernel)
500    /// JIT code accesses state fields via direct memory offset
501    pub simulation_state_ptr: *mut u8,
502
503    /// Size of simulation state data (for deallocation)
504    pub simulation_state_size: usize,
505
506    // ========== GC Integration ==========
507    /// Pointer to GC safepoint flag (AtomicBool raw pointer).
508    /// Null when GC is not enabled. The JIT safepoint function reads this
509    /// to determine if a GC cycle is requested.
510    pub gc_safepoint_flag_ptr: *const u8,
511
512    /// Pointer to GcHeap for allocation fast path.
513    /// Null when GC is not enabled.
514    pub gc_heap_ptr: *mut std::ffi::c_void,
515
516    /// Opaque pointer to JIT foreign-call bridge state.
517    /// Null when no foreign functions are linked for this execution.
518    pub foreign_bridge_ptr: *const std::ffi::c_void,
519}
520
521impl Default for JITContext {
522    fn default() -> Self {
523        Self {
524            in_position: false,
525            position_side: 0,
526            entry_price: box_number(0.0),
527            unrealized_pnl_pct: box_number(0.0),
528            // Timestamps pointer
529            timestamps_ptr: std::ptr::null(),
530            // Generic DataFrame access
531            column_ptrs: std::ptr::null(),
532            column_count: 0,
533            row_count: 0,
534            current_row: 0,
535            // Local variables and stack
536            locals: [TAG_NULL; 256],
537            stack: [TAG_NULL; 512],
538            stack_ptr: 0,
539            heap_ptr: std::ptr::null_mut(),
540            function_table: std::ptr::null(),
541            function_table_len: 0,
542            exec_context_ptr: std::ptr::null_mut(),
543            function_names_ptr: std::ptr::null(),
544            function_names_len: 0,
545            // Async execution support
546            event_queue_ptr: std::ptr::null_mut(),
547            suspension_state: 0,
548            iterations_since_yield: 0,
549            yield_threshold: 0, // 0 = no automatic yielding
550            alert_pipeline_ptr: std::ptr::null_mut(),
551            // Simulation mode support
552            simulation_mode: 0,
553            simulation_state_ptr: std::ptr::null_mut(),
554            simulation_state_size: 0,
555            // GC integration
556            gc_safepoint_flag_ptr: std::ptr::null(),
557            gc_heap_ptr: std::ptr::null_mut(),
558            foreign_bridge_ptr: std::ptr::null(),
559        }
560    }
561}
562
563impl JITContext {
564    /// Get column value at offset from current row
565    /// column_index is the column index in the DataFrame schema
566    pub fn get_column_value(&self, column_index: usize, offset: i32) -> f64 {
567        if self.column_ptrs.is_null() || column_index >= self.column_count {
568            return 0.0;
569        }
570        let row_idx = (self.current_row as i32 + offset) as usize;
571        if row_idx < self.row_count {
572            unsafe {
573                let col_ptr = *self.column_ptrs.add(column_index);
574                if !col_ptr.is_null() {
575                    *col_ptr.add(row_idx)
576                } else {
577                    0.0
578                }
579            }
580        } else {
581            0.0
582        }
583    }
584
585    /// Update current row index for DataFrame iteration
586    #[inline]
587    pub fn set_current_row(&mut self, index: usize) {
588        self.current_row = index;
589    }
590
591    /// Update current row for backtest iteration (alias for backward compatibility)
592    #[inline]
593    pub fn update_current_row(&mut self, index: usize) {
594        self.current_row = index;
595    }
596
597    // ========================================================================
598    // Simulation Mode Methods
599    // ========================================================================
600
601    /// Check if in simulation mode
602    #[inline]
603    pub fn is_simulation_mode(&self) -> bool {
604        self.simulation_mode > 0
605    }
606
607    /// Set up context for DenseKernel simulation.
608    ///
609    /// # Arguments
610    /// * `state_ptr` - Pointer to TypedObject state
611    /// * `state_size` - Size of state data
612    /// * `column_ptrs` - Pointers to data columns
613    /// * `column_count` - Number of columns
614    /// * `row_count` - Number of rows
615    /// * `timestamps` - Pointer to timestamp array
616    pub fn setup_simulation(
617        &mut self,
618        state_ptr: *mut u8,
619        state_size: usize,
620        column_ptrs: *const *const f64,
621        column_count: usize,
622        row_count: usize,
623        timestamps: *const i64,
624    ) {
625        self.simulation_mode = 1; // DenseKernel mode
626        self.simulation_state_ptr = state_ptr;
627        self.simulation_state_size = state_size;
628        self.column_ptrs = column_ptrs;
629        self.column_count = column_count;
630        self.row_count = row_count;
631        self.current_row = 0;
632        self.timestamps_ptr = timestamps;
633    }
634
635    /// Get simulation state field as f64.
636    ///
637    /// # Safety
638    /// Caller must ensure offset is valid for the state TypedObject.
639    #[inline]
640    pub unsafe fn get_state_field_f64(&self, offset: usize) -> f64 {
641        if self.simulation_state_ptr.is_null() {
642            return 0.0;
643        }
644        let field_ptr = unsafe { self.simulation_state_ptr.add(8 + offset) } as *const u64;
645        let bits = unsafe { *field_ptr };
646        unbox_number(bits)
647    }
648
649    /// Set simulation state field as f64.
650    ///
651    /// # Safety
652    /// Caller must ensure offset is valid for the state TypedObject.
653    #[inline]
654    pub unsafe fn set_state_field_f64(&mut self, offset: usize, value: f64) {
655        if self.simulation_state_ptr.is_null() {
656            return;
657        }
658        let field_ptr = unsafe { self.simulation_state_ptr.add(8 + offset) } as *mut u64;
659        unsafe { *field_ptr = box_number(value) };
660    }
661
662    /// Clear simulation mode.
663    pub fn clear_simulation(&mut self) {
664        self.simulation_mode = 0;
665        self.simulation_state_ptr = std::ptr::null_mut();
666        self.simulation_state_size = 0;
667    }
668}
669
670// ============================================================================
671// JITDataFrame - Generic DataFrame for JIT (industry-agnostic)
672// ============================================================================
673
674/// Generic DataFrame storage for JIT execution.
675/// Stores data as an array of columns, matching the generic column_ptrs
676/// design in JITContext.
677///
678/// Column order MUST match the DataFrameSchema used during compilation.
679pub struct JITDataFrame {
680    /// Column data arrays (each Vec is one column)
681    /// Columns are ordered by index as defined in DataFrameSchema
682    pub columns: Vec<Vec<f64>>,
683    /// Pointers to column data (for JITContext.column_ptrs)
684    pub column_ptrs: Vec<*const f64>,
685    /// Timestamps (always present, column 0 equivalent)
686    pub timestamps: Vec<i64>,
687    /// Number of rows
688    pub row_count: usize,
689}
690
691impl JITDataFrame {
692    /// Create an empty JITDataFrame
693    pub fn new() -> Self {
694        Self {
695            columns: Vec::new(),
696            column_ptrs: Vec::new(),
697            timestamps: Vec::new(),
698            row_count: 0,
699        }
700    }
701
702    /// Create from ExecutionContext using a schema mapping.
703    /// The schema determines which columns to extract and their order.
704    pub fn from_execution_context(
705        ctx: &shape_runtime::context::ExecutionContext,
706        schema: &shape_vm::bytecode::DataFrameSchema,
707    ) -> Self {
708        let mut data = Self::new();
709
710        // NOTE: Series caching not yet implemented in ExecutionContext
711        // For now, initialize empty columns for each schema column
712        // TODO: Implement series caching when available
713        let _ = (ctx, schema); // Suppress unused warnings
714        for _ in 0..schema.column_names.len() {
715            data.columns.push(Vec::new());
716            data.column_ptrs.push(std::ptr::null());
717        }
718
719        data
720    }
721
722    /// Populate a JITContext with generic DataFrame pointers.
723    /// This sets column_ptrs, column_count, row_count, and timestamps_ptr.
724    pub fn populate_context(&self, ctx: &mut JITContext) {
725        if !self.column_ptrs.is_empty() {
726            ctx.column_ptrs = self.column_ptrs.as_ptr();
727            ctx.column_count = self.column_ptrs.len();
728        }
729        ctx.row_count = self.row_count;
730
731        if !self.timestamps.is_empty() {
732            ctx.timestamps_ptr = self.timestamps.as_ptr();
733        }
734    }
735
736    /// Get the number of rows
737    pub fn len(&self) -> usize {
738        self.row_count
739    }
740
741    /// Check if empty
742    pub fn is_empty(&self) -> bool {
743        self.row_count == 0
744    }
745
746    /// Get number of columns
747    pub fn column_count(&self) -> usize {
748        self.columns.len()
749    }
750
751    /// Create from a DataTable by extracting f64 columns and an optional timestamp column.
752    ///
753    /// All f64 columns are copied into SIMD-aligned buffers. If a column named
754    /// "timestamp" (or typed as Timestamp) exists, it is extracted as i64.
755    pub fn from_datatable(dt: &shape_value::DataTable) -> Self {
756        use arrow_array::cast::AsArray;
757        use arrow_schema::{DataType, TimeUnit};
758
759        let batch = dt.inner();
760        let schema = batch.schema();
761        let num_rows = batch.num_rows();
762        let mut columns = Vec::new();
763        let mut timestamps = Vec::new();
764
765        for (i, field) in schema.fields().iter().enumerate() {
766            match field.data_type() {
767                DataType::Float64 => {
768                    let arr = batch
769                        .column(i)
770                        .as_primitive::<arrow_array::types::Float64Type>();
771                    let col: Vec<f64> = (0..num_rows).map(|r| arr.value(r)).collect();
772                    columns.push(col);
773                }
774                DataType::Timestamp(TimeUnit::Microsecond, _) => {
775                    let arr = batch
776                        .column(i)
777                        .as_primitive::<arrow_array::types::TimestampMicrosecondType>();
778                    timestamps = (0..num_rows).map(|r| arr.value(r)).collect();
779                }
780                DataType::Int64 => {
781                    // Convert i64 to f64 for JIT column access
782                    let arr = batch
783                        .column(i)
784                        .as_primitive::<arrow_array::types::Int64Type>();
785                    let col: Vec<f64> = (0..num_rows).map(|r| arr.value(r) as f64).collect();
786                    columns.push(col);
787                }
788                _ => {
789                    // Skip non-numeric columns (strings, bools, etc.)
790                }
791            }
792        }
793
794        let column_ptrs: Vec<*const f64> = columns.iter().map(|c| c.as_ptr()).collect();
795
796        Self {
797            columns,
798            column_ptrs,
799            timestamps,
800            row_count: num_rows,
801        }
802    }
803}
804
805impl Default for JITDataFrame {
806    fn default() -> Self {
807        Self::new()
808    }
809}
810
811// ============================================================================
812// JITConfig - Compilation Configuration
813// ============================================================================
814
815/// JIT compilation configuration
816#[derive(Debug, Clone)]
817pub struct JITConfig {
818    /// Optimization level (0-3)
819    pub opt_level: u8,
820    /// Enable debug symbols
821    pub debug_symbols: bool,
822    /// Minimum execution count before JIT compilation
823    pub jit_threshold: usize,
824}
825
826impl Default for JITConfig {
827    fn default() -> Self {
828        Self {
829            opt_level: 3,
830            debug_symbols: false,
831            jit_threshold: 100,
832        }
833    }
834}
835
836#[cfg(test)]
837mod tests {
838    use super::*;
839
840    #[test]
841    fn test_closure_dynamic_captures_0() {
842        // Zero captures — captures_ptr should be a valid (empty) allocation
843        let closure = JITClosure::new(42, &[]);
844        assert_eq!(closure.function_id, 42);
845        assert_eq!(closure.captures_count, 0);
846        // Drop is safe even with 0 captures
847        let mut closure = closure;
848        unsafe { closure.drop_captures() };
849    }
850
851    #[test]
852    fn test_closure_dynamic_captures_5() {
853        // Typical case: 5 captures
854        let captures = [
855            box_number(1.0),
856            box_number(2.0),
857            box_number(3.0),
858            TAG_BOOL_TRUE,
859            TAG_NULL,
860        ];
861        let closure = JITClosure::new(7, &captures);
862        assert_eq!(closure.function_id, 7);
863        assert_eq!(closure.captures_count, 5);
864
865        unsafe {
866            assert_eq!(unbox_number(closure.get_capture(0)), 1.0);
867            assert_eq!(unbox_number(closure.get_capture(1)), 2.0);
868            assert_eq!(unbox_number(closure.get_capture(2)), 3.0);
869            assert_eq!(closure.get_capture(3), TAG_BOOL_TRUE);
870            assert_eq!(closure.get_capture(4), TAG_NULL);
871        }
872    }
873
874    #[test]
875    fn test_closure_dynamic_captures_20() {
876        // Exceeds old 16-capture limit
877        let captures: Vec<u64> = (0..20).map(|i| box_number(i as f64)).collect();
878        let closure = JITClosure::new(99, &captures);
879        assert_eq!(closure.captures_count, 20);
880
881        unsafe {
882            for i in 0..20 {
883                assert_eq!(unbox_number(closure.get_capture(i)), i as f64);
884            }
885        }
886    }
887
888    #[test]
889    fn test_closure_dynamic_captures_64() {
890        // Stress test: 64 captures
891        let captures: Vec<u64> = (0..64).map(|i| box_number(i as f64 * 10.0)).collect();
892        let closure = JITClosure::new(1, &captures);
893        assert_eq!(closure.captures_count, 64);
894
895        unsafe {
896            for i in 0..64 {
897                assert_eq!(unbox_number(closure.get_capture(i)), i as f64 * 10.0);
898            }
899        }
900    }
901
902    #[test]
903    fn test_closure_captures_drop() {
904        // Verify memory is properly freed (no leak under Miri/ASAN)
905        let captures: Vec<u64> = (0..32).map(|i| box_number(i as f64)).collect();
906        let mut closure = JITClosure::new(5, &captures);
907        assert_eq!(closure.captures_count, 32);
908
909        // Verify captures are valid before drop
910        unsafe {
911            assert_eq!(unbox_number(closure.get_capture(0)), 0.0);
912            assert_eq!(unbox_number(closure.get_capture(31)), 31.0);
913        }
914
915        // Drop captures
916        unsafe { closure.drop_captures() };
917        assert!(closure.captures_ptr.is_null());
918        assert_eq!(closure.captures_count, 32); // count unchanged, ptr nulled
919    }
920
921    #[test]
922    fn test_closure_jit_box_roundtrip() {
923        // Verify JITClosure survives jit_box/jit_unbox roundtrip
924        let captures = [box_number(42.0), TAG_BOOL_FALSE];
925        let closure = JITClosure::new(10, &captures);
926        let bits = jit_box(HK_CLOSURE, *closure);
927
928        assert!(is_heap_kind(bits, HK_CLOSURE));
929
930        let recovered = unsafe { jit_unbox::<JITClosure>(bits) };
931        assert_eq!(recovered.function_id, 10);
932        assert_eq!(recovered.captures_count, 2);
933        unsafe {
934            assert_eq!(unbox_number(recovered.get_capture(0)), 42.0);
935            assert_eq!(recovered.get_capture(1), TAG_BOOL_FALSE);
936        }
937    }
938
939    #[test]
940    fn test_closure_drop_impl_frees_captures_via_jit_drop() {
941        // Verify the Drop impl on JITClosure frees the captures array
942        // when the owning JitAlloc is freed via jit_drop.
943        // Under Miri/ASAN this would catch a leak if Drop didn't work.
944        let captures: Vec<u64> = (0..24).map(|i| box_number(i as f64)).collect();
945        let closure = JITClosure::new(3, &captures);
946        let bits = jit_box(HK_CLOSURE, *closure);
947
948        // Read captures to confirm they're valid
949        let recovered = unsafe { jit_unbox::<JITClosure>(bits) };
950        assert_eq!(recovered.captures_count, 24);
951        unsafe {
952            assert_eq!(unbox_number(recovered.get_capture(23)), 23.0);
953        }
954
955        // jit_drop frees JitAlloc<JITClosure>, which calls Drop::drop on
956        // JITClosure, which frees the captures array.
957        unsafe { jit_drop::<JITClosure>(bits) };
958    }
959
960    #[test]
961    fn test_closure_implicit_drop_on_box() {
962        // Verify that simply dropping a Box<JITClosure> frees the captures.
963        // (This tests the Drop impl without jit_box involvement.)
964        let captures: Vec<u64> = (0..10).map(|i| box_number(i as f64)).collect();
965        let closure = JITClosure::new(1, &captures);
966        // closure is Box<JITClosure>, dropping it should free captures via Drop
967        drop(closure);
968        // No leak under Miri/ASAN
969    }
970}