Skip to main content

shape_runtime/simulation/
correlated_kernel.rs

1//! CorrelatedKernel - Multi-Table Simulation
2//!
3//! This module provides support for correlation analysis across multiple
4//! aligned time series for cross-sensor analysis, multi-asset backtesting, etc.
5
6use shape_ast::error::{Result, ShapeError};
7use shape_value::DataTable;
8use std::collections::HashMap;
9
10/// Schema for multi-table context, defining table names and their order.
11///
12/// CRITICAL: Table order is fixed at schema creation time and used for
13/// JIT compilation. The JIT compiler maps series names to indices at
14/// compile time, enabling `context.temperature` → `series_ptrs[0][cursor_idx]`.
15#[derive(Debug, Clone)]
16pub struct TableSchema {
17    /// Ordered list of table names (index = position in series_ptrs array)
18    names: Vec<String>,
19    /// Name to index mapping for compile-time resolution
20    name_to_index: HashMap<String, usize>,
21}
22
23impl TableSchema {
24    /// Create a new table schema from a list of names.
25    ///
26    /// The order of names determines their indices for JIT compilation.
27    pub fn new(names: Vec<String>) -> Self {
28        let name_to_index = names
29            .iter()
30            .enumerate()
31            .map(|(idx, name)| (name.clone(), idx))
32            .collect();
33        Self {
34            names,
35            name_to_index,
36        }
37    }
38
39    /// Create from a slice of string slices.
40    pub fn from_names(names: &[&str]) -> Self {
41        Self::new(names.iter().map(|s| s.to_string()).collect())
42    }
43
44    /// Get the index for a series name (used by JIT at compile time).
45    #[inline]
46    pub fn get_index(&self, name: &str) -> Option<usize> {
47        self.name_to_index.get(name).copied()
48    }
49
50    /// Get the number of series.
51    #[inline]
52    pub fn len(&self) -> usize {
53        self.names.len()
54    }
55
56    /// Check if schema is empty.
57    #[inline]
58    pub fn is_empty(&self) -> bool {
59        self.names.is_empty()
60    }
61
62    /// Get all series names in order.
63    pub fn names(&self) -> &[String] {
64        &self.names
65    }
66}
67
68/// Configuration for correlated kernel execution.
69#[derive(Debug, Clone)]
70pub struct CorrelatedKernelConfig {
71    /// Start tick (inclusive)
72    pub start: usize,
73    /// End tick (exclusive)
74    pub end: usize,
75    /// Warmup period
76    pub warmup: usize,
77}
78
79impl CorrelatedKernelConfig {
80    /// Create a config for the full range.
81    pub fn full(len: usize) -> Self {
82        Self {
83            start: 0,
84            end: len,
85            warmup: 0,
86        }
87    }
88
89    /// Create with warmup period.
90    pub fn with_warmup(len: usize, warmup: usize) -> Self {
91        Self {
92            start: 0,
93            end: len,
94            warmup,
95        }
96    }
97}
98
99/// Result of correlated kernel execution.
100#[derive(Debug)]
101pub struct CorrelatedKernelResult<S> {
102    /// Final state after all ticks
103    pub final_state: S,
104    /// Number of ticks processed
105    pub ticks_processed: usize,
106    /// Whether simulation completed
107    pub completed: bool,
108}
109
110/// The correlated simulation kernel for multi-series processing.
111///
112/// Enables correlation analysis across multiple aligned time series.
113pub struct CorrelatedKernel {
114    config: CorrelatedKernelConfig,
115}
116
117impl CorrelatedKernel {
118    /// Create a new correlated kernel.
119    pub fn new(config: CorrelatedKernelConfig) -> Self {
120        Self { config }
121    }
122
123    /// Run correlated simulation across multiple DataTables.
124    ///
125    /// Each DataTable represents a separate series. All tables must have
126    /// equal row counts. The strategy receives (tick_index, all_column_ptrs, state).
127    #[inline(always)]
128    pub fn run<S, F>(
129        &self,
130        tables: &[&DataTable],
131        schema: TableSchema,
132        mut initial_state: S,
133        mut strategy: F,
134    ) -> Result<CorrelatedKernelResult<S>>
135    where
136        F: FnMut(usize, &[*const f64], &TableSchema, &mut S) -> i32,
137    {
138        if tables.is_empty() {
139            return Err(ShapeError::RuntimeError {
140                message: "CorrelatedKernel requires at least one DataTable".to_string(),
141                location: None,
142            });
143        }
144
145        // Validate equal row counts
146        let row_count = tables[0].row_count();
147        for (i, table) in tables.iter().enumerate().skip(1) {
148            if table.row_count() != row_count {
149                return Err(ShapeError::RuntimeError {
150                    message: format!(
151                        "Table {} has {} rows but table 0 has {} rows",
152                        i,
153                        table.row_count(),
154                        row_count
155                    ),
156                    location: None,
157                });
158            }
159        }
160
161        // Flatten all f64 column pointers across all tables
162        let col_ptrs: Vec<*const f64> = tables
163            .iter()
164            .flat_map(|t| {
165                t.column_ptrs()
166                    .iter()
167                    .filter(|cp| cp.stride == 8)
168                    .map(|cp| cp.values_ptr as *const f64)
169            })
170            .collect();
171
172        let effective_start = self.config.start + self.config.warmup;
173        if effective_start >= self.config.end {
174            return Err(ShapeError::RuntimeError {
175                message: format!(
176                    "Warmup ({}) exceeds available range ({} - {})",
177                    self.config.warmup, self.config.start, self.config.end
178                ),
179                location: None,
180            });
181        }
182
183        let mut ticks_processed = 0;
184
185        for cursor_index in effective_start..self.config.end {
186            let result = strategy(cursor_index, &col_ptrs, &schema, &mut initial_state);
187            if result != 0 {
188                return Ok(CorrelatedKernelResult {
189                    final_state: initial_state,
190                    ticks_processed,
191                    completed: result == 1,
192                });
193            }
194            ticks_processed += 1;
195        }
196
197        Ok(CorrelatedKernelResult {
198            final_state: initial_state,
199            ticks_processed,
200            completed: true,
201        })
202    }
203}
204
205/// Convenience function to run a correlated simulation.
206pub fn simulate_correlated<S, F>(
207    tables: &[&DataTable],
208    schema: TableSchema,
209    initial_state: S,
210    strategy: F,
211) -> Result<CorrelatedKernelResult<S>>
212where
213    F: FnMut(usize, &[*const f64], &TableSchema, &mut S) -> i32,
214{
215    if tables.is_empty() {
216        return Err(ShapeError::RuntimeError {
217            message: "simulate_correlated requires at least one DataTable".to_string(),
218            location: None,
219        });
220    }
221    let config = CorrelatedKernelConfig::full(tables[0].row_count());
222    let kernel = CorrelatedKernel::new(config);
223    kernel.run(tables, schema, initial_state, strategy)
224}
225
226/// JIT-compiled correlated kernel function type.
227pub type CorrelatedKernelFn = unsafe extern "C" fn(
228    cursor_index: usize,
229    series_ptrs: *const *const f64,
230    series_count: usize,
231    state_ptr: *mut u8,
232) -> i32;
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_series_schema_basic() {
240        let schema = TableSchema::from_names(&["temp", "pressure"]);
241        assert_eq!(schema.len(), 2);
242        assert!(!schema.is_empty());
243        assert_eq!(schema.get_index("temp"), Some(0));
244        assert_eq!(schema.get_index("pressure"), Some(1));
245        assert_eq!(schema.get_index("missing"), None);
246        assert_eq!(
247            schema.names(),
248            &["temp".to_string(), "pressure".to_string()]
249        );
250    }
251
252    #[test]
253    fn test_series_schema_empty() {
254        let schema = TableSchema::from_names(&[]);
255        assert_eq!(schema.len(), 0);
256        assert!(schema.is_empty());
257        assert_eq!(schema.get_index("anything"), None);
258    }
259
260    /// Helper: build a DataTable with a single f64 column.
261    fn make_f64_table(name: &str, values: Vec<f64>) -> DataTable {
262        use arrow_array::{ArrayRef, Float64Array};
263        use arrow_schema::{DataType, Field, Schema};
264        use std::sync::Arc;
265
266        let schema = Schema::new(vec![Field::new(name, DataType::Float64, false)]);
267        let col: ArrayRef = Arc::new(Float64Array::from(values));
268        let batch = arrow_array::RecordBatch::try_new(Arc::new(schema), vec![col]).unwrap();
269        DataTable::new(batch)
270    }
271
272    #[test]
273    fn test_correlated_kernel_two_tables() {
274        // Two tables: "spy" prices and "vix" values
275        let spy_table = make_f64_table("price", vec![100.0, 102.0, 98.0, 105.0]);
276        let vix_table = make_f64_table("value", vec![15.0, 25.0, 30.0, 12.0]);
277
278        let schema = TableSchema::from_names(&["spy", "vix"]);
279        let config = CorrelatedKernelConfig::full(spy_table.row_count());
280        let kernel = CorrelatedKernel::new(config);
281
282        // Strategy: when VIX > 20 and position == 0, buy; when VIX < 15, sell
283        #[derive(Debug, Default)]
284        struct State {
285            position: f64,
286            cash: f64,
287            trades: u32,
288        }
289
290        let initial = State {
291            position: 0.0,
292            cash: 10000.0,
293            trades: 0,
294        };
295
296        let tables: Vec<&DataTable> = vec![&spy_table, &vix_table];
297
298        let result = kernel
299            .run(&tables, schema, initial, |idx, col_ptrs, schema, state| {
300                // col_ptrs[0] = spy price, col_ptrs[1] = vix value
301                let spy_idx = schema.get_index("spy").unwrap();
302                let vix_idx = schema.get_index("vix").unwrap();
303
304                let spy_price = unsafe { *col_ptrs[spy_idx].add(idx) };
305                let vix_value = unsafe { *col_ptrs[vix_idx].add(idx) };
306
307                if vix_value > 20.0 && state.position == 0.0 {
308                    // Buy
309                    let shares = (state.cash / spy_price).floor();
310                    state.cash -= shares * spy_price;
311                    state.position = shares;
312                    state.trades += 1;
313                } else if vix_value < 15.0 && state.position > 0.0 {
314                    // Sell
315                    state.cash += state.position * spy_price;
316                    state.position = 0.0;
317                    state.trades += 1;
318                }
319
320                0 // continue
321            })
322            .unwrap();
323
324        assert!(result.completed);
325        assert_eq!(result.ticks_processed, 4);
326        // VIX=25 at idx 1: buy at SPY=102, shares = floor(10000/102) = 98
327        // VIX=12 at idx 3: sell at SPY=105
328        assert_eq!(result.final_state.trades, 2);
329        assert_eq!(result.final_state.position, 0.0);
330        // Bought 98 at 102 = 9996, remaining cash = 10000-9996 = 4
331        // Sold 98 at 105 = 10290, total cash = 4 + 10290 = 10294
332        assert_eq!(result.final_state.cash, 10294.0);
333    }
334
335    #[test]
336    fn test_correlated_kernel_mismatched_rows() {
337        let table1 = make_f64_table("a", vec![1.0, 2.0, 3.0]);
338        let table2 = make_f64_table("b", vec![1.0, 2.0]); // different length
339
340        let schema = TableSchema::from_names(&["a", "b"]);
341        let tables: Vec<&DataTable> = vec![&table1, &table2];
342
343        let result = simulate_correlated(&tables, schema, 0.0_f64, |_idx, _ptrs, _s, _st| 0);
344        assert!(result.is_err());
345    }
346
347    #[test]
348    fn test_correlated_kernel_empty_tables() {
349        let schema = TableSchema::from_names(&["a"]);
350        let tables: Vec<&DataTable> = vec![];
351
352        let result = simulate_correlated(&tables, schema, 0.0_f64, |_idx, _ptrs, _s, _st| 0);
353        assert!(result.is_err());
354    }
355
356    #[test]
357    fn test_correlated_kernel_with_warmup() {
358        let table1 = make_f64_table("a", vec![1.0, 2.0, 3.0, 4.0, 5.0]);
359        let table2 = make_f64_table("b", vec![10.0, 20.0, 30.0, 40.0, 50.0]);
360
361        let schema = TableSchema::from_names(&["a", "b"]);
362        let config = CorrelatedKernelConfig::with_warmup(table1.row_count(), 2);
363        let kernel = CorrelatedKernel::new(config);
364
365        let tables: Vec<&DataTable> = vec![&table1, &table2];
366        let mut visited = Vec::new();
367
368        let result = kernel
369            .run(&tables, schema, 0.0_f64, |idx, col_ptrs, _schema, state| {
370                visited.push(idx);
371                unsafe {
372                    *state += *col_ptrs[0].add(idx) + *col_ptrs[1].add(idx);
373                }
374                0
375            })
376            .unwrap();
377
378        assert!(result.completed);
379        // Warmup=2, so should process indices 2, 3, 4
380        assert_eq!(visited, vec![2, 3, 4]);
381        // Sum: (3+30) + (4+40) + (5+50) = 33 + 44 + 55 = 132
382        assert_eq!(result.final_state, 132.0);
383    }
384
385    #[test]
386    fn test_correlated_kernel_early_stop() {
387        let table1 = make_f64_table("a", vec![1.0, 2.0, 3.0, 4.0]);
388        let table2 = make_f64_table("b", vec![10.0, 20.0, 30.0, 40.0]);
389
390        let schema = TableSchema::from_names(&["a", "b"]);
391        let config = CorrelatedKernelConfig::full(table1.row_count());
392        let kernel = CorrelatedKernel::new(config);
393        let tables: Vec<&DataTable> = vec![&table1, &table2];
394
395        let result = kernel
396            .run(&tables, schema, 0.0_f64, |idx, col_ptrs, _schema, state| {
397                let val = unsafe { *col_ptrs[1].add(idx) };
398                if val > 25.0 {
399                    return 1; // done
400                }
401                *state += val;
402                0
403            })
404            .unwrap();
405
406        assert!(result.completed); // 1 = normal completion
407        assert_eq!(result.ticks_processed, 2); // processed idx 0, 1; stopped at 2
408        assert_eq!(result.final_state, 30.0); // 10+20
409    }
410
411    #[test]
412    fn test_simulate_correlated_convenience() {
413        let table1 = make_f64_table("a", vec![1.0, 2.0, 3.0]);
414        let table2 = make_f64_table("b", vec![4.0, 5.0, 6.0]);
415
416        let schema = TableSchema::from_names(&["a", "b"]);
417        let tables: Vec<&DataTable> = vec![&table1, &table2];
418
419        let result = simulate_correlated(&tables, schema, 0.0_f64, |idx, col_ptrs, _s, state| {
420            unsafe {
421                *state += *col_ptrs[0].add(idx) * *col_ptrs[1].add(idx);
422            }
423            0
424        })
425        .unwrap();
426
427        assert!(result.completed);
428        assert_eq!(result.ticks_processed, 3);
429        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
430        assert_eq!(result.final_state, 32.0);
431    }
432}