1use crate::event_bus::EventBus;
7use crate::filter_library::FilterLibrary;
8use somatize_compiler::ExecutionPlan;
9use somatize_core::cache::CacheStore;
10use somatize_core::error::{Result, SomaError};
11use somatize_core::event::Event;
12use somatize_core::store::DataStore;
13use somatize_core::value::Value;
14use somatize_core::virtual_value::VirtualValue;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Instant;
18
19#[derive(Debug, Clone, Default)]
24pub struct GraphInfo {
25 predecessors: HashMap<String, Vec<String>>,
27}
28
29impl GraphInfo {
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn set_predecessors(&mut self, node_id: impl Into<String>, preds: Vec<String>) {
36 self.predecessors.insert(node_id.into(), preds);
37 }
38
39 pub fn from_graph(graph: &somatize_core::graph::Graph) -> Self {
41 let mut info = Self::new();
42 for node in &graph.nodes {
43 let preds: Vec<String> = graph
44 .predecessors(&node.id)
45 .into_iter()
46 .map(|s| s.to_string())
47 .collect();
48 info.set_predecessors(node.id.clone(), preds);
49 }
50 info
51 }
52
53 pub fn for_linear(node_ids: &[&str]) -> Self {
55 let mut info = Self::new();
56 for (i, &id) in node_ids.iter().enumerate() {
57 let preds = if i > 0 {
58 vec![node_ids[i - 1].to_string()]
59 } else {
60 vec![]
61 };
62 info.set_predecessors(id, preds);
63 }
64 info
65 }
66
67 pub fn predecessors(&self, node_id: &str) -> &[String] {
69 self.predecessors
70 .get(node_id)
71 .map(|v| v.as_slice())
72 .unwrap_or(&[])
73 }
74}
75
76pub struct Context {
82 pub store: HashMap<String, VirtualValue>,
84 pub event_bus: Arc<EventBus>,
86 pub run_id: String,
88 pub execution_order: Vec<String>,
90 pub graph_info: GraphInfo,
92 pub transport: Option<Arc<dyn crate::runner::Transport>>,
94 pub data_store: Option<Arc<dyn DataStore>>,
96 pub spill_threshold: usize,
99}
100
101impl Context {
102 pub fn new(event_bus: Arc<EventBus>, run_id: impl Into<String>) -> Self {
103 Self {
104 store: HashMap::new(),
105 event_bus,
106 run_id: run_id.into(),
107 execution_order: Vec::new(),
108 graph_info: GraphInfo::new(),
109 transport: None,
110 data_store: None,
111 spill_threshold: 0,
112 }
113 }
114
115 pub fn with_graph_info(mut self, info: GraphInfo) -> Self {
116 self.graph_info = info;
117 self
118 }
119
120 pub fn with_transport(mut self, transport: Arc<dyn crate::runner::Transport>) -> Self {
121 self.transport = Some(transport);
122 self
123 }
124
125 pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
126 self.data_store = Some(store);
127 self
128 }
129
130 pub fn with_spill_threshold(mut self, bytes: usize) -> Self {
134 self.spill_threshold = bytes;
135 self
136 }
137
138 fn maybe_spill(&self, node_id: &str, value: Value) -> VirtualValue {
141 if self.spill_threshold > 0
142 && let Some(store) = &self.data_store
143 {
144 let size = value.size() * 8; if size >= self.spill_threshold {
146 let key = somatize_core::cache::CacheKey::from_parts(&[
147 self.run_id.as_bytes(),
148 node_id.as_bytes(),
149 ]);
150 let vv_for_schema = VirtualValue::materialized(value.clone());
151 let schema = vv_for_schema.schema().clone();
152 if let Ok(_data_ref) = store.put(&key, &value) {
153 tracing::debug!("spilled node `{node_id}` ({size} bytes) to DataStore");
154 return VirtualValue::cached(key, schema);
155 }
156 }
157 }
158 VirtualValue::materialized(value)
159 }
160
161 pub fn get(&self, node_id: &str) -> Option<&Value> {
163 self.store.get(node_id).and_then(|vv| vv.as_value())
164 }
165
166 pub fn get_virtual(&self, node_id: &str) -> Option<&VirtualValue> {
168 self.store.get(node_id)
169 }
170
171 pub fn set(&mut self, node_id: impl Into<String>, value: Value) {
173 let id = node_id.into();
174 self.execution_order.push(id.clone());
175 self.store.insert(id, VirtualValue::materialized(value));
176 }
177
178 pub fn set_virtual(&mut self, node_id: impl Into<String>, vv: VirtualValue) {
180 let id = node_id.into();
181 self.execution_order.push(id.clone());
182 self.store.insert(id, vv);
183 }
184
185 fn snapshot(&self) -> Self {
186 Self {
187 store: self.store.clone(),
188 event_bus: self.event_bus.clone(),
189 run_id: self.run_id.clone(),
190 execution_order: self.execution_order.clone(),
191 graph_info: self.graph_info.clone(),
192 transport: self.transport.clone(),
193 data_store: self.data_store.clone(),
194 spill_threshold: self.spill_threshold,
195 }
196 }
197}
198
199pub trait Executable {
203 fn execute(
204 &self,
205 ctx: &mut Context,
206 filters: &FilterLibrary,
207 cache: &dyn CacheStore,
208 ) -> Result<()>;
209}
210
211impl Executable for ExecutionPlan {
212 fn execute(
213 &self,
214 ctx: &mut Context,
215 filters: &FilterLibrary,
216 cache: &dyn CacheStore,
217 ) -> Result<()> {
218 match self {
219 ExecutionPlan::Empty => Ok(()),
220
221 ExecutionPlan::Execute { node_id } => execute_node(node_id, ctx, filters, cache),
222
223 ExecutionPlan::Cached { node_id, key } => {
224 let start = Instant::now();
225 let value = cache.get(key)?.ok_or_else(|| {
226 SomaError::Cache(format!(
227 "expected cached value for node `{node_id}` not found"
228 ))
229 })?;
230 ctx.set(node_id.clone(), value);
231 ctx.event_bus.emit(Event::NodeCacheHit {
232 run_id: ctx.run_id.clone(),
233 node_id: node_id.clone(),
234 key: key.clone(),
235 tier: somatize_core::cache::CacheTier::Memory,
236 load_time: start.elapsed(),
237 });
238 Ok(())
239 }
240
241 ExecutionPlan::Sequence(steps) => {
242 for step in steps {
243 step.execute(ctx, filters, cache)?;
244 }
245 Ok(())
246 }
247
248 ExecutionPlan::Parallel(branches) => execute_parallel(branches, ctx, filters, cache),
249
250 ExecutionPlan::Loop {
251 node_id,
252 body,
253 max_iterations,
254 } => {
255 let max = max_iterations.unwrap_or(100);
256 for i in 0..max {
257 body.execute(ctx, filters, cache)?;
258
259 let should_stop = ctx
262 .execution_order
263 .last()
264 .and_then(|last_id| ctx.get(last_id))
265 .map(|v| match v {
266 Value::Json(j) => {
267 j.as_bool() == Some(true)
268 || j.as_str().map(|s| s == "done" || s == "stop") == Some(true)
269 || j.get("done").and_then(|d| d.as_bool()) == Some(true)
270 }
271 Value::Empty => true,
272 _ => false,
273 })
274 .unwrap_or(false);
275
276 if should_stop {
277 ctx.event_bus.emit(Event::NodeCompleted {
278 run_id: ctx.run_id.clone(),
279 node_id: node_id.clone(),
280 duration: std::time::Duration::ZERO,
281 output_summary: format!("Loop terminated at iteration {}", i + 1),
282 });
283 break;
284 }
285 }
286 Ok(())
287 }
288
289 ExecutionPlan::Branch { node_id, arms } => {
290 execute_node(node_id, ctx, filters, cache)?;
292
293 let condition = ctx.get(node_id).cloned().unwrap_or(Value::Empty);
295
296 let selected_arm = match &condition {
298 Value::Json(j) => {
299 let selector = j
301 .as_str()
302 .map(String::from)
303 .or_else(|| j.as_bool().map(|b| b.to_string()))
304 .or_else(|| j.get("branch").and_then(|b| b.as_str()).map(String::from))
305 .unwrap_or_else(|| "true".to_string());
306
307 arms.iter()
308 .find(|(label, _)| label == &selector)
309 .or_else(|| {
310 arms.iter()
311 .find(|(label, _)| label == "default" || label == "else")
312 })
313 .or_else(|| arms.first())
314 }
315 _ => arms.first(),
316 };
317
318 if let Some((label, plan)) = selected_arm {
319 ctx.event_bus.emit(Event::NodeCompleted {
320 run_id: ctx.run_id.clone(),
321 node_id: node_id.clone(),
322 duration: std::time::Duration::ZERO,
323 output_summary: format!("Branch selected: {label}"),
324 });
325 plan.execute(ctx, filters, cache)?;
326 }
327 Ok(())
328 }
329
330 ExecutionPlan::Remote {
331 node_id,
332 target: _,
333 plan,
334 } => {
335 if let Some(transport) = &ctx.transport {
336 let input = ctx
338 .graph_info
339 .predecessors(node_id)
340 .first()
341 .and_then(|pred| ctx.get(pred));
342
343 let result = transport.execute_node(node_id, input)?;
344 ctx.set(node_id.clone(), result);
345 Ok(())
346 } else {
347 plan.execute(ctx, filters, cache)
349 }
350 }
351
352 ExecutionPlan::Composite { node_ids } => {
353 for nid in node_ids {
356 execute_node(nid, ctx, filters, cache)?;
357 }
358 Ok(())
359 }
360
361 _ => {
362 tracing::warn!("Unhandled ExecutionPlan variant");
363 Ok(())
364 }
365 }
366 }
367}
368
369pub fn execute(
371 plan: &ExecutionPlan,
372 ctx: &mut Context,
373 filters: &FilterLibrary,
374 cache: &dyn CacheStore,
375) -> Result<()> {
376 plan.execute(ctx, filters, cache)
377}
378
379fn execute_node(
381 node_id: &str,
382 ctx: &mut Context,
383 filters: &FilterLibrary,
384 _cache: &dyn CacheStore,
385) -> Result<()> {
386 let start = Instant::now();
387
388 let filter = filters
389 .get(node_id)
390 .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
391
392 ctx.event_bus.emit(Event::NodeStarted {
393 run_id: ctx.run_id.clone(),
394 node_id: node_id.to_string(),
395 kind: filter.meta().kind,
396 });
397
398 let _span = tracing::info_span!("execute_node", %node_id).entered();
399
400 let input = resolve_input(node_id, ctx);
401 let state = filters.get_state(node_id).cloned().unwrap_or(Value::Empty);
402
403 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
405 filter.forward(&input, &state)
406 }));
407
408 let result = match result {
409 Ok(inner) => inner,
410 Err(panic) => {
411 let msg = panic
412 .downcast_ref::<String>()
413 .map(|s| s.as_str())
414 .or_else(|| panic.downcast_ref::<&str>().copied())
415 .unwrap_or("unknown panic");
416 tracing::error!(node_id, "filter panicked: {msg}");
417 Err(SomaError::Execution {
418 node_id: node_id.to_string(),
419 message: format!("filter panicked: {msg}"),
420 })
421 }
422 };
423
424 match result {
425 Ok(output) => {
426 let duration = start.elapsed();
427 let summary = format!("{output}");
428 let vv = ctx.maybe_spill(node_id, output);
429 ctx.set_virtual(node_id, vv);
430 ctx.event_bus.emit(Event::NodeCompleted {
431 run_id: ctx.run_id.clone(),
432 node_id: node_id.to_string(),
433 duration,
434 output_summary: summary,
435 });
436 Ok(())
437 }
438 Err(e) => {
439 tracing::error!(node_id, error = %e, "node execution failed");
440 ctx.event_bus.emit(Event::NodeFailed {
441 run_id: ctx.run_id.clone(),
442 node_id: node_id.to_string(),
443 error: e.to_string(),
444 });
445 Err(e)
446 }
447 }
448}
449
450fn execute_parallel(
455 branches: &[ExecutionPlan],
456 ctx: &mut Context,
457 filters: &FilterLibrary,
458 cache: &dyn CacheStore,
459) -> Result<()> {
460 let snapshot_keys: Arc<std::collections::HashSet<String>> =
461 Arc::new(ctx.store.keys().cloned().collect());
462
463 let results: Vec<Result<Vec<(String, VirtualValue)>>> = std::thread::scope(|s| {
465 let handles: Vec<_> = branches
466 .iter()
467 .map(|branch| {
468 let mut branch_ctx = ctx.snapshot();
469 let keys = snapshot_keys.clone();
470 s.spawn(move || {
471 execute(branch, &mut branch_ctx, filters, cache)?;
472 let new_entries: Vec<(String, VirtualValue)> = branch_ctx
473 .store
474 .into_iter()
475 .filter(|(k, _)| !keys.contains(k))
476 .collect();
477 Ok(new_entries)
478 })
479 })
480 .collect();
481
482 handles.into_iter().map(|h| h.join().unwrap()).collect()
483 });
484
485 for result in results {
487 let entries = result?;
488 for (key, vv) in entries {
489 ctx.set_virtual(key, vv);
490 }
491 }
492
493 Ok(())
494}
495
496fn resolve_value(vv: &VirtualValue, data_store: &Option<Arc<dyn DataStore>>) -> Option<Value> {
498 match vv {
499 VirtualValue::Materialized { value, .. } => Some(value.clone()),
500 VirtualValue::Cached { key, .. } => {
501 if let Some(store) = data_store {
503 let data_ref = somatize_core::store::DataRef::Cached {
504 cache_key: key.clone(),
505 };
506 store.get(&data_ref).ok()
507 } else {
508 None
509 }
510 }
511 _ => None,
512 }
513}
514
515pub fn resolve_input(node_id: &str, ctx: &Context) -> Value {
518 let preds = ctx.graph_info.predecessors(node_id);
519
520 let resolve_node = |id: &str| -> Option<Value> {
521 ctx.store
522 .get(id)
523 .and_then(|vv| resolve_value(vv, &ctx.data_store))
524 };
525
526 match preds.len() {
527 0 => ctx
528 .execution_order
529 .last()
530 .and_then(|id| resolve_node(id))
531 .unwrap_or(Value::Empty),
532 1 => resolve_node(&preds[0]).unwrap_or(Value::Empty),
533 _ => {
534 let mut merged = serde_json::Map::new();
535 for pred_id in preds {
536 if let Some(val) = resolve_node(pred_id) {
537 let json_val = serde_json::to_value(&val).unwrap_or(serde_json::Value::Null);
538 merged.insert(pred_id.clone(), json_val);
539 }
540 }
541 Value::Json(serde_json::Value::Object(merged))
542 }
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549 use crate::cache::MemoryCache;
550 use somatize_core::cache::CacheKey;
551 use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
552
553 struct DoublerFilter;
554
555 impl Filter for DoublerFilter {
556 fn config_hash(&self) -> CacheKey {
557 CacheKey::from_parts(&[b"Doubler"])
558 }
559 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
560 Ok(Value::Empty)
561 }
562 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
563 match x {
564 Value::Tensor { values, shape } => {
565 let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
566 Ok(Value::tensor(doubled, shape.clone()))
567 }
568 _ => Ok(x.clone()),
569 }
570 }
571 fn meta(&self) -> FilterMeta {
572 FilterMeta {
573 name: "Doubler".into(),
574 kind: FilterKind::Stateless,
575 cacheable: true,
576 differentiable: true,
577 stream_mode: StreamMode::FixedState,
578 distribution: somatize_core::filter::Distribution::Local,
579 input_schema: None,
580 output_schema: None,
581 }
582 }
583
584 fn as_any(&self) -> &dyn std::any::Any {
585 self
586 }
587 }
588
589 struct AdderFilter {
590 amount: f64,
591 }
592
593 impl Filter for AdderFilter {
594 fn config_hash(&self) -> CacheKey {
595 CacheKey::from_parts(&[b"Adder", &self.amount.to_le_bytes()])
596 }
597 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
598 Ok(Value::Empty)
599 }
600 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
601 match x {
602 Value::Tensor { values, shape } => {
603 let added: Vec<f64> = values.iter().map(|v| v + self.amount).collect();
604 Ok(Value::tensor(added, shape.clone()))
605 }
606 _ => Ok(x.clone()),
607 }
608 }
609 fn meta(&self) -> FilterMeta {
610 FilterMeta {
611 name: "Adder".into(),
612 kind: FilterKind::Stateless,
613 cacheable: true,
614 differentiable: true,
615 stream_mode: StreamMode::FixedState,
616 distribution: somatize_core::filter::Distribution::Local,
617 input_schema: None,
618 output_schema: None,
619 }
620 }
621
622 fn as_any(&self) -> &dyn std::any::Any {
623 self
624 }
625 }
626
627 struct SlowFilter {
629 id: String,
630 delay_ms: u64,
631 }
632
633 impl Filter for SlowFilter {
634 fn config_hash(&self) -> CacheKey {
635 CacheKey::from_parts(&[b"Slow", self.id.as_bytes()])
636 }
637 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
638 Ok(Value::Empty)
639 }
640 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
641 std::thread::sleep(std::time::Duration::from_millis(self.delay_ms));
642 Ok(x.clone())
643 }
644 fn meta(&self) -> FilterMeta {
645 FilterMeta {
646 name: format!("Slow_{}", self.id),
647 kind: FilterKind::Stateless,
648 cacheable: false,
649 differentiable: true,
650 stream_mode: StreamMode::FixedState,
651 distribution: somatize_core::filter::Distribution::Local,
652 input_schema: None,
653 output_schema: None,
654 }
655 }
656
657 fn as_any(&self) -> &dyn std::any::Any {
658 self
659 }
660 }
661
662 fn setup() -> (Arc<EventBus>, MemoryCache) {
663 (Arc::new(EventBus::new(64)), MemoryCache::default())
664 }
665
666 #[test]
667 fn execute_single_node() {
668 let (bus, cache) = setup();
669 let mut ctx = Context::new(bus, "run_1");
670 ctx.set("input", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
671 ctx.graph_info
672 .set_predecessors("doubler", vec!["input".into()]);
673
674 let mut filters = FilterLibrary::new();
675 filters.register("doubler", Box::new(DoublerFilter));
676
677 let plan = ExecutionPlan::Execute {
678 node_id: "doubler".into(),
679 };
680
681 execute(&plan, &mut ctx, &filters, &cache).unwrap();
682
683 let result = ctx.get("doubler").unwrap();
684 let (data, _) = result.as_tensor().unwrap();
685 assert_eq!(data, &[2.0, 4.0, 6.0]);
686 }
687
688 #[test]
689 fn execute_sequence_with_graph_info() {
690 let (bus, cache) = setup();
691 let mut ctx = Context::new(bus, "run_1");
692 ctx.set("input", Value::tensor(vec![1.0, 2.0], vec![2]));
693
694 let graph_info = GraphInfo::for_linear(&["input", "add", "double"]);
695 ctx.graph_info = graph_info;
696
697 let mut filters = FilterLibrary::new();
698 filters.register("add", Box::new(AdderFilter { amount: 10.0 }));
699 filters.register("double", Box::new(DoublerFilter));
700
701 let plan = ExecutionPlan::Sequence(vec![
702 ExecutionPlan::Execute {
703 node_id: "add".into(),
704 },
705 ExecutionPlan::Execute {
706 node_id: "double".into(),
707 },
708 ]);
709
710 execute(&plan, &mut ctx, &filters, &cache).unwrap();
711
712 let result = ctx.get("double").unwrap();
713 let (data, _) = result.as_tensor().unwrap();
714 assert_eq!(data, &[22.0, 24.0]);
715 }
716
717 #[test]
718 fn execute_cached_node() {
719 let (bus, cache) = setup();
720 let key = CacheKey::hash_data(b"cached_result");
721 let cached_value = Value::tensor(vec![99.0], vec![1]);
722 cache.put(&key, &cached_value).unwrap();
723
724 let mut ctx = Context::new(bus, "run_1");
725 let filters = FilterLibrary::new();
726
727 let plan = ExecutionPlan::Cached {
728 node_id: "cached_node".into(),
729 key,
730 };
731
732 execute(&plan, &mut ctx, &filters, &cache).unwrap();
733 assert_eq!(*ctx.get("cached_node").unwrap(), cached_value);
734 }
735
736 #[test]
737 fn execute_emits_events() {
738 let bus = Arc::new(EventBus::new(64));
739 let cache = MemoryCache::default();
740 let mut rx = bus.subscribe();
741
742 let mut ctx = Context::new(bus, "run_1");
743 ctx.set("input", Value::tensor(vec![1.0], vec![1]));
744 ctx.graph_info
745 .set_predecessors("double", vec!["input".into()]);
746
747 let mut filters = FilterLibrary::new();
748 filters.register("double", Box::new(DoublerFilter));
749
750 execute(
751 &ExecutionPlan::Execute {
752 node_id: "double".into(),
753 },
754 &mut ctx,
755 &filters,
756 &cache,
757 )
758 .unwrap();
759
760 let e1 = rx.try_recv().unwrap();
761 assert!(matches!(e1, Event::NodeStarted { .. }));
762 let e2 = rx.try_recv().unwrap();
763 assert!(matches!(e2, Event::NodeCompleted { .. }));
764 }
765
766 #[test]
767 fn execute_missing_filter_errors() {
768 let (bus, cache) = setup();
769 let mut ctx = Context::new(bus, "run_1");
770 let filters = FilterLibrary::new();
771
772 let result = execute(
773 &ExecutionPlan::Execute {
774 node_id: "nonexistent".into(),
775 },
776 &mut ctx,
777 &filters,
778 &cache,
779 );
780 assert!(matches!(result, Err(SomaError::NodeNotFound(_))));
781 }
782
783 #[test]
784 fn execute_empty_plan() {
785 let (bus, cache) = setup();
786 let mut ctx = Context::new(bus, "run_1");
787 let filters = FilterLibrary::new();
788 execute(&ExecutionPlan::Empty, &mut ctx, &filters, &cache).unwrap();
789 }
790
791 #[test]
792 fn execute_parallel_branches_merge_outputs() {
793 let (bus, cache) = setup();
794 let mut ctx = Context::new(bus, "run_1");
795 ctx.set("input", Value::tensor(vec![5.0], vec![1]));
796 ctx.graph_info
797 .set_predecessors("double", vec!["input".into()]);
798 ctx.graph_info.set_predecessors("add", vec!["input".into()]);
799
800 let mut filters = FilterLibrary::new();
801 filters.register("double", Box::new(DoublerFilter));
802 filters.register("add", Box::new(AdderFilter { amount: 100.0 }));
803
804 let plan = ExecutionPlan::Parallel(vec![
805 ExecutionPlan::Execute {
806 node_id: "double".into(),
807 },
808 ExecutionPlan::Execute {
809 node_id: "add".into(),
810 },
811 ]);
812
813 execute(&plan, &mut ctx, &filters, &cache).unwrap();
814
815 let double_out = ctx.get("double").unwrap().as_tensor().unwrap().0;
816 assert_eq!(double_out, &[10.0]);
817 let add_out = ctx.get("add").unwrap().as_tensor().unwrap().0;
818 assert_eq!(add_out, &[105.0]);
819 }
820
821 #[test]
822 fn parallel_branches_run_concurrently() {
823 let (bus, cache) = setup();
824 let mut ctx = Context::new(bus, "run_1");
825 ctx.set("input", Value::tensor(vec![1.0], vec![1]));
826 ctx.graph_info
827 .set_predecessors("slow_a", vec!["input".into()]);
828 ctx.graph_info
829 .set_predecessors("slow_b", vec!["input".into()]);
830
831 let mut filters = FilterLibrary::new();
832 filters.register(
833 "slow_a",
834 Box::new(SlowFilter {
835 id: "a".into(),
836 delay_ms: 50,
837 }),
838 );
839 filters.register(
840 "slow_b",
841 Box::new(SlowFilter {
842 id: "b".into(),
843 delay_ms: 50,
844 }),
845 );
846
847 let plan = ExecutionPlan::Parallel(vec![
848 ExecutionPlan::Execute {
849 node_id: "slow_a".into(),
850 },
851 ExecutionPlan::Execute {
852 node_id: "slow_b".into(),
853 },
854 ]);
855
856 let start = Instant::now();
857 execute(&plan, &mut ctx, &filters, &cache).unwrap();
858 let elapsed = start.elapsed();
859
860 assert!(
863 elapsed.as_millis() < 90,
864 "parallel branches took {}ms, expected <90ms (sequential would be ~100ms)",
865 elapsed.as_millis()
866 );
867
868 assert!(ctx.get("slow_a").is_some());
869 assert!(ctx.get("slow_b").is_some());
870 }
871
872 #[test]
873 fn resolve_input_single_predecessor() {
874 let bus = Arc::new(EventBus::new(8));
875 let mut ctx = Context::new(bus, "r");
876 ctx.set("A", Value::tensor(vec![42.0], vec![1]));
877 ctx.graph_info.set_predecessors("B", vec!["A".into()]);
878
879 let input = resolve_input("B", &ctx);
880 let (data, _) = input.as_tensor().unwrap();
881 assert_eq!(data, &[42.0]);
882 }
883
884 #[test]
885 fn resolve_input_multiple_predecessors() {
886 let bus = Arc::new(EventBus::new(8));
887 let mut ctx = Context::new(bus, "r");
888 ctx.set("A", Value::tensor(vec![1.0], vec![1]));
889 ctx.set("B", Value::tensor(vec![2.0], vec![1]));
890 ctx.graph_info
891 .set_predecessors("C", vec!["A".into(), "B".into()]);
892
893 let input = resolve_input("C", &ctx);
894 let json = input.as_json().unwrap();
895 assert!(json.get("A").is_some());
896 assert!(json.get("B").is_some());
897 }
898
899 #[test]
900 fn resolve_input_no_predecessors_fallback() {
901 let bus = Arc::new(EventBus::new(8));
902 let mut ctx = Context::new(bus, "r");
903 ctx.set("prev", Value::tensor(vec![7.0], vec![1]));
904
905 let input = resolve_input("root", &ctx);
906 let (data, _) = input.as_tensor().unwrap();
907 assert_eq!(data, &[7.0]);
908 }
909
910 #[test]
911 fn graph_info_from_linear() {
912 let info = GraphInfo::for_linear(&["a", "b", "c"]);
913 assert!(info.predecessors("a").is_empty());
914 assert_eq!(info.predecessors("b"), &["a"]);
915 assert_eq!(info.predecessors("c"), &["b"]);
916 }
917}