1use shape_ast::error::{Result, ShapeError};
7use shape_value::DataTable;
8
9pub type SimulationKernelFn = unsafe extern "C" fn(
17 cursor_index: usize,
18 series_ptrs: *const *const f64,
19 state_ptr: *mut u8,
20) -> i32;
21
22#[derive(Debug, Clone, Default)]
30pub struct KernelCompileConfig {
31 pub state_field_offsets: Vec<(String, usize)>,
33 pub state_schema_id: u32,
35 pub column_map: Vec<(String, usize)>,
37 pub column_count: usize,
39}
40
41impl KernelCompileConfig {
42 pub fn new(schema_id: u32, column_count: usize) -> Self {
44 Self {
45 state_schema_id: schema_id,
46 column_count,
47 ..Default::default()
48 }
49 }
50
51 pub fn with_state_field(mut self, name: &str, offset: usize) -> Self {
53 self.state_field_offsets.push((name.to_string(), offset));
54 self
55 }
56
57 pub fn with_column(mut self, name: &str, index: usize) -> Self {
59 self.column_map.push((name.to_string(), index));
60 self
61 }
62}
63
64pub trait KernelCompiler: Send + Sync {
71 fn compile_kernel(
81 &self,
82 name: &str,
83 function_bytecode: &[u8],
84 config: &KernelCompileConfig,
85 ) -> std::result::Result<SimulationKernelFn, String>;
86
87 fn supports_feature(&self, feature: &str) -> bool {
89 match feature {
90 "typed_object" => true,
91 "closures" => false, _ => false,
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct DenseKernelConfig {
100 pub start: usize,
102 pub end: usize,
104 pub warmup: usize,
106}
107
108impl DenseKernelConfig {
109 pub fn full(len: usize) -> Self {
111 Self {
112 start: 0,
113 end: len,
114 warmup: 0,
115 }
116 }
117
118 pub fn with_warmup(len: usize, warmup: usize) -> Self {
120 Self {
121 start: 0,
122 end: len,
123 warmup,
124 }
125 }
126
127 pub fn range(start: usize, end: usize) -> Self {
129 Self {
130 start,
131 end,
132 warmup: 0,
133 }
134 }
135}
136
137#[derive(Debug)]
139pub struct DenseKernelResult<S> {
140 pub final_state: S,
142 pub ticks_processed: usize,
144 pub completed: bool,
146}
147
148pub struct DenseKernel {
155 config: DenseKernelConfig,
156}
157
158impl DenseKernel {
159 pub fn new(config: DenseKernelConfig) -> Self {
161 Self { config }
162 }
163
164 #[inline(always)]
169 pub fn run<S, F>(
170 &self,
171 data: &DataTable,
172 mut initial_state: S,
173 mut strategy: F,
174 ) -> Result<DenseKernelResult<S>>
175 where
176 F: FnMut(usize, &[*const f64], &mut S) -> i32,
177 {
178 let col_ptrs: Vec<*const f64> = data
180 .column_ptrs()
181 .iter()
182 .filter(|cp| cp.stride == 8)
183 .map(|cp| cp.values_ptr as *const f64)
184 .collect();
185
186 let effective_start = self.config.start + self.config.warmup;
187
188 if effective_start >= self.config.end {
189 return Err(ShapeError::RuntimeError {
190 message: format!(
191 "Warmup ({}) exceeds available range ({} - {})",
192 self.config.warmup, self.config.start, self.config.end
193 ),
194 location: None,
195 });
196 }
197
198 let mut ticks_processed = 0;
199
200 for cursor_index in effective_start..self.config.end {
201 let result = strategy(cursor_index, &col_ptrs, &mut initial_state);
202
203 if result != 0 {
204 return Ok(DenseKernelResult {
205 final_state: initial_state,
206 ticks_processed,
207 completed: result == 1,
208 });
209 }
210
211 ticks_processed += 1;
212 }
213
214 Ok(DenseKernelResult {
215 final_state: initial_state,
216 ticks_processed,
217 completed: true,
218 })
219 }
220
221 #[inline(always)]
238 pub unsafe fn run_jit(
239 &self,
240 column_ptrs: &[*const f64],
241 state_ptr: *mut u8,
242 kernel: SimulationKernelFn,
243 ) -> Result<DenseKernelResult<()>> {
244 let series_ptrs = column_ptrs.as_ptr();
245 let effective_start = self.config.start + self.config.warmup;
246
247 if effective_start >= self.config.end {
248 return Err(ShapeError::RuntimeError {
249 message: format!(
250 "Warmup ({}) exceeds available range ({} - {})",
251 self.config.warmup, self.config.start, self.config.end
252 ),
253 location: None,
254 });
255 }
256
257 let mut ticks_processed = 0;
258
259 for cursor_index in effective_start..self.config.end {
261 let result = unsafe { kernel(cursor_index, series_ptrs, state_ptr) };
262
263 if result != 0 {
264 return Ok(DenseKernelResult {
266 final_state: (),
267 ticks_processed,
268 completed: result == 1, });
270 }
271
272 ticks_processed += 1;
273 }
274
275 Ok(DenseKernelResult {
276 final_state: (),
277 ticks_processed,
278 completed: true,
279 })
280 }
281}
282
283pub fn simulate<S, F>(
288 data: &DataTable,
289 initial_state: S,
290 strategy: F,
291) -> Result<DenseKernelResult<S>>
292where
293 F: FnMut(usize, &[*const f64], &mut S) -> i32,
294{
295 let config = DenseKernelConfig::full(data.row_count());
296 let kernel = DenseKernel::new(config);
297 kernel.run(data, initial_state, strategy)
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_kernel_compile_config() {
306 let config = KernelCompileConfig::new(42, 3)
307 .with_state_field("cash", 0)
308 .with_state_field("position", 8)
309 .with_column("open", 0)
310 .with_column("close", 1)
311 .with_column("volume", 2);
312
313 assert_eq!(config.state_schema_id, 42);
314 assert_eq!(config.column_count, 3);
315 assert_eq!(config.state_field_offsets.len(), 2);
316 assert_eq!(config.column_map.len(), 3);
317 }
318
319 #[test]
320 fn test_dense_kernel_config() {
321 let config = DenseKernelConfig::full(100);
322 assert_eq!(config.start, 0);
323 assert_eq!(config.end, 100);
324 assert_eq!(config.warmup, 0);
325
326 let config = DenseKernelConfig::with_warmup(100, 10);
327 assert_eq!(config.warmup, 10);
328
329 let config = DenseKernelConfig::range(5, 50);
330 assert_eq!(config.start, 5);
331 assert_eq!(config.end, 50);
332 }
333
334 fn make_price_table(prices: Vec<f64>) -> DataTable {
336 use arrow_array::{ArrayRef, Float64Array};
337 use arrow_schema::{DataType, Field, Schema};
338 use std::sync::Arc;
339
340 let schema = Schema::new(vec![Field::new("price", DataType::Float64, false)]);
341 let col: ArrayRef = Arc::new(Float64Array::from(prices));
342 let batch = arrow_array::RecordBatch::try_new(Arc::new(schema), vec![col]).unwrap();
343 DataTable::new(batch)
344 }
345
346 #[test]
347 fn test_dense_kernel_run_sum() {
348 let table = make_price_table(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
349
350 let config = DenseKernelConfig::full(table.row_count());
351 let kernel = DenseKernel::new(config);
352
353 let result = kernel
355 .run(&table, 0.0_f64, |idx, col_ptrs, state| {
356 unsafe {
357 let price = *col_ptrs[0].add(idx);
358 *state += price;
359 }
360 0 })
362 .unwrap();
363
364 assert!(result.completed);
365 assert_eq!(result.ticks_processed, 5);
366 assert_eq!(result.final_state, 150.0); }
368
369 #[test]
370 fn test_dense_kernel_run_early_stop() {
371 let table = make_price_table(vec![10.0, 20.0, 100.0, 40.0, 50.0]);
372
373 let config = DenseKernelConfig::full(table.row_count());
374 let kernel = DenseKernel::new(config);
375
376 let result = kernel
378 .run(&table, 0.0_f64, |idx, col_ptrs, state| {
379 let price = unsafe { *col_ptrs[0].add(idx) };
380 if price > 50.0 {
381 return 1; }
383 *state += price;
384 0 })
386 .unwrap();
387
388 assert!(result.completed); assert_eq!(result.ticks_processed, 2); assert_eq!(result.final_state, 30.0); }
392
393 #[test]
394 fn test_dense_kernel_with_warmup() {
395 let table = make_price_table(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
396
397 let config = DenseKernelConfig::with_warmup(table.row_count(), 2);
398 let kernel = DenseKernel::new(config);
399
400 let mut processed_indices = Vec::new();
402 let result = kernel
403 .run(&table, 0.0_f64, |idx, col_ptrs, state| {
404 unsafe {
405 processed_indices.push(idx);
406 *state += *col_ptrs[0].add(idx);
407 }
408 0
409 })
410 .unwrap();
411
412 assert!(result.completed);
413 assert_eq!(processed_indices, vec![2, 3, 4]);
415 assert_eq!(result.ticks_processed, 3);
416 assert_eq!(result.final_state, 12.0); }
418
419 #[test]
420 fn test_dense_kernel_range() {
421 let table = make_price_table(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
422
423 let config = DenseKernelConfig::range(1, 4);
424 let kernel = DenseKernel::new(config);
425
426 let result = kernel
427 .run(&table, 0.0_f64, |idx, col_ptrs, state| {
428 unsafe { *state += *col_ptrs[0].add(idx) };
429 0
430 })
431 .unwrap();
432
433 assert!(result.completed);
434 assert_eq!(result.ticks_processed, 3); assert_eq!(result.final_state, 9.0); }
437
438 #[test]
439 fn test_dense_kernel_warmup_exceeds_range() {
440 let table = make_price_table(vec![1.0, 2.0, 3.0]);
441
442 let config = DenseKernelConfig::with_warmup(table.row_count(), 10);
444 let kernel = DenseKernel::new(config);
445
446 let result = kernel.run(&table, 0.0_f64, |_idx, _col_ptrs, _state| 0);
447 assert!(result.is_err());
448 }
449
450 #[test]
451 fn test_simulate_convenience_fn() {
452 let table = make_price_table(vec![10.0, 20.0, 30.0]);
453
454 let result = simulate(&table, 0.0_f64, |idx, col_ptrs, state| {
455 unsafe { *state += *col_ptrs[0].add(idx) };
456 0
457 })
458 .unwrap();
459
460 assert!(result.completed);
461 assert_eq!(result.ticks_processed, 3);
462 assert_eq!(result.final_state, 60.0);
463 }
464
465 #[test]
466 fn test_dense_kernel_struct_state() {
467 #[derive(Debug, Default)]
469 struct BacktestState {
470 cash: f64,
471 position: f64,
472 trades: u32,
473 }
474
475 let table = make_price_table(vec![100.0, 105.0, 110.0]);
476 let config = DenseKernelConfig::full(table.row_count());
477 let kernel = DenseKernel::new(config);
478
479 let initial = BacktestState {
480 cash: 10000.0,
481 position: 0.0,
482 trades: 0,
483 };
484
485 let result = kernel
487 .run(&table, initial, |idx, col_ptrs, state| {
488 let price = unsafe { *col_ptrs[0].add(idx) };
489 match idx {
490 0 => {
491 state.cash -= 10.0 * price;
493 state.position = 10.0;
494 state.trades += 1;
495 }
496 2 => {
497 state.cash += state.position * price;
499 state.position = 0.0;
500 state.trades += 1;
501 }
502 _ => {} }
504 0
505 })
506 .unwrap();
507
508 assert!(result.completed);
509 assert_eq!(result.final_state.trades, 2);
510 assert_eq!(result.final_state.position, 0.0);
511 assert_eq!(result.final_state.cash, 10100.0);
513 }
514
515 #[test]
516 fn test_dense_kernel_multi_column() {
517 use arrow_array::{ArrayRef, Float64Array};
519 use arrow_schema::{DataType, Field, Schema};
520 use std::sync::Arc;
521
522 let schema = Schema::new(vec![
523 Field::new("price", DataType::Float64, false),
524 Field::new("volume", DataType::Float64, false),
525 ]);
526 let prices: ArrayRef = Arc::new(Float64Array::from(vec![100.0, 105.0, 98.0]));
527 let volumes: ArrayRef = Arc::new(Float64Array::from(vec![1000.0, 2000.0, 1500.0]));
528 let batch =
529 arrow_array::RecordBatch::try_new(Arc::new(schema), vec![prices, volumes]).unwrap();
530 let table = DataTable::new(batch);
531
532 let config = DenseKernelConfig::full(table.row_count());
533 let kernel = DenseKernel::new(config);
534
535 let result = kernel
537 .run(&table, (0.0_f64, 0.0_f64), |idx, col_ptrs, state| {
538 unsafe {
539 let price = *col_ptrs[0].add(idx);
540 let volume = *col_ptrs[1].add(idx);
541 state.0 += price * volume; state.1 += volume; }
544 0
545 })
546 .unwrap();
547
548 let (weighted_sum, total_vol) = result.final_state;
549 let vwap = weighted_sum / total_vol;
550 assert!((vwap - 101.5556).abs() < 0.001);
553 }
554
555 #[test]
564 fn test_full_loop_csv_to_backtest() {
565 use arrow_array::{ArrayRef, Float64Array};
566 use arrow_schema::{DataType, Field, Schema};
567 use std::sync::Arc;
568
569 let opens = vec![
572 100.0, 102.0, 104.0, 106.0, 108.0, 110.0, 108.0, 106.0, 104.0, 102.0,
573 ];
574 let highs = vec![
575 103.0, 105.0, 107.0, 109.0, 111.0, 112.0, 110.0, 108.0, 106.0, 104.0,
576 ];
577 let lows = vec![
578 99.0, 101.0, 103.0, 105.0, 107.0, 108.0, 106.0, 104.0, 102.0, 100.0,
579 ];
580 let closes = vec![
581 102.0, 104.0, 106.0, 108.0, 110.0, 109.0, 107.0, 105.0, 103.0, 101.0,
582 ];
583 let vols = vec![
584 1000.0, 1200.0, 1100.0, 1300.0, 1500.0, 1400.0, 1600.0, 1100.0, 900.0, 800.0,
585 ];
586
587 let schema = Schema::new(vec![
588 Field::new("open", DataType::Float64, false),
589 Field::new("high", DataType::Float64, false),
590 Field::new("low", DataType::Float64, false),
591 Field::new("close", DataType::Float64, false),
592 Field::new("volume", DataType::Float64, false),
593 ]);
594 let batch = arrow_array::RecordBatch::try_new(
595 Arc::new(schema),
596 vec![
597 Arc::new(Float64Array::from(opens)) as ArrayRef,
598 Arc::new(Float64Array::from(highs)) as ArrayRef,
599 Arc::new(Float64Array::from(lows)) as ArrayRef,
600 Arc::new(Float64Array::from(closes.clone())) as ArrayRef,
601 Arc::new(Float64Array::from(vols)) as ArrayRef,
602 ],
603 )
604 .unwrap();
605 let table = DataTable::new(batch);
606
607 assert_eq!(table.row_count(), 10);
609 assert_eq!(table.column_count(), 5);
610 assert_eq!(
611 table.column_names(),
612 vec!["open", "high", "low", "close", "volume"]
613 );
614
615 let f64_col_count = table
618 .column_ptrs()
619 .iter()
620 .filter(|cp| cp.stride == 8)
621 .count();
622 assert_eq!(f64_col_count, 5, "All OHLCV columns must be f64 (stride 8)");
623
624 let config = DenseKernelConfig::with_warmup(table.row_count(), 1); let kernel = DenseKernel::new(config);
629
630 #[derive(Debug)]
631 struct BacktestState {
632 cash: f64,
633 position: f64,
634 entry_price: f64,
635 trades: u32,
636 wins: u32,
637 losses: u32,
638 total_pnl: f64,
639 }
640
641 let initial = BacktestState {
642 cash: 100_000.0,
643 position: 0.0,
644 entry_price: 0.0,
645 trades: 0,
646 wins: 0,
647 losses: 0,
648 total_pnl: 0.0,
649 };
650
651 let slippage_bps = 5.0;
652 let commission_pct = 0.1;
653
654 let result = kernel
655 .run(&table, initial, |idx, col_ptrs, state| {
656 let close = unsafe { *col_ptrs[3].add(idx) };
658 let prev_close = unsafe { *col_ptrs[3].add(idx - 1) };
659
660 let signal = if close > prev_close { "buy" } else { "sell" };
661
662 if signal == "buy" && state.position == 0.0 {
663 let slip = close * slippage_bps / 10_000.0;
665 let fill_price = close + slip;
666 let size = (state.cash * 0.1 / fill_price).floor(); if size > 0.0 {
668 let cost = fill_price * size;
669 let commission = cost * commission_pct / 100.0;
670 state.cash -= cost + commission;
671 state.position = size;
672 state.entry_price = fill_price;
673 }
674 } else if signal == "sell" && state.position > 0.0 {
675 let slip = close * slippage_bps / 10_000.0;
677 let fill_price = close - slip;
678 let proceeds = fill_price * state.position;
679 let commission = proceeds * commission_pct / 100.0;
680 let pnl = (fill_price - state.entry_price) * state.position - commission;
681
682 state.cash += proceeds - commission;
683 state.total_pnl += pnl;
684 state.trades += 1;
685 if pnl > 0.0 {
686 state.wins += 1;
687 } else {
688 state.losses += 1;
689 }
690 state.position = 0.0;
691 state.entry_price = 0.0;
692 }
693 0 })
695 .unwrap();
696
697 assert!(result.completed);
699 assert_eq!(result.ticks_processed, 9); let s = &result.final_state;
702
703 assert!(s.trades > 0, "Should have completed at least one trade");
705
706 assert!(s.total_pnl.is_finite(), "P&L should be finite");
709
710 let equity = if s.position > 0.0 {
712 s.cash + s.position * closes[9]
713 } else {
714 s.cash
715 };
716 assert!(equity > 0.0, "Equity should be positive");
717
718 assert_eq!(
720 s.wins + s.losses,
721 s.trades,
722 "wins + losses should equal total trades"
723 );
724
725 }
729
730 #[test]
737 fn test_csv_int64_column_compatibility() {
738 use arrow_array::{ArrayRef, Float64Array, Int64Array};
739 use arrow_schema::{DataType, Field, Schema};
740 use std::sync::Arc;
741
742 let schema = Schema::new(vec![
744 Field::new("close", DataType::Float64, false),
745 Field::new("volume", DataType::Int64, false), ]);
747 let closes: ArrayRef = Arc::new(Float64Array::from(vec![100.0, 105.0, 110.0]));
748 let volumes: ArrayRef = Arc::new(Int64Array::from(vec![1000_i64, 2000, 3000]));
749 let batch =
750 arrow_array::RecordBatch::try_new(Arc::new(schema), vec![closes, volumes]).unwrap();
751 let table = DataTable::new(batch);
752
753 let strides: Vec<usize> = table.column_ptrs().iter().map(|cp| cp.stride).collect();
755 assert_eq!(strides, vec![8, 8]);
756
757 let config = DenseKernelConfig::full(table.row_count());
769 let kernel = DenseKernel::new(config);
770
771 let result = kernel
772 .run(&table, 0.0_f64, |idx, col_ptrs, state| {
773 unsafe { *state += *col_ptrs[0].add(idx) };
775 0
776 })
777 .unwrap();
778
779 assert_eq!(result.final_state, 315.0); }
781}