1use shape_ast::error::{Result, ShapeError};
7use shape_value::DataTable;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
16pub struct TableSchema {
17 names: Vec<String>,
19 name_to_index: HashMap<String, usize>,
21}
22
23impl TableSchema {
24 pub fn new(names: Vec<String>) -> Self {
28 let name_to_index = names
29 .iter()
30 .enumerate()
31 .map(|(idx, name)| (name.clone(), idx))
32 .collect();
33 Self {
34 names,
35 name_to_index,
36 }
37 }
38
39 pub fn from_names(names: &[&str]) -> Self {
41 Self::new(names.iter().map(|s| s.to_string()).collect())
42 }
43
44 #[inline]
46 pub fn get_index(&self, name: &str) -> Option<usize> {
47 self.name_to_index.get(name).copied()
48 }
49
50 #[inline]
52 pub fn len(&self) -> usize {
53 self.names.len()
54 }
55
56 #[inline]
58 pub fn is_empty(&self) -> bool {
59 self.names.is_empty()
60 }
61
62 pub fn names(&self) -> &[String] {
64 &self.names
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct CorrelatedKernelConfig {
71 pub start: usize,
73 pub end: usize,
75 pub warmup: usize,
77}
78
79impl CorrelatedKernelConfig {
80 pub fn full(len: usize) -> Self {
82 Self {
83 start: 0,
84 end: len,
85 warmup: 0,
86 }
87 }
88
89 pub fn with_warmup(len: usize, warmup: usize) -> Self {
91 Self {
92 start: 0,
93 end: len,
94 warmup,
95 }
96 }
97}
98
99#[derive(Debug)]
101pub struct CorrelatedKernelResult<S> {
102 pub final_state: S,
104 pub ticks_processed: usize,
106 pub completed: bool,
108}
109
110pub struct CorrelatedKernel {
114 config: CorrelatedKernelConfig,
115}
116
117impl CorrelatedKernel {
118 pub fn new(config: CorrelatedKernelConfig) -> Self {
120 Self { config }
121 }
122
123 #[inline(always)]
128 pub fn run<S, F>(
129 &self,
130 tables: &[&DataTable],
131 schema: TableSchema,
132 mut initial_state: S,
133 mut strategy: F,
134 ) -> Result<CorrelatedKernelResult<S>>
135 where
136 F: FnMut(usize, &[*const f64], &TableSchema, &mut S) -> i32,
137 {
138 if tables.is_empty() {
139 return Err(ShapeError::RuntimeError {
140 message: "CorrelatedKernel requires at least one DataTable".to_string(),
141 location: None,
142 });
143 }
144
145 let row_count = tables[0].row_count();
147 for (i, table) in tables.iter().enumerate().skip(1) {
148 if table.row_count() != row_count {
149 return Err(ShapeError::RuntimeError {
150 message: format!(
151 "Table {} has {} rows but table 0 has {} rows",
152 i,
153 table.row_count(),
154 row_count
155 ),
156 location: None,
157 });
158 }
159 }
160
161 let col_ptrs: Vec<*const f64> = tables
163 .iter()
164 .flat_map(|t| {
165 t.column_ptrs()
166 .iter()
167 .filter(|cp| cp.stride == 8)
168 .map(|cp| cp.values_ptr as *const f64)
169 })
170 .collect();
171
172 let effective_start = self.config.start + self.config.warmup;
173 if effective_start >= self.config.end {
174 return Err(ShapeError::RuntimeError {
175 message: format!(
176 "Warmup ({}) exceeds available range ({} - {})",
177 self.config.warmup, self.config.start, self.config.end
178 ),
179 location: None,
180 });
181 }
182
183 let mut ticks_processed = 0;
184
185 for cursor_index in effective_start..self.config.end {
186 let result = strategy(cursor_index, &col_ptrs, &schema, &mut initial_state);
187 if result != 0 {
188 return Ok(CorrelatedKernelResult {
189 final_state: initial_state,
190 ticks_processed,
191 completed: result == 1,
192 });
193 }
194 ticks_processed += 1;
195 }
196
197 Ok(CorrelatedKernelResult {
198 final_state: initial_state,
199 ticks_processed,
200 completed: true,
201 })
202 }
203}
204
205pub fn simulate_correlated<S, F>(
207 tables: &[&DataTable],
208 schema: TableSchema,
209 initial_state: S,
210 strategy: F,
211) -> Result<CorrelatedKernelResult<S>>
212where
213 F: FnMut(usize, &[*const f64], &TableSchema, &mut S) -> i32,
214{
215 if tables.is_empty() {
216 return Err(ShapeError::RuntimeError {
217 message: "simulate_correlated requires at least one DataTable".to_string(),
218 location: None,
219 });
220 }
221 let config = CorrelatedKernelConfig::full(tables[0].row_count());
222 let kernel = CorrelatedKernel::new(config);
223 kernel.run(tables, schema, initial_state, strategy)
224}
225
226pub type CorrelatedKernelFn = unsafe extern "C" fn(
228 cursor_index: usize,
229 series_ptrs: *const *const f64,
230 series_count: usize,
231 state_ptr: *mut u8,
232) -> i32;
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_series_schema_basic() {
240 let schema = TableSchema::from_names(&["temp", "pressure"]);
241 assert_eq!(schema.len(), 2);
242 assert!(!schema.is_empty());
243 assert_eq!(schema.get_index("temp"), Some(0));
244 assert_eq!(schema.get_index("pressure"), Some(1));
245 assert_eq!(schema.get_index("missing"), None);
246 assert_eq!(
247 schema.names(),
248 &["temp".to_string(), "pressure".to_string()]
249 );
250 }
251
252 #[test]
253 fn test_series_schema_empty() {
254 let schema = TableSchema::from_names(&[]);
255 assert_eq!(schema.len(), 0);
256 assert!(schema.is_empty());
257 assert_eq!(schema.get_index("anything"), None);
258 }
259
260 fn make_f64_table(name: &str, values: Vec<f64>) -> DataTable {
262 use arrow_array::{ArrayRef, Float64Array};
263 use arrow_schema::{DataType, Field, Schema};
264 use std::sync::Arc;
265
266 let schema = Schema::new(vec![Field::new(name, DataType::Float64, false)]);
267 let col: ArrayRef = Arc::new(Float64Array::from(values));
268 let batch = arrow_array::RecordBatch::try_new(Arc::new(schema), vec![col]).unwrap();
269 DataTable::new(batch)
270 }
271
272 #[test]
273 fn test_correlated_kernel_two_tables() {
274 let spy_table = make_f64_table("price", vec![100.0, 102.0, 98.0, 105.0]);
276 let vix_table = make_f64_table("value", vec![15.0, 25.0, 30.0, 12.0]);
277
278 let schema = TableSchema::from_names(&["spy", "vix"]);
279 let config = CorrelatedKernelConfig::full(spy_table.row_count());
280 let kernel = CorrelatedKernel::new(config);
281
282 #[derive(Debug, Default)]
284 struct State {
285 position: f64,
286 cash: f64,
287 trades: u32,
288 }
289
290 let initial = State {
291 position: 0.0,
292 cash: 10000.0,
293 trades: 0,
294 };
295
296 let tables: Vec<&DataTable> = vec![&spy_table, &vix_table];
297
298 let result = kernel
299 .run(&tables, schema, initial, |idx, col_ptrs, schema, state| {
300 let spy_idx = schema.get_index("spy").unwrap();
302 let vix_idx = schema.get_index("vix").unwrap();
303
304 let spy_price = unsafe { *col_ptrs[spy_idx].add(idx) };
305 let vix_value = unsafe { *col_ptrs[vix_idx].add(idx) };
306
307 if vix_value > 20.0 && state.position == 0.0 {
308 let shares = (state.cash / spy_price).floor();
310 state.cash -= shares * spy_price;
311 state.position = shares;
312 state.trades += 1;
313 } else if vix_value < 15.0 && state.position > 0.0 {
314 state.cash += state.position * spy_price;
316 state.position = 0.0;
317 state.trades += 1;
318 }
319
320 0 })
322 .unwrap();
323
324 assert!(result.completed);
325 assert_eq!(result.ticks_processed, 4);
326 assert_eq!(result.final_state.trades, 2);
329 assert_eq!(result.final_state.position, 0.0);
330 assert_eq!(result.final_state.cash, 10294.0);
333 }
334
335 #[test]
336 fn test_correlated_kernel_mismatched_rows() {
337 let table1 = make_f64_table("a", vec![1.0, 2.0, 3.0]);
338 let table2 = make_f64_table("b", vec![1.0, 2.0]); let schema = TableSchema::from_names(&["a", "b"]);
341 let tables: Vec<&DataTable> = vec![&table1, &table2];
342
343 let result = simulate_correlated(&tables, schema, 0.0_f64, |_idx, _ptrs, _s, _st| 0);
344 assert!(result.is_err());
345 }
346
347 #[test]
348 fn test_correlated_kernel_empty_tables() {
349 let schema = TableSchema::from_names(&["a"]);
350 let tables: Vec<&DataTable> = vec![];
351
352 let result = simulate_correlated(&tables, schema, 0.0_f64, |_idx, _ptrs, _s, _st| 0);
353 assert!(result.is_err());
354 }
355
356 #[test]
357 fn test_correlated_kernel_with_warmup() {
358 let table1 = make_f64_table("a", vec![1.0, 2.0, 3.0, 4.0, 5.0]);
359 let table2 = make_f64_table("b", vec![10.0, 20.0, 30.0, 40.0, 50.0]);
360
361 let schema = TableSchema::from_names(&["a", "b"]);
362 let config = CorrelatedKernelConfig::with_warmup(table1.row_count(), 2);
363 let kernel = CorrelatedKernel::new(config);
364
365 let tables: Vec<&DataTable> = vec![&table1, &table2];
366 let mut visited = Vec::new();
367
368 let result = kernel
369 .run(&tables, schema, 0.0_f64, |idx, col_ptrs, _schema, state| {
370 visited.push(idx);
371 unsafe {
372 *state += *col_ptrs[0].add(idx) + *col_ptrs[1].add(idx);
373 }
374 0
375 })
376 .unwrap();
377
378 assert!(result.completed);
379 assert_eq!(visited, vec![2, 3, 4]);
381 assert_eq!(result.final_state, 132.0);
383 }
384
385 #[test]
386 fn test_correlated_kernel_early_stop() {
387 let table1 = make_f64_table("a", vec![1.0, 2.0, 3.0, 4.0]);
388 let table2 = make_f64_table("b", vec![10.0, 20.0, 30.0, 40.0]);
389
390 let schema = TableSchema::from_names(&["a", "b"]);
391 let config = CorrelatedKernelConfig::full(table1.row_count());
392 let kernel = CorrelatedKernel::new(config);
393 let tables: Vec<&DataTable> = vec![&table1, &table2];
394
395 let result = kernel
396 .run(&tables, schema, 0.0_f64, |idx, col_ptrs, _schema, state| {
397 let val = unsafe { *col_ptrs[1].add(idx) };
398 if val > 25.0 {
399 return 1; }
401 *state += val;
402 0
403 })
404 .unwrap();
405
406 assert!(result.completed); assert_eq!(result.ticks_processed, 2); assert_eq!(result.final_state, 30.0); }
410
411 #[test]
412 fn test_simulate_correlated_convenience() {
413 let table1 = make_f64_table("a", vec![1.0, 2.0, 3.0]);
414 let table2 = make_f64_table("b", vec![4.0, 5.0, 6.0]);
415
416 let schema = TableSchema::from_names(&["a", "b"]);
417 let tables: Vec<&DataTable> = vec![&table1, &table2];
418
419 let result = simulate_correlated(&tables, schema, 0.0_f64, |idx, col_ptrs, _s, state| {
420 unsafe {
421 *state += *col_ptrs[0].add(idx) * *col_ptrs[1].add(idx);
422 }
423 0
424 })
425 .unwrap();
426
427 assert!(result.completed);
428 assert_eq!(result.ticks_processed, 3);
429 assert_eq!(result.final_state, 32.0);
431 }
432}