1use shape_ast::error::Result;
7use shape_value::ValueWord;
8use std::sync::Arc;
9
10type Value = ValueWord;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
15pub enum SimulationMode {
16 #[default]
18 Batch,
19 Stream,
21}
22
23#[derive(Debug, Clone)]
25pub struct SimulationEvent {
26 pub index: usize,
28 pub event_type: String,
30 pub data: Value,
32}
33
34#[derive(Debug, Clone)]
36pub struct StepResult {
37 pub state: Value,
39 pub events: Vec<SimulationEvent>,
41 pub continue_processing: bool,
43}
44
45impl StepResult {
46 pub fn with_state(state: Value) -> Self {
48 Self {
49 state,
50 events: vec![],
51 continue_processing: true,
52 }
53 }
54
55 pub fn with_events(state: Value, events: Vec<SimulationEvent>) -> Self {
57 Self {
58 state,
59 events,
60 continue_processing: true,
61 }
62 }
63
64 pub fn stop(state: Value) -> Self {
66 Self {
67 state,
68 events: vec![],
69 continue_processing: false,
70 }
71 }
72}
73
74pub type StepHandler = Box<dyn Fn(&Value, &Value, usize) -> Result<StepResult>>;
76
77pub struct SimulationEngineConfig {
79 pub mode: SimulationMode,
81 pub initial_state: Value,
83 pub max_events: usize,
85 pub track_state_history: bool,
87}
88
89impl Default for SimulationEngineConfig {
90 fn default() -> Self {
91 Self {
92 mode: SimulationMode::default(),
93 initial_state: ValueWord::none(),
94 max_events: 0,
95 track_state_history: false,
96 }
97 }
98}
99
100impl SimulationEngineConfig {
101 pub fn new() -> Self {
102 Self::default()
103 }
104
105 pub fn with_initial_state(mut self, state: Value) -> Self {
106 self.initial_state = state;
107 self
108 }
109
110 pub fn with_mode(mut self, mode: SimulationMode) -> Self {
111 self.mode = mode;
112 self
113 }
114
115 pub fn with_max_events(mut self, max: usize) -> Self {
116 self.max_events = max;
117 self
118 }
119
120 pub fn tracking_state_history(mut self) -> Self {
121 self.track_state_history = true;
122 self
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct SimulationEngineResult {
129 pub final_state: Value,
131 pub events: Vec<SimulationEvent>,
133 pub state_history: Option<Vec<Value>>,
135 pub elements_processed: usize,
137 pub completed: bool,
139}
140
141impl SimulationEngineResult {
142 pub fn to_value(&self) -> Value {
144 let events_value: Vec<ValueWord> = self
146 .events
147 .iter()
148 .map(|e| {
149 crate::type_schema::typed_object_from_nb_pairs(&[
150 ("index", ValueWord::from_f64(e.index as f64)),
151 (
152 "type",
153 ValueWord::from_string(Arc::new(e.event_type.clone())),
154 ),
155 ("data", e.data.clone()),
156 ])
157 })
158 .collect();
159
160 let mut nb_pairs: Vec<(&str, ValueWord)> = vec![
161 ("final_state", self.final_state.clone()),
162 (
163 "elements_processed",
164 ValueWord::from_f64(self.elements_processed as f64),
165 ),
166 ("completed", ValueWord::from_bool(self.completed)),
167 ("events", ValueWord::from_array(Arc::new(events_value))),
168 ];
169
170 if let Some(history) = &self.state_history {
171 let history_nb: Vec<ValueWord> = history.iter().map(|v| v.clone()).collect();
172 nb_pairs.push(("state_history", ValueWord::from_array(Arc::new(history_nb))));
173 }
174
175 crate::type_schema::typed_object_from_nb_pairs(&nb_pairs).clone()
176 }
177}
178
179pub struct SimulationEngine {
181 config: SimulationEngineConfig,
182}
183
184impl SimulationEngine {
185 pub fn new(config: SimulationEngineConfig) -> Self {
187 Self { config }
188 }
189
190 pub fn default_config() -> Self {
192 Self::new(SimulationEngineConfig::default())
193 }
194
195 pub fn run<F>(&self, data: &[Value], handler: F) -> Result<SimulationEngineResult>
200 where
201 F: Fn(&Value, &Value, usize) -> Result<StepResult>,
202 {
203 let mut state = self.config.initial_state.clone();
204 let mut events = Vec::new();
205 let mut state_history = if self.config.track_state_history {
206 Some(Vec::with_capacity(data.len()))
207 } else {
208 None
209 };
210 let mut completed = true;
211
212 for (i, value) in data.iter().enumerate() {
213 let step_result = handler(value, &state, i)?;
215
216 state = step_result.state;
218
219 if let Some(ref mut history) = state_history {
221 history.push(state.clone());
222 }
223
224 for event in step_result.events {
226 if self.config.max_events == 0 || events.len() < self.config.max_events {
227 events.push(event);
228 }
229 }
230
231 if !step_result.continue_processing {
233 completed = false;
234 break;
235 }
236 }
237
238 Ok(SimulationEngineResult {
239 final_state: state,
240 events,
241 state_history,
242 elements_processed: data.len(),
243 completed,
244 })
245 }
246
247 pub fn run_with_hooks<F, Pre, Post>(
249 &self,
250 data: &[Value],
251 pre_process: Pre,
252 handler: F,
253 post_process: Post,
254 ) -> Result<SimulationEngineResult>
255 where
256 F: Fn(&Value, &Value, usize) -> Result<StepResult>,
257 Pre: Fn(&Value) -> Result<Value>,
258 Post: Fn(SimulationEngineResult) -> Result<SimulationEngineResult>,
259 {
260 let initial_state = pre_process(&self.config.initial_state)?;
262
263 let modified_engine = SimulationEngine::new(SimulationEngineConfig {
265 initial_state,
266 ..self.config.clone()
267 });
268
269 let result = modified_engine.run(data, handler)?;
271
272 post_process(result)
274 }
275}
276
277impl Clone for SimulationEngineConfig {
279 fn clone(&self) -> Self {
280 Self {
281 mode: self.mode,
282 initial_state: self.config_value_clone(&self.initial_state),
283 max_events: self.max_events,
284 track_state_history: self.track_state_history,
285 }
286 }
287}
288
289impl SimulationEngineConfig {
290 fn config_value_clone(&self, value: &Value) -> Value {
291 value.clone()
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use std::sync::Arc;
299
300 #[test]
301 fn test_simulation_engine_basic() {
302 let config = SimulationEngineConfig::new().with_initial_state(ValueWord::from_f64(0.0));
303
304 let engine = SimulationEngine::new(config);
305
306 let data = vec![
307 ValueWord::from_f64(1.0),
308 ValueWord::from_f64(2.0),
309 ValueWord::from_f64(3.0),
310 ];
311
312 let result = engine
314 .run(&data, |value, state, _idx| {
315 let v = value.as_f64().unwrap_or(0.0);
316 let s = state.as_f64().unwrap_or(0.0);
317 Ok(StepResult::with_state(ValueWord::from_f64(v + s)))
318 })
319 .unwrap();
320
321 assert_eq!(result.elements_processed, 3);
322 assert!(result.completed);
323 assert_eq!(result.final_state.as_f64(), Some(6.0));
324 }
325
326 #[test]
327 fn test_simulation_engine_with_events() {
328 let config = SimulationEngineConfig::new().with_initial_state(ValueWord::from_f64(0.0));
329
330 let engine = SimulationEngine::new(config);
331
332 let data = vec![
333 ValueWord::from_f64(5.0),
334 ValueWord::from_f64(15.0), ValueWord::from_f64(8.0),
336 ];
337
338 let result = engine
340 .run(&data, |value, state, idx| {
341 let mut events = vec![];
342
343 if let Some(v) = value.as_f64() {
344 if v > 10.0 {
345 events.push(SimulationEvent {
346 index: idx,
347 event_type: "threshold_exceeded".to_string(),
348 data: value.clone(),
349 });
350 }
351 }
352
353 Ok(StepResult::with_events(state.clone(), events))
354 })
355 .unwrap();
356
357 assert_eq!(result.events.len(), 1);
358 assert_eq!(result.events[0].event_type, "threshold_exceeded");
359 assert_eq!(result.events[0].index, 1);
360 }
361
362 #[test]
363 fn test_simulation_engine_early_stop() {
364 let config = SimulationEngineConfig::new().with_initial_state(ValueWord::from_f64(0.0));
365
366 let engine = SimulationEngine::new(config);
367
368 let data = vec![
369 ValueWord::from_f64(1.0),
370 ValueWord::from_f64(-1.0), ValueWord::from_f64(3.0),
372 ];
373
374 let result = engine
376 .run(&data, |value, state, _idx| {
377 if let Some(v) = value.as_f64() {
378 if v < 0.0 {
379 return Ok(StepResult::stop(state.clone()));
380 }
381 }
382 Ok(StepResult::with_state(state.clone()))
383 })
384 .unwrap();
385
386 assert!(!result.completed);
388 }
389
390 #[test]
391 fn test_simulation_result_to_value() {
392 let result = SimulationEngineResult {
393 final_state: ValueWord::from_f64(42.0),
394 events: vec![SimulationEvent {
395 index: 5,
396 event_type: "test".to_string(),
397 data: ValueWord::from_string(Arc::new("data".to_string())),
398 }],
399 state_history: None,
400 elements_processed: 10,
401 completed: true,
402 };
403
404 let value = result.to_value();
405 let obj =
406 crate::type_schema::typed_object_to_hashmap_nb(&value).expect("Expected TypedObject");
407 assert!(obj.contains_key("final_state"));
408 assert!(obj.contains_key("events"));
409 assert!(obj.contains_key("elements_processed"));
410 assert!(obj.contains_key("completed"));
411 }
412
413 #[test]
414 fn test_state_history_tracking() {
415 let config = SimulationEngineConfig::new()
416 .with_initial_state(ValueWord::from_f64(0.0))
417 .tracking_state_history();
418
419 let engine = SimulationEngine::new(config);
420 let data = vec![
421 ValueWord::from_f64(1.0),
422 ValueWord::from_f64(2.0),
423 ValueWord::from_f64(3.0),
424 ];
425
426 let result = engine
427 .run(&data, |value, state, _idx| {
428 let v = value.as_f64().unwrap_or(0.0);
429 let s = state.as_f64().unwrap_or(0.0);
430 Ok(StepResult::with_state(ValueWord::from_f64(v + s)))
431 })
432 .unwrap();
433
434 let history = result
435 .state_history
436 .as_ref()
437 .expect("Should have state history");
438 assert_eq!(history.len(), 3);
439 assert_eq!(history[0].as_f64(), Some(1.0));
441 assert_eq!(history[1].as_f64(), Some(3.0));
443 assert_eq!(history[2].as_f64(), Some(6.0));
445 }
446
447 #[test]
448 fn test_state_maintained_across_steps() {
449 let initial_state = crate::type_schema::typed_object_from_nb_pairs(&[
451 ("cash", ValueWord::from_f64(10000.0)),
452 ("position", ValueWord::from_f64(0.0)),
453 ]);
454
455 let config = SimulationEngineConfig::new()
456 .with_initial_state(initial_state)
457 .tracking_state_history();
458
459 let engine = SimulationEngine::new(config);
460
461 let data = vec![
463 ValueWord::from_f64(100.0), ValueWord::from_f64(105.0), ValueWord::from_f64(110.0), ];
467
468 let result = engine
469 .run(&data, |value, state, idx| {
470 let price = match value.as_f64() {
471 Some(p) => p,
472 None => return Ok(StepResult::with_state(state.clone())),
473 };
474 let obj = crate::type_schema::typed_object_to_hashmap_nb(state).unwrap_or_default();
475 let cash = obj.get("cash").and_then(|v| v.as_f64()).unwrap_or(0.0);
476 let position = obj.get("position").and_then(|v| v.as_f64()).unwrap_or(0.0);
477
478 let (new_cash, new_position) = match idx {
479 0 => (cash - 10.0 * price, 10.0), 2 => (cash + position * price, 0.0), _ => (cash, position), };
483 let new_state = crate::type_schema::typed_object_from_nb_pairs(&[
484 ("cash", ValueWord::from_f64(new_cash)),
485 ("position", ValueWord::from_f64(new_position)),
486 ]);
487 Ok(StepResult::with_state(new_state))
488 })
489 .unwrap();
490
491 assert!(result.completed);
492 assert_eq!(result.elements_processed, 3);
493
494 let obj = crate::type_schema::typed_object_to_hashmap_nb(&result.final_state)
496 .expect("Expected TypedObject for final state");
497 {
498 let cash = obj
499 .get("cash")
500 .and_then(|v| v.as_f64())
501 .expect("Expected cash");
502 let position = obj
503 .get("position")
504 .and_then(|v| v.as_f64())
505 .expect("Expected position");
506 assert_eq!(cash, 10100.0);
508 assert_eq!(position, 0.0);
509 }
510
511 let history = result.state_history.as_ref().unwrap();
513 assert_eq!(history.len(), 3);
514 }
515
516 #[test]
517 fn test_max_events_limit() {
518 let config = SimulationEngineConfig::new()
519 .with_initial_state(ValueWord::from_f64(0.0))
520 .with_max_events(2);
521
522 let engine = SimulationEngine::new(config);
523 let data = vec![
524 ValueWord::from_f64(1.0),
525 ValueWord::from_f64(2.0),
526 ValueWord::from_f64(3.0),
527 ValueWord::from_f64(4.0),
528 ValueWord::from_f64(5.0),
529 ];
530
531 let result = engine
532 .run(&data, |value, state, idx| {
533 let events = vec![SimulationEvent {
534 index: idx,
535 event_type: "tick".to_string(),
536 data: value.clone(),
537 }];
538 Ok(StepResult::with_events(state.clone(), events))
539 })
540 .unwrap();
541
542 assert_eq!(result.events.len(), 2);
544 assert_eq!(result.events[0].index, 0);
545 assert_eq!(result.events[1].index, 1);
546 }
547
548 #[test]
549 fn test_run_with_hooks() {
550 let config = SimulationEngineConfig::new().with_initial_state(ValueWord::from_f64(0.0));
551
552 let engine = SimulationEngine::new(config);
553 let data = vec![
554 ValueWord::from_f64(1.0),
555 ValueWord::from_f64(2.0),
556 ValueWord::from_f64(3.0),
557 ];
558
559 let result = engine
560 .run_with_hooks(
561 &data,
562 |_initial| Ok(ValueWord::from_f64(10.0)),
564 |value, state, _idx| {
566 let v = value.as_f64().unwrap_or(0.0);
567 let s = state.as_f64().unwrap_or(0.0);
568 Ok(StepResult::with_state(ValueWord::from_f64(v + s)))
569 },
570 |mut result| {
572 if let Some(n) = result.final_state.as_f64() {
573 result.final_state = ValueWord::from_f64(n * 2.0);
574 }
575 Ok(result)
576 },
577 )
578 .unwrap();
579
580 assert_eq!(result.final_state.as_f64(), Some(32.0));
582 }
583
584 #[test]
585 fn test_empty_data() {
586 let config = SimulationEngineConfig::new().with_initial_state(ValueWord::from_f64(42.0));
587
588 let engine = SimulationEngine::new(config);
589 let data: Vec<Value> = vec![];
590
591 let result = engine
592 .run(&data, |_value, _state, _idx| {
593 panic!("Should not be called on empty data");
594 })
595 .unwrap();
596
597 assert!(result.completed);
598 assert_eq!(result.elements_processed, 0);
599 assert_eq!(result.final_state.as_f64(), Some(42.0));
600 }
601
602 #[test]
603 fn test_result_to_value_with_history() {
604 let result = SimulationEngineResult {
605 final_state: ValueWord::from_f64(10.0),
606 events: vec![],
607 state_history: Some(vec![
608 ValueWord::from_f64(1.0),
609 ValueWord::from_f64(5.0),
610 ValueWord::from_f64(10.0),
611 ]),
612 elements_processed: 3,
613 completed: true,
614 };
615
616 let value = result.to_value();
617 let obj =
618 crate::type_schema::typed_object_to_hashmap_nb(&value).expect("Expected TypedObject");
619 assert!(obj.contains_key("state_history"));
620 if let Some(history_nb) = obj.get("state_history") {
621 let history = history_nb
622 .as_any_array()
623 .expect("Expected array")
624 .to_generic();
625 assert_eq!(history.len(), 3);
626 } else {
627 panic!("Expected state_history field");
628 }
629 }
630}