1use super::event_scheduler::{EventQueue, ScheduledEvent};
7use shape_ast::error::{Result, ShapeError};
8use shape_value::DataTable;
9
10#[derive(Debug)]
12pub struct HybridKernelResult<S> {
13 pub final_state: S,
15 pub ticks_processed: usize,
17 pub events_processed: usize,
19 pub completed: bool,
21}
22
23#[derive(Debug, Clone)]
25pub struct HybridKernelConfig {
26 pub start: usize,
28 pub end: usize,
30 pub warmup: usize,
32}
33
34impl HybridKernelConfig {
35 pub fn full(len: usize) -> Self {
37 Self {
38 start: 0,
39 end: len,
40 warmup: 0,
41 }
42 }
43
44 pub fn with_warmup(len: usize, warmup: usize) -> Self {
46 Self {
47 start: 0,
48 end: len,
49 warmup,
50 }
51 }
52}
53
54pub type EventHandlerFn<S> = fn(&ScheduledEvent, &mut S, &mut EventQueue) -> Result<()>;
56
57pub struct HybridKernel {
59 config: HybridKernelConfig,
60}
61
62impl HybridKernel {
63 pub fn new(config: HybridKernelConfig) -> Self {
65 Self { config }
66 }
67
68 pub fn run<S, F>(
72 &self,
73 data: &DataTable,
74 mut initial_state: S,
75 mut event_queue: EventQueue,
76 mut tick_strategy: F,
77 event_handler: EventHandlerFn<S>,
78 ) -> Result<HybridKernelResult<S>>
79 where
80 F: FnMut(usize, &[*const f64], &mut S) -> i32,
81 {
82 let col_ptrs: Vec<*const f64> = data
83 .column_ptrs()
84 .iter()
85 .filter(|cp| cp.stride == 8)
86 .map(|cp| cp.values_ptr as *const f64)
87 .collect();
88
89 let effective_start = self.config.start + self.config.warmup;
90 if effective_start >= self.config.end {
91 return Err(ShapeError::RuntimeError {
92 message: format!(
93 "Warmup ({}) exceeds available range ({} - {})",
94 self.config.warmup, self.config.start, self.config.end
95 ),
96 location: None,
97 });
98 }
99
100 let mut ticks_processed = 0;
101 let mut events_processed = 0;
102
103 for cursor_index in effective_start..self.config.end {
104 let result = tick_strategy(cursor_index, &col_ptrs, &mut initial_state);
106 if result != 0 {
107 return Ok(HybridKernelResult {
108 final_state: initial_state,
109 ticks_processed,
110 events_processed,
111 completed: result == 1,
112 });
113 }
114 ticks_processed += 1;
115
116 while let Some(event) = event_queue.pop_due(cursor_index as i64) {
118 event_handler(&event, &mut initial_state, &mut event_queue)?;
119 events_processed += 1;
120 }
121 }
122
123 Ok(HybridKernelResult {
124 final_state: initial_state,
125 ticks_processed,
126 events_processed,
127 completed: true,
128 })
129 }
130}
131
132pub fn simulate_hybrid<S, F>(
134 data: &DataTable,
135 initial_state: S,
136 event_queue: EventQueue,
137 tick_strategy: F,
138 event_handler: EventHandlerFn<S>,
139) -> Result<HybridKernelResult<S>>
140where
141 F: FnMut(usize, &[*const f64], &mut S) -> i32,
142{
143 let config = HybridKernelConfig::full(data.row_count());
144 let kernel = HybridKernel::new(config);
145 kernel.run(
146 data,
147 initial_state,
148 event_queue,
149 tick_strategy,
150 event_handler,
151 )
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 fn make_f64_table(name: &str, values: Vec<f64>) -> DataTable {
160 use arrow_array::{ArrayRef, Float64Array};
161 use arrow_schema::{DataType, Field, Schema};
162 use std::sync::Arc;
163
164 let schema = Schema::new(vec![Field::new(name, DataType::Float64, false)]);
165 let col: ArrayRef = Arc::new(Float64Array::from(values));
166 let batch = arrow_array::RecordBatch::try_new(Arc::new(schema), vec![col]).unwrap();
167 DataTable::new(batch)
168 }
169
170 #[test]
171 fn test_hybrid_kernel_ticks_only() {
172 let table = make_f64_table("price", vec![10.0, 20.0, 30.0]);
174 let event_queue = EventQueue::new();
175
176 fn no_op_handler(
177 _event: &ScheduledEvent,
178 _state: &mut f64,
179 _queue: &mut EventQueue,
180 ) -> Result<()> {
181 panic!("Should not be called with no events");
182 }
183
184 let result = simulate_hybrid(
185 &table,
186 0.0_f64,
187 event_queue,
188 |idx, col_ptrs, state| {
189 unsafe { *state += *col_ptrs[0].add(idx) };
190 0
191 },
192 no_op_handler,
193 )
194 .unwrap();
195
196 assert!(result.completed);
197 assert_eq!(result.ticks_processed, 3);
198 assert_eq!(result.events_processed, 0);
199 assert_eq!(result.final_state, 60.0); }
201
202 #[test]
203 fn test_hybrid_kernel_with_scheduled_events() {
204 let table = make_f64_table("price", vec![100.0, 105.0, 110.0, 108.0, 112.0]);
205
206 let mut event_queue = EventQueue::new();
207 event_queue.schedule(2, 1, 0);
209 event_queue.schedule(4, 1, 0);
211
212 #[derive(Debug, Default)]
213 struct State {
214 sum: f64,
215 rebalance_count: u32,
216 }
217
218 fn rebalance_handler(
219 _event: &ScheduledEvent,
220 state: &mut State,
221 _queue: &mut EventQueue,
222 ) -> Result<()> {
223 state.rebalance_count += 1;
224 Ok(())
225 }
226
227 let config = HybridKernelConfig::full(table.row_count());
228 let kernel = HybridKernel::new(config);
229
230 let result = kernel
231 .run(
232 &table,
233 State::default(),
234 event_queue,
235 |idx, col_ptrs, state| {
236 unsafe { state.sum += *col_ptrs[0].add(idx) };
237 0
238 },
239 rebalance_handler,
240 )
241 .unwrap();
242
243 assert!(result.completed);
244 assert_eq!(result.ticks_processed, 5);
245 assert_eq!(result.events_processed, 2);
246 assert_eq!(result.final_state.rebalance_count, 2);
247 assert_eq!(result.final_state.sum, 535.0); }
249
250 #[test]
251 fn test_hybrid_kernel_event_spawns_event() {
252 let table = make_f64_table("price", vec![1.0, 2.0, 3.0, 4.0, 5.0]);
254
255 let mut event_queue = EventQueue::new();
256 event_queue.schedule(1, 1, 0);
258
259 #[derive(Debug, Default)]
260 struct State {
261 events_seen: u32,
262 }
263
264 fn cascading_handler(
265 event: &ScheduledEvent,
266 state: &mut State,
267 queue: &mut EventQueue,
268 ) -> Result<()> {
269 state.events_seen += 1;
270 if event.time == 1 {
272 queue.schedule(3, 2, 0);
273 }
274 Ok(())
275 }
276
277 let result = simulate_hybrid(
278 &table,
279 State::default(),
280 event_queue,
281 |_idx, _col_ptrs, _state| 0,
282 cascading_handler,
283 )
284 .unwrap();
285
286 assert!(result.completed);
287 assert_eq!(result.ticks_processed, 5);
288 assert_eq!(result.events_processed, 2); assert_eq!(result.final_state.events_seen, 2);
290 }
291
292 #[test]
293 fn test_hybrid_kernel_multiple_events_same_tick() {
294 let table = make_f64_table("price", vec![1.0, 2.0, 3.0]);
295
296 let mut event_queue = EventQueue::new();
297 event_queue.schedule(1, 10, 0);
299 event_queue.schedule(1, 20, 0);
300 event_queue.schedule(1, 30, 0);
301
302 fn counting_handler(
303 _event: &ScheduledEvent,
304 state: &mut u32,
305 _queue: &mut EventQueue,
306 ) -> Result<()> {
307 *state += 1;
308 Ok(())
309 }
310
311 let result = simulate_hybrid(
312 &table,
313 0_u32,
314 event_queue,
315 |_idx, _col_ptrs, _state| 0,
316 counting_handler,
317 )
318 .unwrap();
319
320 assert!(result.completed);
321 assert_eq!(result.events_processed, 3);
322 assert_eq!(result.final_state, 3);
323 }
324
325 #[test]
326 fn test_hybrid_kernel_tick_early_stop() {
327 let table = make_f64_table("price", vec![1.0, 2.0, 3.0, 4.0, 5.0]);
328
329 let mut event_queue = EventQueue::new();
330 event_queue.schedule(4, 1, 0); fn noop_handler(
333 _event: &ScheduledEvent,
334 _state: &mut u32,
335 _queue: &mut EventQueue,
336 ) -> Result<()> {
337 panic!("Should not fire - tick stops before tick 4");
338 }
339
340 let result = simulate_hybrid(
341 &table,
342 0_u32,
343 event_queue,
344 |idx, _col_ptrs, state| {
345 *state += 1;
346 if idx == 2 {
347 1 } else {
349 0
350 }
351 },
352 noop_handler,
353 )
354 .unwrap();
355
356 assert!(result.completed); assert_eq!(result.ticks_processed, 2); assert_eq!(result.events_processed, 0);
359 }
360
361 #[test]
362 fn test_hybrid_kernel_with_warmup() {
363 let table = make_f64_table("price", vec![1.0, 2.0, 3.0, 4.0, 5.0]);
364
365 let mut event_queue = EventQueue::new();
366 event_queue.schedule(0, 1, 0);
368 event_queue.schedule(3, 2, 0);
370
371 fn handler(
372 _event: &ScheduledEvent,
373 state: &mut u32,
374 _queue: &mut EventQueue,
375 ) -> Result<()> {
376 *state += 100;
377 Ok(())
378 }
379
380 let config = HybridKernelConfig::with_warmup(table.row_count(), 2);
381 let kernel = HybridKernel::new(config);
382
383 let result = kernel
384 .run(
385 &table,
386 0_u32,
387 event_queue,
388 |_idx, _col_ptrs, state| {
389 *state += 1;
390 0
391 },
392 handler,
393 )
394 .unwrap();
395
396 assert!(result.completed);
397 assert_eq!(result.ticks_processed, 3); assert_eq!(result.events_processed, 2);
401 assert_eq!(result.final_state, 203);
403 }
404
405 #[test]
406 fn test_hybrid_kernel_warmup_exceeds_range() {
407 let table = make_f64_table("price", vec![1.0, 2.0]);
408 let config = HybridKernelConfig::with_warmup(table.row_count(), 10);
409 let kernel = HybridKernel::new(config);
410
411 fn noop_handler(
412 _event: &ScheduledEvent,
413 _state: &mut f64,
414 _queue: &mut EventQueue,
415 ) -> Result<()> {
416 Ok(())
417 }
418
419 let result = kernel.run(
420 &table,
421 0.0_f64,
422 EventQueue::new(),
423 |_idx, _col_ptrs, _state| 0,
424 noop_handler,
425 );
426 assert!(result.is_err());
427 }
428}