1use crate::event_bus::EventBus;
7use crate::filter_library::FilterLibrary;
8use somatize_compiler::ExecutionPlan;
9use somatize_core::cache::CacheStore;
10use somatize_core::error::{Result, SomaError};
11use somatize_core::event::Event;
12use somatize_core::store::DataStore;
13use somatize_core::value::Value;
14use somatize_core::virtual_value::VirtualValue;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Instant;
18
19#[derive(Debug, Clone, Default)]
24pub struct GraphInfo {
25 predecessors: HashMap<String, Vec<String>>,
27}
28
29impl GraphInfo {
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn set_predecessors(&mut self, node_id: impl Into<String>, preds: Vec<String>) {
36 self.predecessors.insert(node_id.into(), preds);
37 }
38
39 pub fn from_graph(graph: &somatize_core::graph::Graph) -> Self {
41 let mut info = Self::new();
42 for node in &graph.nodes {
43 let preds: Vec<String> = graph
44 .predecessors(&node.id)
45 .into_iter()
46 .map(|s| s.to_string())
47 .collect();
48 info.set_predecessors(node.id.clone(), preds);
49 }
50 info
51 }
52
53 pub fn for_linear(node_ids: &[&str]) -> Self {
55 let mut info = Self::new();
56 for (i, &id) in node_ids.iter().enumerate() {
57 let preds = if i > 0 {
58 vec![node_ids[i - 1].to_string()]
59 } else {
60 vec![]
61 };
62 info.set_predecessors(id, preds);
63 }
64 info
65 }
66
67 pub fn predecessors(&self, node_id: &str) -> &[String] {
69 self.predecessors
70 .get(node_id)
71 .map(|v| v.as_slice())
72 .unwrap_or(&[])
73 }
74}
75
76pub trait RemoteExecutor: Send + Sync {
82 fn execute_remote(
84 &self,
85 node_id: &str,
86 target: &somatize_core::filter::RemoteTarget,
87 input: Option<&Value>,
88 ) -> Result<Value>;
89}
90
91pub struct Context {
97 pub store: HashMap<String, VirtualValue>,
99 pub event_bus: Arc<EventBus>,
101 pub run_id: String,
103 pub execution_order: Vec<String>,
105 pub graph_info: GraphInfo,
107 pub remote_executor: Option<Arc<dyn RemoteExecutor>>,
109 pub data_store: Option<Arc<dyn DataStore>>,
111 pub spill_threshold: usize,
114}
115
116impl Context {
117 pub fn new(event_bus: Arc<EventBus>, run_id: impl Into<String>) -> Self {
118 Self {
119 store: HashMap::new(),
120 event_bus,
121 run_id: run_id.into(),
122 execution_order: Vec::new(),
123 graph_info: GraphInfo::new(),
124 remote_executor: None,
125 data_store: None,
126 spill_threshold: 0,
127 }
128 }
129
130 pub fn with_graph_info(mut self, info: GraphInfo) -> Self {
131 self.graph_info = info;
132 self
133 }
134
135 pub fn with_remote_executor(mut self, executor: Arc<dyn RemoteExecutor>) -> Self {
136 self.remote_executor = Some(executor);
137 self
138 }
139
140 pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
141 self.data_store = Some(store);
142 self
143 }
144
145 pub fn with_spill_threshold(mut self, bytes: usize) -> Self {
149 self.spill_threshold = bytes;
150 self
151 }
152
153 fn maybe_spill(&self, node_id: &str, value: Value) -> VirtualValue {
156 if self.spill_threshold > 0
157 && let Some(store) = &self.data_store
158 {
159 let size = value.size() * 8; if size >= self.spill_threshold {
161 let key = somatize_core::cache::CacheKey::from_parts(&[
162 self.run_id.as_bytes(),
163 node_id.as_bytes(),
164 ]);
165 let vv_for_schema = VirtualValue::materialized(value.clone());
166 let schema = vv_for_schema.schema().clone();
167 if let Ok(_data_ref) = store.put(&key, &value) {
168 tracing::debug!("spilled node `{node_id}` ({size} bytes) to DataStore");
169 return VirtualValue::cached(key, schema);
170 }
171 }
172 }
173 VirtualValue::materialized(value)
174 }
175
176 pub fn get(&self, node_id: &str) -> Option<&Value> {
178 self.store.get(node_id).and_then(|vv| vv.as_value())
179 }
180
181 pub fn get_virtual(&self, node_id: &str) -> Option<&VirtualValue> {
183 self.store.get(node_id)
184 }
185
186 pub fn set(&mut self, node_id: impl Into<String>, value: Value) {
188 let id = node_id.into();
189 self.execution_order.push(id.clone());
190 self.store.insert(id, VirtualValue::materialized(value));
191 }
192
193 pub fn set_virtual(&mut self, node_id: impl Into<String>, vv: VirtualValue) {
195 let id = node_id.into();
196 self.execution_order.push(id.clone());
197 self.store.insert(id, vv);
198 }
199
200 fn snapshot(&self) -> Self {
201 Self {
202 store: self.store.clone(),
203 event_bus: self.event_bus.clone(),
204 run_id: self.run_id.clone(),
205 execution_order: self.execution_order.clone(),
206 graph_info: self.graph_info.clone(),
207 remote_executor: self.remote_executor.clone(),
208 data_store: self.data_store.clone(),
209 spill_threshold: self.spill_threshold,
210 }
211 }
212}
213
214pub trait Executable {
218 fn execute(
219 &self,
220 ctx: &mut Context,
221 filters: &FilterLibrary,
222 cache: &dyn CacheStore,
223 ) -> Result<()>;
224}
225
226impl Executable for ExecutionPlan {
227 fn execute(
228 &self,
229 ctx: &mut Context,
230 filters: &FilterLibrary,
231 cache: &dyn CacheStore,
232 ) -> Result<()> {
233 match self {
234 ExecutionPlan::Empty => Ok(()),
235
236 ExecutionPlan::Execute { node_id } => execute_node(node_id, ctx, filters, cache),
237
238 ExecutionPlan::Cached { node_id, key } => {
239 let start = Instant::now();
240 let value = cache.get(key)?.ok_or_else(|| {
241 SomaError::Cache(format!(
242 "expected cached value for node `{node_id}` not found"
243 ))
244 })?;
245 ctx.set(node_id.clone(), value);
246 ctx.event_bus.emit(Event::NodeCacheHit {
247 run_id: ctx.run_id.clone(),
248 node_id: node_id.clone(),
249 key: key.clone(),
250 tier: somatize_core::cache::CacheTier::Memory,
251 load_time: start.elapsed(),
252 });
253 Ok(())
254 }
255
256 ExecutionPlan::Sequence(steps) => {
257 for step in steps {
258 step.execute(ctx, filters, cache)?;
259 }
260 Ok(())
261 }
262
263 ExecutionPlan::Parallel(branches) => execute_parallel(branches, ctx, filters, cache),
264
265 ExecutionPlan::Loop {
266 node_id,
267 body,
268 max_iterations,
269 } => {
270 let max = max_iterations.unwrap_or(100);
271 for i in 0..max {
272 body.execute(ctx, filters, cache)?;
273
274 let should_stop = ctx
277 .execution_order
278 .last()
279 .and_then(|last_id| ctx.get(last_id))
280 .map(|v| match v {
281 Value::Json(j) => {
282 j.as_bool() == Some(true)
283 || j.as_str().map(|s| s == "done" || s == "stop") == Some(true)
284 || j.get("done").and_then(|d| d.as_bool()) == Some(true)
285 }
286 Value::Empty => true,
287 _ => false,
288 })
289 .unwrap_or(false);
290
291 if should_stop {
292 ctx.event_bus.emit(Event::NodeCompleted {
293 run_id: ctx.run_id.clone(),
294 node_id: node_id.clone(),
295 duration: std::time::Duration::ZERO,
296 output_summary: format!("Loop terminated at iteration {}", i + 1),
297 });
298 break;
299 }
300 }
301 Ok(())
302 }
303
304 ExecutionPlan::Branch { node_id, arms } => {
305 execute_node(node_id, ctx, filters, cache)?;
307
308 let condition = ctx.get(node_id).cloned().unwrap_or(Value::Empty);
310
311 let selected_arm = match &condition {
313 Value::Json(j) => {
314 let selector = j
316 .as_str()
317 .map(String::from)
318 .or_else(|| j.as_bool().map(|b| b.to_string()))
319 .or_else(|| j.get("branch").and_then(|b| b.as_str()).map(String::from))
320 .unwrap_or_else(|| "true".to_string());
321
322 arms.iter()
323 .find(|(label, _)| label == &selector)
324 .or_else(|| {
325 arms.iter()
326 .find(|(label, _)| label == "default" || label == "else")
327 })
328 .or_else(|| arms.first())
329 }
330 _ => arms.first(),
331 };
332
333 if let Some((label, plan)) = selected_arm {
334 ctx.event_bus.emit(Event::NodeCompleted {
335 run_id: ctx.run_id.clone(),
336 node_id: node_id.clone(),
337 duration: std::time::Duration::ZERO,
338 output_summary: format!("Branch selected: {label}"),
339 });
340 plan.execute(ctx, filters, cache)?;
341 }
342 Ok(())
343 }
344
345 ExecutionPlan::Remote {
346 node_id,
347 target,
348 plan,
349 } => {
350 if let Some(remote) = &ctx.remote_executor {
351 let input = ctx
353 .graph_info
354 .predecessors(node_id)
355 .first()
356 .and_then(|pred| ctx.get(pred));
357
358 let result = remote.execute_remote(node_id, target, input)?;
359 ctx.set(node_id.clone(), result);
360 ctx.execution_order.push(node_id.clone());
361 Ok(())
362 } else {
363 plan.execute(ctx, filters, cache)
365 }
366 }
367
368 ExecutionPlan::Composite { node_ids } => {
369 for nid in node_ids {
372 execute_node(nid, ctx, filters, cache)?;
373 }
374 Ok(())
375 }
376
377 _ => {
378 tracing::warn!("Unhandled ExecutionPlan variant");
379 Ok(())
380 }
381 }
382 }
383}
384
385pub fn execute(
387 plan: &ExecutionPlan,
388 ctx: &mut Context,
389 filters: &FilterLibrary,
390 cache: &dyn CacheStore,
391) -> Result<()> {
392 plan.execute(ctx, filters, cache)
393}
394
395fn execute_node(
397 node_id: &str,
398 ctx: &mut Context,
399 filters: &FilterLibrary,
400 _cache: &dyn CacheStore,
401) -> Result<()> {
402 let start = Instant::now();
403
404 let filter = filters
405 .get(node_id)
406 .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
407
408 ctx.event_bus.emit(Event::NodeStarted {
409 run_id: ctx.run_id.clone(),
410 node_id: node_id.to_string(),
411 kind: filter.meta().kind,
412 });
413
414 let input = resolve_input(node_id, ctx);
415 let state = filters.get_state(node_id).cloned().unwrap_or(Value::Empty);
416 let result = filter.forward(&input, &state);
417
418 match result {
419 Ok(output) => {
420 let duration = start.elapsed();
421 let summary = format!("{output}");
422 let vv = ctx.maybe_spill(node_id, output);
423 ctx.set_virtual(node_id, vv);
424 ctx.event_bus.emit(Event::NodeCompleted {
425 run_id: ctx.run_id.clone(),
426 node_id: node_id.to_string(),
427 duration,
428 output_summary: summary,
429 });
430 Ok(())
431 }
432 Err(e) => {
433 ctx.event_bus.emit(Event::NodeFailed {
434 run_id: ctx.run_id.clone(),
435 node_id: node_id.to_string(),
436 error: e.to_string(),
437 });
438 Err(e)
439 }
440 }
441}
442
443fn execute_parallel(
448 branches: &[ExecutionPlan],
449 ctx: &mut Context,
450 filters: &FilterLibrary,
451 cache: &dyn CacheStore,
452) -> Result<()> {
453 let snapshot_keys: Arc<std::collections::HashSet<String>> =
454 Arc::new(ctx.store.keys().cloned().collect());
455
456 let results: Vec<Result<Vec<(String, VirtualValue)>>> = std::thread::scope(|s| {
458 let handles: Vec<_> = branches
459 .iter()
460 .map(|branch| {
461 let mut branch_ctx = ctx.snapshot();
462 let keys = snapshot_keys.clone();
463 s.spawn(move || {
464 execute(branch, &mut branch_ctx, filters, cache)?;
465 let new_entries: Vec<(String, VirtualValue)> = branch_ctx
466 .store
467 .into_iter()
468 .filter(|(k, _)| !keys.contains(k))
469 .collect();
470 Ok(new_entries)
471 })
472 })
473 .collect();
474
475 handles.into_iter().map(|h| h.join().unwrap()).collect()
476 });
477
478 for result in results {
480 let entries = result?;
481 for (key, vv) in entries {
482 ctx.set_virtual(key, vv);
483 }
484 }
485
486 Ok(())
487}
488
489fn resolve_value(vv: &VirtualValue, data_store: &Option<Arc<dyn DataStore>>) -> Option<Value> {
491 match vv {
492 VirtualValue::Materialized { value, .. } => Some(value.clone()),
493 VirtualValue::Cached { key, .. } => {
494 if let Some(store) = data_store {
496 let data_ref = somatize_core::store::DataRef::Cached {
497 cache_key: key.clone(),
498 };
499 store.get(&data_ref).ok()
500 } else {
501 None
502 }
503 }
504 _ => None,
505 }
506}
507
508pub fn resolve_input(node_id: &str, ctx: &Context) -> Value {
511 let preds = ctx.graph_info.predecessors(node_id);
512
513 let resolve_node = |id: &str| -> Option<Value> {
514 ctx.store
515 .get(id)
516 .and_then(|vv| resolve_value(vv, &ctx.data_store))
517 };
518
519 match preds.len() {
520 0 => ctx
521 .execution_order
522 .last()
523 .and_then(|id| resolve_node(id))
524 .unwrap_or(Value::Empty),
525 1 => resolve_node(&preds[0]).unwrap_or(Value::Empty),
526 _ => {
527 let mut merged = serde_json::Map::new();
528 for pred_id in preds {
529 if let Some(val) = resolve_node(pred_id) {
530 let json_val = serde_json::to_value(&val).unwrap_or(serde_json::Value::Null);
531 merged.insert(pred_id.clone(), json_val);
532 }
533 }
534 Value::Json(serde_json::Value::Object(merged))
535 }
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542 use crate::cache::MemoryCache;
543 use somatize_core::cache::CacheKey;
544 use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
545
546 struct DoublerFilter;
547
548 impl Filter for DoublerFilter {
549 fn config_hash(&self) -> CacheKey {
550 CacheKey::from_parts(&[b"Doubler"])
551 }
552 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
553 Ok(Value::Empty)
554 }
555 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
556 match x {
557 Value::Tensor { values, shape } => {
558 let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
559 Ok(Value::tensor(doubled, shape.clone()))
560 }
561 _ => Ok(x.clone()),
562 }
563 }
564 fn meta(&self) -> FilterMeta {
565 FilterMeta {
566 name: "Doubler".into(),
567 kind: FilterKind::Stateless,
568 cacheable: true,
569 differentiable: true,
570 stream_mode: StreamMode::FixedState,
571 distribution: somatize_core::filter::Distribution::Local,
572 input_schema: None,
573 output_schema: None,
574 }
575 }
576
577 fn as_any(&self) -> &dyn std::any::Any {
578 self
579 }
580 }
581
582 struct AdderFilter {
583 amount: f64,
584 }
585
586 impl Filter for AdderFilter {
587 fn config_hash(&self) -> CacheKey {
588 CacheKey::from_parts(&[b"Adder", &self.amount.to_le_bytes()])
589 }
590 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
591 Ok(Value::Empty)
592 }
593 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
594 match x {
595 Value::Tensor { values, shape } => {
596 let added: Vec<f64> = values.iter().map(|v| v + self.amount).collect();
597 Ok(Value::tensor(added, shape.clone()))
598 }
599 _ => Ok(x.clone()),
600 }
601 }
602 fn meta(&self) -> FilterMeta {
603 FilterMeta {
604 name: "Adder".into(),
605 kind: FilterKind::Stateless,
606 cacheable: true,
607 differentiable: true,
608 stream_mode: StreamMode::FixedState,
609 distribution: somatize_core::filter::Distribution::Local,
610 input_schema: None,
611 output_schema: None,
612 }
613 }
614
615 fn as_any(&self) -> &dyn std::any::Any {
616 self
617 }
618 }
619
620 struct SlowFilter {
622 id: String,
623 delay_ms: u64,
624 }
625
626 impl Filter for SlowFilter {
627 fn config_hash(&self) -> CacheKey {
628 CacheKey::from_parts(&[b"Slow", self.id.as_bytes()])
629 }
630 fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
631 Ok(Value::Empty)
632 }
633 fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
634 std::thread::sleep(std::time::Duration::from_millis(self.delay_ms));
635 Ok(x.clone())
636 }
637 fn meta(&self) -> FilterMeta {
638 FilterMeta {
639 name: format!("Slow_{}", self.id),
640 kind: FilterKind::Stateless,
641 cacheable: false,
642 differentiable: true,
643 stream_mode: StreamMode::FixedState,
644 distribution: somatize_core::filter::Distribution::Local,
645 input_schema: None,
646 output_schema: None,
647 }
648 }
649
650 fn as_any(&self) -> &dyn std::any::Any {
651 self
652 }
653 }
654
655 fn setup() -> (Arc<EventBus>, MemoryCache) {
656 (Arc::new(EventBus::new(64)), MemoryCache::default())
657 }
658
659 #[test]
660 fn execute_single_node() {
661 let (bus, cache) = setup();
662 let mut ctx = Context::new(bus, "run_1");
663 ctx.set("input", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
664 ctx.graph_info
665 .set_predecessors("doubler", vec!["input".into()]);
666
667 let mut filters = FilterLibrary::new();
668 filters.register("doubler", Box::new(DoublerFilter));
669
670 let plan = ExecutionPlan::Execute {
671 node_id: "doubler".into(),
672 };
673
674 execute(&plan, &mut ctx, &filters, &cache).unwrap();
675
676 let result = ctx.get("doubler").unwrap();
677 let (data, _) = result.as_tensor().unwrap();
678 assert_eq!(data, &[2.0, 4.0, 6.0]);
679 }
680
681 #[test]
682 fn execute_sequence_with_graph_info() {
683 let (bus, cache) = setup();
684 let mut ctx = Context::new(bus, "run_1");
685 ctx.set("input", Value::tensor(vec![1.0, 2.0], vec![2]));
686
687 let graph_info = GraphInfo::for_linear(&["input", "add", "double"]);
688 ctx.graph_info = graph_info;
689
690 let mut filters = FilterLibrary::new();
691 filters.register("add", Box::new(AdderFilter { amount: 10.0 }));
692 filters.register("double", Box::new(DoublerFilter));
693
694 let plan = ExecutionPlan::Sequence(vec![
695 ExecutionPlan::Execute {
696 node_id: "add".into(),
697 },
698 ExecutionPlan::Execute {
699 node_id: "double".into(),
700 },
701 ]);
702
703 execute(&plan, &mut ctx, &filters, &cache).unwrap();
704
705 let result = ctx.get("double").unwrap();
706 let (data, _) = result.as_tensor().unwrap();
707 assert_eq!(data, &[22.0, 24.0]);
708 }
709
710 #[test]
711 fn execute_cached_node() {
712 let (bus, cache) = setup();
713 let key = CacheKey::hash_data(b"cached_result");
714 let cached_value = Value::tensor(vec![99.0], vec![1]);
715 cache.put(&key, &cached_value).unwrap();
716
717 let mut ctx = Context::new(bus, "run_1");
718 let filters = FilterLibrary::new();
719
720 let plan = ExecutionPlan::Cached {
721 node_id: "cached_node".into(),
722 key,
723 };
724
725 execute(&plan, &mut ctx, &filters, &cache).unwrap();
726 assert_eq!(*ctx.get("cached_node").unwrap(), cached_value);
727 }
728
729 #[test]
730 fn execute_emits_events() {
731 let bus = Arc::new(EventBus::new(64));
732 let cache = MemoryCache::default();
733 let mut rx = bus.subscribe();
734
735 let mut ctx = Context::new(bus, "run_1");
736 ctx.set("input", Value::tensor(vec![1.0], vec![1]));
737 ctx.graph_info
738 .set_predecessors("double", vec!["input".into()]);
739
740 let mut filters = FilterLibrary::new();
741 filters.register("double", Box::new(DoublerFilter));
742
743 execute(
744 &ExecutionPlan::Execute {
745 node_id: "double".into(),
746 },
747 &mut ctx,
748 &filters,
749 &cache,
750 )
751 .unwrap();
752
753 let e1 = rx.try_recv().unwrap();
754 assert!(matches!(e1, Event::NodeStarted { .. }));
755 let e2 = rx.try_recv().unwrap();
756 assert!(matches!(e2, Event::NodeCompleted { .. }));
757 }
758
759 #[test]
760 fn execute_missing_filter_errors() {
761 let (bus, cache) = setup();
762 let mut ctx = Context::new(bus, "run_1");
763 let filters = FilterLibrary::new();
764
765 let result = execute(
766 &ExecutionPlan::Execute {
767 node_id: "nonexistent".into(),
768 },
769 &mut ctx,
770 &filters,
771 &cache,
772 );
773 assert!(matches!(result, Err(SomaError::NodeNotFound(_))));
774 }
775
776 #[test]
777 fn execute_empty_plan() {
778 let (bus, cache) = setup();
779 let mut ctx = Context::new(bus, "run_1");
780 let filters = FilterLibrary::new();
781 execute(&ExecutionPlan::Empty, &mut ctx, &filters, &cache).unwrap();
782 }
783
784 #[test]
785 fn execute_parallel_branches_merge_outputs() {
786 let (bus, cache) = setup();
787 let mut ctx = Context::new(bus, "run_1");
788 ctx.set("input", Value::tensor(vec![5.0], vec![1]));
789 ctx.graph_info
790 .set_predecessors("double", vec!["input".into()]);
791 ctx.graph_info.set_predecessors("add", vec!["input".into()]);
792
793 let mut filters = FilterLibrary::new();
794 filters.register("double", Box::new(DoublerFilter));
795 filters.register("add", Box::new(AdderFilter { amount: 100.0 }));
796
797 let plan = ExecutionPlan::Parallel(vec![
798 ExecutionPlan::Execute {
799 node_id: "double".into(),
800 },
801 ExecutionPlan::Execute {
802 node_id: "add".into(),
803 },
804 ]);
805
806 execute(&plan, &mut ctx, &filters, &cache).unwrap();
807
808 let double_out = ctx.get("double").unwrap().as_tensor().unwrap().0;
809 assert_eq!(double_out, &[10.0]);
810 let add_out = ctx.get("add").unwrap().as_tensor().unwrap().0;
811 assert_eq!(add_out, &[105.0]);
812 }
813
814 #[test]
815 fn parallel_branches_run_concurrently() {
816 let (bus, cache) = setup();
817 let mut ctx = Context::new(bus, "run_1");
818 ctx.set("input", Value::tensor(vec![1.0], vec![1]));
819 ctx.graph_info
820 .set_predecessors("slow_a", vec!["input".into()]);
821 ctx.graph_info
822 .set_predecessors("slow_b", vec!["input".into()]);
823
824 let mut filters = FilterLibrary::new();
825 filters.register(
826 "slow_a",
827 Box::new(SlowFilter {
828 id: "a".into(),
829 delay_ms: 50,
830 }),
831 );
832 filters.register(
833 "slow_b",
834 Box::new(SlowFilter {
835 id: "b".into(),
836 delay_ms: 50,
837 }),
838 );
839
840 let plan = ExecutionPlan::Parallel(vec![
841 ExecutionPlan::Execute {
842 node_id: "slow_a".into(),
843 },
844 ExecutionPlan::Execute {
845 node_id: "slow_b".into(),
846 },
847 ]);
848
849 let start = Instant::now();
850 execute(&plan, &mut ctx, &filters, &cache).unwrap();
851 let elapsed = start.elapsed();
852
853 assert!(
856 elapsed.as_millis() < 90,
857 "parallel branches took {}ms, expected <90ms (sequential would be ~100ms)",
858 elapsed.as_millis()
859 );
860
861 assert!(ctx.get("slow_a").is_some());
862 assert!(ctx.get("slow_b").is_some());
863 }
864
865 #[test]
866 fn resolve_input_single_predecessor() {
867 let bus = Arc::new(EventBus::new(8));
868 let mut ctx = Context::new(bus, "r");
869 ctx.set("A", Value::tensor(vec![42.0], vec![1]));
870 ctx.graph_info.set_predecessors("B", vec!["A".into()]);
871
872 let input = resolve_input("B", &ctx);
873 let (data, _) = input.as_tensor().unwrap();
874 assert_eq!(data, &[42.0]);
875 }
876
877 #[test]
878 fn resolve_input_multiple_predecessors() {
879 let bus = Arc::new(EventBus::new(8));
880 let mut ctx = Context::new(bus, "r");
881 ctx.set("A", Value::tensor(vec![1.0], vec![1]));
882 ctx.set("B", Value::tensor(vec![2.0], vec![1]));
883 ctx.graph_info
884 .set_predecessors("C", vec!["A".into(), "B".into()]);
885
886 let input = resolve_input("C", &ctx);
887 let json = input.as_json().unwrap();
888 assert!(json.get("A").is_some());
889 assert!(json.get("B").is_some());
890 }
891
892 #[test]
893 fn resolve_input_no_predecessors_fallback() {
894 let bus = Arc::new(EventBus::new(8));
895 let mut ctx = Context::new(bus, "r");
896 ctx.set("prev", Value::tensor(vec![7.0], vec![1]));
897
898 let input = resolve_input("root", &ctx);
899 let (data, _) = input.as_tensor().unwrap();
900 assert_eq!(data, &[7.0]);
901 }
902
903 #[test]
904 fn graph_info_from_linear() {
905 let info = GraphInfo::for_linear(&["a", "b", "c"]);
906 assert!(info.predecessors("a").is_empty());
907 assert_eq!(info.predecessors("b"), &["a"]);
908 assert_eq!(info.predecessors("c"), &["b"]);
909 }
910}