Skip to main content

somatize_runtime/
executor.rs

1//! Plan executor — walks [`ExecutionPlan`] trees and runs filter nodes.
2//!
3//! Handles sequential, parallel (scoped threads), cached, remote, loop,
4//! and branch execution. Uses [`GraphInfo`] for topology-aware input resolution.
5
6use crate::event_bus::EventBus;
7use crate::filter_library::FilterLibrary;
8use somatize_compiler::ExecutionPlan;
9use somatize_core::cache::CacheStore;
10use somatize_core::error::{Result, SomaError};
11use somatize_core::event::Event;
12use somatize_core::store::DataStore;
13use somatize_core::value::Value;
14use somatize_core::virtual_value::VirtualValue;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Instant;
18
19/// Graph topology information for input resolution.
20///
21/// Maps each node to its predecessor node IDs so the executor knows
22/// where to read inputs from in the context store.
23#[derive(Debug, Clone, Default)]
24pub struct GraphInfo {
25    /// node_id → list of predecessor node IDs
26    predecessors: HashMap<String, Vec<String>>,
27}
28
29impl GraphInfo {
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    /// Register predecessors for a node.
35    pub fn set_predecessors(&mut self, node_id: impl Into<String>, preds: Vec<String>) {
36        self.predecessors.insert(node_id.into(), preds);
37    }
38
39    /// Build GraphInfo from a somatize_core::graph::Graph.
40    pub fn from_graph(graph: &somatize_core::graph::Graph) -> Self {
41        let mut info = Self::new();
42        for node in &graph.nodes {
43            let preds: Vec<String> = graph
44                .predecessors(&node.id)
45                .into_iter()
46                .map(|s| s.to_string())
47                .collect();
48            info.set_predecessors(node.id.clone(), preds);
49        }
50        info
51    }
52
53    /// Build GraphInfo for a linear pipeline (each node depends on the previous).
54    pub fn for_linear(node_ids: &[&str]) -> Self {
55        let mut info = Self::new();
56        for (i, &id) in node_ids.iter().enumerate() {
57            let preds = if i > 0 {
58                vec![node_ids[i - 1].to_string()]
59            } else {
60                vec![]
61            };
62            info.set_predecessors(id, preds);
63        }
64        info
65    }
66
67    /// Get predecessors for a node.
68    pub fn predecessors(&self, node_id: &str) -> &[String] {
69        self.predecessors
70            .get(node_id)
71            .map(|v| v.as_slice())
72            .unwrap_or(&[])
73    }
74}
75
76/// Trait for executing plan nodes on remote workers.
77///
78/// When set on Context, `ExecutionPlan::Remote` nodes delegate to this
79/// instead of executing locally. The implementation sends the sub-plan
80/// to a worker and returns the result.
81pub trait RemoteExecutor: Send + Sync {
82    /// Execute a sub-plan remotely and return the output value.
83    fn execute_remote(
84        &self,
85        node_id: &str,
86        target: &somatize_core::filter::RemoteTarget,
87        input: Option<&Value>,
88    ) -> Result<Value>;
89}
90
91/// Execution context passed to filters during runtime.
92///
93/// Node outputs are stored as [`VirtualValue`]s — they may be materialized
94/// in memory, cached on disk, or deferred (not yet computed). The executor
95/// resolves them on demand when a downstream node needs the data.
96pub struct Context {
97    /// Node outputs as virtual values (may be lazy).
98    pub store: HashMap<String, VirtualValue>,
99    /// Event bus for emitting runtime events.
100    pub event_bus: Arc<EventBus>,
101    /// Current run ID.
102    pub run_id: String,
103    /// Track execution order.
104    pub execution_order: Vec<String>,
105    /// Graph topology for input resolution.
106    pub graph_info: GraphInfo,
107    /// Optional remote executor for distributed plans.
108    pub remote_executor: Option<Arc<dyn RemoteExecutor>>,
109    /// Optional data store for persisting intermediate results.
110    pub data_store: Option<Arc<dyn DataStore>>,
111    /// Minimum value size (bytes) to spill to DataStore instead of keeping in memory.
112    /// Default: 0 (disabled — all values stay in memory).
113    pub spill_threshold: usize,
114}
115
116impl Context {
117    pub fn new(event_bus: Arc<EventBus>, run_id: impl Into<String>) -> Self {
118        Self {
119            store: HashMap::new(),
120            event_bus,
121            run_id: run_id.into(),
122            execution_order: Vec::new(),
123            graph_info: GraphInfo::new(),
124            remote_executor: None,
125            data_store: None,
126            spill_threshold: 0,
127        }
128    }
129
130    pub fn with_graph_info(mut self, info: GraphInfo) -> Self {
131        self.graph_info = info;
132        self
133    }
134
135    pub fn with_remote_executor(mut self, executor: Arc<dyn RemoteExecutor>) -> Self {
136        self.remote_executor = Some(executor);
137        self
138    }
139
140    pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
141        self.data_store = Some(store);
142        self
143    }
144
145    /// Set spill threshold: values larger than this (in bytes) are offloaded
146    /// to the DataStore and replaced with a VirtualValue::Cached reference.
147    /// Requires a DataStore to be set via `with_data_store()`.
148    pub fn with_spill_threshold(mut self, bytes: usize) -> Self {
149        self.spill_threshold = bytes;
150        self
151    }
152
153    /// If a DataStore and spill threshold are configured, check if the value
154    /// should be offloaded. Returns VirtualValue (materialized or cached ref).
155    fn maybe_spill(&self, node_id: &str, value: Value) -> VirtualValue {
156        if self.spill_threshold > 0
157            && let Some(store) = &self.data_store
158        {
159            let size = value.size() * 8; // approximate bytes (f64 = 8 bytes)
160            if size >= self.spill_threshold {
161                let key = somatize_core::cache::CacheKey::from_parts(&[
162                    self.run_id.as_bytes(),
163                    node_id.as_bytes(),
164                ]);
165                let vv_for_schema = VirtualValue::materialized(value.clone());
166                let schema = vv_for_schema.schema().clone();
167                if let Ok(_data_ref) = store.put(&key, &value) {
168                    tracing::debug!("spilled node `{node_id}` ({size} bytes) to DataStore");
169                    return VirtualValue::cached(key, schema);
170                }
171            }
172        }
173        VirtualValue::materialized(value)
174    }
175
176    /// Get the materialized Value for a node, if present and materialized.
177    pub fn get(&self, node_id: &str) -> Option<&Value> {
178        self.store.get(node_id).and_then(|vv| vv.as_value())
179    }
180
181    /// Get the raw VirtualValue for a node.
182    pub fn get_virtual(&self, node_id: &str) -> Option<&VirtualValue> {
183        self.store.get(node_id)
184    }
185
186    /// Store a materialized value for a node.
187    pub fn set(&mut self, node_id: impl Into<String>, value: Value) {
188        let id = node_id.into();
189        self.execution_order.push(id.clone());
190        self.store.insert(id, VirtualValue::materialized(value));
191    }
192
193    /// Store a virtual value (which may be deferred or cached).
194    pub fn set_virtual(&mut self, node_id: impl Into<String>, vv: VirtualValue) {
195        let id = node_id.into();
196        self.execution_order.push(id.clone());
197        self.store.insert(id, vv);
198    }
199
200    fn snapshot(&self) -> Self {
201        Self {
202            store: self.store.clone(),
203            event_bus: self.event_bus.clone(),
204            run_id: self.run_id.clone(),
205            execution_order: self.execution_order.clone(),
206            graph_info: self.graph_info.clone(),
207            remote_executor: self.remote_executor.clone(),
208            data_store: self.data_store.clone(),
209            spill_threshold: self.spill_threshold,
210        }
211    }
212}
213
214/// Execute a compiled plan synchronously.
215/// For parallel branches, uses the async executor under the hood.
216pub fn execute(
217    plan: &ExecutionPlan,
218    ctx: &mut Context,
219    filters: &FilterLibrary,
220    cache: &dyn CacheStore,
221) -> Result<()> {
222    match plan {
223        ExecutionPlan::Empty => Ok(()),
224
225        ExecutionPlan::Execute { node_id } => execute_node(node_id, ctx, filters, cache),
226
227        ExecutionPlan::Cached { node_id, key } => {
228            let start = Instant::now();
229            let value = cache.get(key)?.ok_or_else(|| {
230                SomaError::Cache(format!(
231                    "expected cached value for node `{node_id}` not found"
232                ))
233            })?;
234            ctx.set(node_id.clone(), value);
235            ctx.event_bus.emit(Event::NodeCacheHit {
236                run_id: ctx.run_id.clone(),
237                node_id: node_id.clone(),
238                key: key.clone(),
239                tier: somatize_core::cache::CacheTier::Memory,
240                load_time: start.elapsed(),
241            });
242            Ok(())
243        }
244
245        ExecutionPlan::Sequence(steps) => {
246            for step in steps {
247                execute(step, ctx, filters, cache)?;
248            }
249            Ok(())
250        }
251
252        ExecutionPlan::Parallel(branches) => execute_parallel(branches, ctx, filters, cache),
253
254        ExecutionPlan::Loop {
255            node_id,
256            body,
257            max_iterations,
258        } => {
259            let max = max_iterations.unwrap_or(100);
260            for i in 0..max {
261                execute(body, ctx, filters, cache)?;
262
263                // Check termination: if the last executed node produced a Value
264                // that indicates "done" (true, "done", "stop", or empty), break.
265                let should_stop = ctx
266                    .execution_order
267                    .last()
268                    .and_then(|last_id| ctx.get(last_id))
269                    .map(|v| match v {
270                        Value::Json(j) => {
271                            j.as_bool() == Some(true)
272                                || j.as_str().map(|s| s == "done" || s == "stop") == Some(true)
273                                || j.get("done").and_then(|d| d.as_bool()) == Some(true)
274                        }
275                        Value::Empty => true,
276                        _ => false,
277                    })
278                    .unwrap_or(false);
279
280                if should_stop {
281                    ctx.event_bus.emit(Event::NodeCompleted {
282                        run_id: ctx.run_id.clone(),
283                        node_id: node_id.clone(),
284                        duration: std::time::Duration::ZERO,
285                        output_summary: format!("Loop terminated at iteration {}", i + 1),
286                    });
287                    break;
288                }
289            }
290            Ok(())
291        }
292
293        ExecutionPlan::Branch { node_id, arms } => {
294            // Execute the branch node first (it produces the condition value)
295            execute_node(node_id, ctx, filters, cache)?;
296
297            // Get the condition result
298            let condition = ctx.get(node_id).cloned().unwrap_or(Value::Empty);
299
300            // Match against arm labels
301            let selected_arm = match &condition {
302                Value::Json(j) => {
303                    // Try matching by string value, bool, or "branch" field
304                    let selector = j
305                        .as_str()
306                        .map(String::from)
307                        .or_else(|| j.as_bool().map(|b| b.to_string()))
308                        .or_else(|| j.get("branch").and_then(|b| b.as_str()).map(String::from))
309                        .unwrap_or_else(|| "true".to_string());
310
311                    arms.iter()
312                        .find(|(label, _)| label == &selector)
313                        .or_else(|| {
314                            arms.iter()
315                                .find(|(label, _)| label == "default" || label == "else")
316                        })
317                        .or_else(|| arms.first())
318                }
319                _ => arms.first(),
320            };
321
322            if let Some((label, plan)) = selected_arm {
323                ctx.event_bus.emit(Event::NodeCompleted {
324                    run_id: ctx.run_id.clone(),
325                    node_id: node_id.clone(),
326                    duration: std::time::Duration::ZERO,
327                    output_summary: format!("Branch selected: {label}"),
328                });
329                execute(plan, ctx, filters, cache)?;
330            }
331            Ok(())
332        }
333
334        ExecutionPlan::Remote {
335            node_id,
336            target,
337            plan,
338        } => {
339            if let Some(remote) = &ctx.remote_executor {
340                // Gather input from predecessors
341                let input = ctx
342                    .graph_info
343                    .predecessors(node_id)
344                    .first()
345                    .and_then(|pred| ctx.get(pred));
346
347                let result = remote.execute_remote(node_id, target, input)?;
348                ctx.set(node_id.clone(), result);
349                ctx.execution_order.push(node_id.clone());
350                Ok(())
351            } else {
352                // No remote executor — fall back to local execution
353                execute(plan, ctx, filters, cache)
354            }
355        }
356
357        _ => {
358            tracing::warn!("Unhandled ExecutionPlan variant");
359            Ok(())
360        }
361    }
362}
363
364/// Execute a single filter node.
365fn execute_node(
366    node_id: &str,
367    ctx: &mut Context,
368    filters: &FilterLibrary,
369    _cache: &dyn CacheStore,
370) -> Result<()> {
371    let start = Instant::now();
372
373    let filter = filters
374        .get(node_id)
375        .ok_or_else(|| SomaError::NodeNotFound(node_id.to_string()))?;
376
377    ctx.event_bus.emit(Event::NodeStarted {
378        run_id: ctx.run_id.clone(),
379        node_id: node_id.to_string(),
380        kind: filter.meta().kind,
381    });
382
383    let input = resolve_input(node_id, ctx);
384    let state = filters.get_state(node_id).cloned().unwrap_or(Value::Empty);
385    let result = filter.forward(&input, &state);
386
387    match result {
388        Ok(output) => {
389            let duration = start.elapsed();
390            let summary = format!("{output}");
391            let vv = ctx.maybe_spill(node_id, output);
392            ctx.set_virtual(node_id, vv);
393            ctx.event_bus.emit(Event::NodeCompleted {
394                run_id: ctx.run_id.clone(),
395                node_id: node_id.to_string(),
396                duration,
397                output_summary: summary,
398            });
399            Ok(())
400        }
401        Err(e) => {
402            ctx.event_bus.emit(Event::NodeFailed {
403                run_id: ctx.run_id.clone(),
404                node_id: node_id.to_string(),
405                error: e.to_string(),
406            });
407            Err(e)
408        }
409    }
410}
411
412/// Execute parallel branches concurrently using std::thread::scope.
413///
414/// Each branch gets a snapshot of the context. After all branches complete,
415/// their new outputs are merged back into the main context.
416fn execute_parallel(
417    branches: &[ExecutionPlan],
418    ctx: &mut Context,
419    filters: &FilterLibrary,
420    cache: &dyn CacheStore,
421) -> Result<()> {
422    let snapshot_keys: Arc<std::collections::HashSet<String>> =
423        Arc::new(ctx.store.keys().cloned().collect());
424
425    // Use scoped threads for true parallelism without Send requirements
426    let results: Vec<Result<Vec<(String, VirtualValue)>>> = std::thread::scope(|s| {
427        let handles: Vec<_> = branches
428            .iter()
429            .map(|branch| {
430                let mut branch_ctx = ctx.snapshot();
431                let keys = snapshot_keys.clone();
432                s.spawn(move || {
433                    execute(branch, &mut branch_ctx, filters, cache)?;
434                    let new_entries: Vec<(String, VirtualValue)> = branch_ctx
435                        .store
436                        .into_iter()
437                        .filter(|(k, _)| !keys.contains(k))
438                        .collect();
439                    Ok(new_entries)
440                })
441            })
442            .collect();
443
444        handles.into_iter().map(|h| h.join().unwrap()).collect()
445    });
446
447    // Merge results and propagate first error
448    for result in results {
449        let entries = result?;
450        for (key, vv) in entries {
451            ctx.set_virtual(key, vv);
452        }
453    }
454
455    Ok(())
456}
457
458/// Resolve a VirtualValue to a concrete Value, loading from DataStore if needed.
459fn resolve_value(vv: &VirtualValue, data_store: &Option<Arc<dyn DataStore>>) -> Option<Value> {
460    match vv {
461        VirtualValue::Materialized { value, .. } => Some(value.clone()),
462        VirtualValue::Cached { key, .. } => {
463            // Try to load from DataStore
464            if let Some(store) = data_store {
465                let data_ref = somatize_core::store::DataRef::Cached {
466                    cache_key: key.clone(),
467                };
468                store.get(&data_ref).ok()
469            } else {
470                None
471            }
472        }
473        _ => None,
474    }
475}
476
477/// Resolve the input for a node from the context store using graph topology.
478/// If a predecessor was spilled to DataStore, loads it back.
479pub fn resolve_input(node_id: &str, ctx: &Context) -> Value {
480    let preds = ctx.graph_info.predecessors(node_id);
481
482    let resolve_node = |id: &str| -> Option<Value> {
483        ctx.store
484            .get(id)
485            .and_then(|vv| resolve_value(vv, &ctx.data_store))
486    };
487
488    match preds.len() {
489        0 => ctx
490            .execution_order
491            .last()
492            .and_then(|id| resolve_node(id))
493            .unwrap_or(Value::Empty),
494        1 => resolve_node(&preds[0]).unwrap_or(Value::Empty),
495        _ => {
496            let mut merged = serde_json::Map::new();
497            for pred_id in preds {
498                if let Some(val) = resolve_node(pred_id) {
499                    let json_val = serde_json::to_value(&val).unwrap_or(serde_json::Value::Null);
500                    merged.insert(pred_id.clone(), json_val);
501                }
502            }
503            Value::Json(serde_json::Value::Object(merged))
504        }
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511    use crate::cache::MemoryCache;
512    use somatize_core::cache::CacheKey;
513    use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
514
515    struct DoublerFilter;
516
517    impl Filter for DoublerFilter {
518        fn config_hash(&self) -> CacheKey {
519            CacheKey::from_parts(&[b"Doubler"])
520        }
521        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
522            Ok(Value::Empty)
523        }
524        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
525            match x {
526                Value::Tensor { values, shape } => {
527                    let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
528                    Ok(Value::tensor(doubled, shape.clone()))
529                }
530                _ => Ok(x.clone()),
531            }
532        }
533        fn meta(&self) -> FilterMeta {
534            FilterMeta {
535                name: "Doubler".into(),
536                kind: FilterKind::Stateless,
537                cacheable: true,
538                differentiable: true,
539                stream_mode: StreamMode::FixedState,
540                distribution: somatize_core::filter::Distribution::Local,
541                input_schema: None,
542                output_schema: None,
543            }
544        }
545    }
546
547    struct AdderFilter {
548        amount: f64,
549    }
550
551    impl Filter for AdderFilter {
552        fn config_hash(&self) -> CacheKey {
553            CacheKey::from_parts(&[b"Adder", &self.amount.to_le_bytes()])
554        }
555        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
556            Ok(Value::Empty)
557        }
558        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
559            match x {
560                Value::Tensor { values, shape } => {
561                    let added: Vec<f64> = values.iter().map(|v| v + self.amount).collect();
562                    Ok(Value::tensor(added, shape.clone()))
563                }
564                _ => Ok(x.clone()),
565            }
566        }
567        fn meta(&self) -> FilterMeta {
568            FilterMeta {
569                name: "Adder".into(),
570                kind: FilterKind::Stateless,
571                cacheable: true,
572                differentiable: true,
573                stream_mode: StreamMode::FixedState,
574                distribution: somatize_core::filter::Distribution::Local,
575                input_schema: None,
576                output_schema: None,
577            }
578        }
579    }
580
581    /// Slow filter that sleeps to verify parallelism.
582    struct SlowFilter {
583        id: String,
584        delay_ms: u64,
585    }
586
587    impl Filter for SlowFilter {
588        fn config_hash(&self) -> CacheKey {
589            CacheKey::from_parts(&[b"Slow", self.id.as_bytes()])
590        }
591        fn fit(&self, _x: &Value, _y: Option<&Value>) -> Result<Value> {
592            Ok(Value::Empty)
593        }
594        fn forward(&self, x: &Value, _state: &Value) -> Result<Value> {
595            std::thread::sleep(std::time::Duration::from_millis(self.delay_ms));
596            Ok(x.clone())
597        }
598        fn meta(&self) -> FilterMeta {
599            FilterMeta {
600                name: format!("Slow_{}", self.id),
601                kind: FilterKind::Stateless,
602                cacheable: false,
603                differentiable: true,
604                stream_mode: StreamMode::FixedState,
605                distribution: somatize_core::filter::Distribution::Local,
606                input_schema: None,
607                output_schema: None,
608            }
609        }
610    }
611
612    fn setup() -> (Arc<EventBus>, MemoryCache) {
613        (Arc::new(EventBus::new(64)), MemoryCache::default())
614    }
615
616    #[test]
617    fn execute_single_node() {
618        let (bus, cache) = setup();
619        let mut ctx = Context::new(bus, "run_1");
620        ctx.set("input", Value::tensor(vec![1.0, 2.0, 3.0], vec![3]));
621        ctx.graph_info
622            .set_predecessors("doubler", vec!["input".into()]);
623
624        let mut filters = FilterLibrary::new();
625        filters.register("doubler", Box::new(DoublerFilter));
626
627        let plan = ExecutionPlan::Execute {
628            node_id: "doubler".into(),
629        };
630
631        execute(&plan, &mut ctx, &filters, &cache).unwrap();
632
633        let result = ctx.get("doubler").unwrap();
634        let (data, _) = result.as_tensor().unwrap();
635        assert_eq!(data, &[2.0, 4.0, 6.0]);
636    }
637
638    #[test]
639    fn execute_sequence_with_graph_info() {
640        let (bus, cache) = setup();
641        let mut ctx = Context::new(bus, "run_1");
642        ctx.set("input", Value::tensor(vec![1.0, 2.0], vec![2]));
643
644        let graph_info = GraphInfo::for_linear(&["input", "add", "double"]);
645        ctx.graph_info = graph_info;
646
647        let mut filters = FilterLibrary::new();
648        filters.register("add", Box::new(AdderFilter { amount: 10.0 }));
649        filters.register("double", Box::new(DoublerFilter));
650
651        let plan = ExecutionPlan::Sequence(vec![
652            ExecutionPlan::Execute {
653                node_id: "add".into(),
654            },
655            ExecutionPlan::Execute {
656                node_id: "double".into(),
657            },
658        ]);
659
660        execute(&plan, &mut ctx, &filters, &cache).unwrap();
661
662        let result = ctx.get("double").unwrap();
663        let (data, _) = result.as_tensor().unwrap();
664        assert_eq!(data, &[22.0, 24.0]);
665    }
666
667    #[test]
668    fn execute_cached_node() {
669        let (bus, cache) = setup();
670        let key = CacheKey::hash_data(b"cached_result");
671        let cached_value = Value::tensor(vec![99.0], vec![1]);
672        cache.put(&key, &cached_value).unwrap();
673
674        let mut ctx = Context::new(bus, "run_1");
675        let filters = FilterLibrary::new();
676
677        let plan = ExecutionPlan::Cached {
678            node_id: "cached_node".into(),
679            key,
680        };
681
682        execute(&plan, &mut ctx, &filters, &cache).unwrap();
683        assert_eq!(*ctx.get("cached_node").unwrap(), cached_value);
684    }
685
686    #[test]
687    fn execute_emits_events() {
688        let bus = Arc::new(EventBus::new(64));
689        let cache = MemoryCache::default();
690        let mut rx = bus.subscribe();
691
692        let mut ctx = Context::new(bus, "run_1");
693        ctx.set("input", Value::tensor(vec![1.0], vec![1]));
694        ctx.graph_info
695            .set_predecessors("double", vec!["input".into()]);
696
697        let mut filters = FilterLibrary::new();
698        filters.register("double", Box::new(DoublerFilter));
699
700        execute(
701            &ExecutionPlan::Execute {
702                node_id: "double".into(),
703            },
704            &mut ctx,
705            &filters,
706            &cache,
707        )
708        .unwrap();
709
710        let e1 = rx.try_recv().unwrap();
711        assert!(matches!(e1, Event::NodeStarted { .. }));
712        let e2 = rx.try_recv().unwrap();
713        assert!(matches!(e2, Event::NodeCompleted { .. }));
714    }
715
716    #[test]
717    fn execute_missing_filter_errors() {
718        let (bus, cache) = setup();
719        let mut ctx = Context::new(bus, "run_1");
720        let filters = FilterLibrary::new();
721
722        let result = execute(
723            &ExecutionPlan::Execute {
724                node_id: "nonexistent".into(),
725            },
726            &mut ctx,
727            &filters,
728            &cache,
729        );
730        assert!(matches!(result, Err(SomaError::NodeNotFound(_))));
731    }
732
733    #[test]
734    fn execute_empty_plan() {
735        let (bus, cache) = setup();
736        let mut ctx = Context::new(bus, "run_1");
737        let filters = FilterLibrary::new();
738        execute(&ExecutionPlan::Empty, &mut ctx, &filters, &cache).unwrap();
739    }
740
741    #[test]
742    fn execute_parallel_branches_merge_outputs() {
743        let (bus, cache) = setup();
744        let mut ctx = Context::new(bus, "run_1");
745        ctx.set("input", Value::tensor(vec![5.0], vec![1]));
746        ctx.graph_info
747            .set_predecessors("double", vec!["input".into()]);
748        ctx.graph_info.set_predecessors("add", vec!["input".into()]);
749
750        let mut filters = FilterLibrary::new();
751        filters.register("double", Box::new(DoublerFilter));
752        filters.register("add", Box::new(AdderFilter { amount: 100.0 }));
753
754        let plan = ExecutionPlan::Parallel(vec![
755            ExecutionPlan::Execute {
756                node_id: "double".into(),
757            },
758            ExecutionPlan::Execute {
759                node_id: "add".into(),
760            },
761        ]);
762
763        execute(&plan, &mut ctx, &filters, &cache).unwrap();
764
765        let double_out = ctx.get("double").unwrap().as_tensor().unwrap().0;
766        assert_eq!(double_out, &[10.0]);
767        let add_out = ctx.get("add").unwrap().as_tensor().unwrap().0;
768        assert_eq!(add_out, &[105.0]);
769    }
770
771    #[test]
772    fn parallel_branches_run_concurrently() {
773        let (bus, cache) = setup();
774        let mut ctx = Context::new(bus, "run_1");
775        ctx.set("input", Value::tensor(vec![1.0], vec![1]));
776        ctx.graph_info
777            .set_predecessors("slow_a", vec!["input".into()]);
778        ctx.graph_info
779            .set_predecessors("slow_b", vec!["input".into()]);
780
781        let mut filters = FilterLibrary::new();
782        filters.register(
783            "slow_a",
784            Box::new(SlowFilter {
785                id: "a".into(),
786                delay_ms: 50,
787            }),
788        );
789        filters.register(
790            "slow_b",
791            Box::new(SlowFilter {
792                id: "b".into(),
793                delay_ms: 50,
794            }),
795        );
796
797        let plan = ExecutionPlan::Parallel(vec![
798            ExecutionPlan::Execute {
799                node_id: "slow_a".into(),
800            },
801            ExecutionPlan::Execute {
802                node_id: "slow_b".into(),
803            },
804        ]);
805
806        let start = Instant::now();
807        execute(&plan, &mut ctx, &filters, &cache).unwrap();
808        let elapsed = start.elapsed();
809
810        // If truly parallel: ~50ms. If sequential: ~100ms.
811        // Use 90ms as threshold to account for overhead.
812        assert!(
813            elapsed.as_millis() < 90,
814            "parallel branches took {}ms, expected <90ms (sequential would be ~100ms)",
815            elapsed.as_millis()
816        );
817
818        assert!(ctx.get("slow_a").is_some());
819        assert!(ctx.get("slow_b").is_some());
820    }
821
822    #[test]
823    fn resolve_input_single_predecessor() {
824        let bus = Arc::new(EventBus::new(8));
825        let mut ctx = Context::new(bus, "r");
826        ctx.set("A", Value::tensor(vec![42.0], vec![1]));
827        ctx.graph_info.set_predecessors("B", vec!["A".into()]);
828
829        let input = resolve_input("B", &ctx);
830        let (data, _) = input.as_tensor().unwrap();
831        assert_eq!(data, &[42.0]);
832    }
833
834    #[test]
835    fn resolve_input_multiple_predecessors() {
836        let bus = Arc::new(EventBus::new(8));
837        let mut ctx = Context::new(bus, "r");
838        ctx.set("A", Value::tensor(vec![1.0], vec![1]));
839        ctx.set("B", Value::tensor(vec![2.0], vec![1]));
840        ctx.graph_info
841            .set_predecessors("C", vec!["A".into(), "B".into()]);
842
843        let input = resolve_input("C", &ctx);
844        let json = input.as_json().unwrap();
845        assert!(json.get("A").is_some());
846        assert!(json.get("B").is_some());
847    }
848
849    #[test]
850    fn resolve_input_no_predecessors_fallback() {
851        let bus = Arc::new(EventBus::new(8));
852        let mut ctx = Context::new(bus, "r");
853        ctx.set("prev", Value::tensor(vec![7.0], vec![1]));
854
855        let input = resolve_input("root", &ctx);
856        let (data, _) = input.as_tensor().unwrap();
857        assert_eq!(data, &[7.0]);
858    }
859
860    #[test]
861    fn graph_info_from_linear() {
862        let info = GraphInfo::for_linear(&["a", "b", "c"]);
863        assert!(info.predecessors("a").is_empty());
864        assert_eq!(info.predecessors("b"), &["a"]);
865        assert_eq!(info.predecessors("c"), &["b"]);
866    }
867}